get_bases_ono_attn Function

private function get_bases_ono_attn(this) result(phi)

Orthogonalise the basis matrix B using modified Gram-Schmidt

Type Bound

orthogonal_attention_layer_type

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(in) :: this

Layer instance providing basis parameters

Return Value type(array_type)

Orthogonalised basis matrix packed in an array_type


Source Code

  function get_bases_ono_attn(this) result(phi)
    !! Orthogonalise the basis matrix B using modified Gram-Schmidt
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(in) :: this
    !! Layer instance providing basis parameters
    type(array_type) :: phi
    !! Orthogonalised basis matrix packed in an array_type

    ! Local variables
    integer :: n, k, i, j
    !! Basis dimensions and Gram-Schmidt loop indices
    real(real32), allocatable :: B(:,:), Q(:,:)
    !! Raw basis matrix and orthogonalised copy
    real(real32) :: norm_val, proj
    !! Gram-Schmidt norm and projection scalars

    n = this%num_inputs
    k = this%num_basis

    allocate(B(n, k), Q(n, k))

    ! Reshape B from flat params(4) into [n, k]
    B = reshape(this%params(4)%val(:,1), [n, k])

    ! Modified Gram-Schmidt orthogonalisation
    Q = B
    do j = 1, k
       ! Subtract projections of previous orthogonal vectors
       do i = 1, j - 1
          proj = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - proj * Q(:,i)
       end do
       ! Normalise
       norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / norm_val
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    ! Store in phi as a fixed array_type
    call phi%allocate([n, k, 1])
    phi%is_sample_dependent = .false.
    phi%requires_grad = .false.
    phi%fix_pointer = .true.
    phi%is_temporary = .false.
    phi%val(:,1) = reshape(Q, [n * k])

    deallocate(B, Q)

  end function get_bases_ono_attn