ono_encode Module Function

module function ono_encode(input, basis_weights, num_inputs, num_basis) result(c)

Encode input through an orthogonalised basis.

Forward: y = Q(B)^T @ u [k, batch] Q = modified_gram_schmidt(B), B [n x k] from basis_weights

left_operand → input u [n, batch] right_operand → basis weights B [n*k, 1] output → encoded [k, batch]

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: input

Input tensor [n, batch]

class(array_type), intent(in), target :: basis_weights

Flattened basis matrix parameters [n*k, 1]

integer, intent(in) :: num_inputs

Input dimension and basis size

integer, intent(in) :: num_basis

Input dimension and basis size

Return Value type(array_type), pointer

Encoded output tensor


Source Code

  module function ono_encode( &
       input, basis_weights, num_inputs, num_basis &
  ) result(c)
    !! Encode input through an orthogonalised basis.
    !!
    !! Forward: y = Q(B)^T @ u   [k, batch]
    !!   Q = modified_gram_schmidt(B), B [n x k] from basis_weights
    !!
    !! left_operand  → input u         [n, batch]
    !! right_operand → basis weights B  [n*k, 1]
    !! output        → encoded          [k, batch]
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: input
    !! Input tensor [n, batch]
    class(array_type), intent(in), target :: basis_weights
    !! Flattened basis matrix parameters [n*k, 1]
    integer, intent(in) :: num_inputs, num_basis
    !! Input dimension and basis size
    type(array_type), pointer :: c
    !! Encoded output tensor

    ! Local variables
    integer :: num_samples, n, k, i, j, s
    !! Batch/dimension values and loop indices
    real(real32), allocatable :: B(:,:), Q(:,:), QT(:,:)
    !! Basis matrix, orthonormal basis and transpose buffer
    real(real32) :: norm_val, proj
    !! Gram-Schmidt norm and projection scalars

    n = num_inputs
    k = num_basis
    num_samples = size(input%val, 2)

    ! Modified Gram-Schmidt: B -> Q
    allocate(B(n, k), Q(n, k), QT(k, n))
    B = reshape(basis_weights%val(:, 1), [n, k])
    Q = B
    do j = 1, k
       do i = 1, j - 1
          proj = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - proj * Q(:,i)
       end do
       norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / norm_val
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    ! Transpose
    do j = 1, n
       do i = 1, k
          QT(i, j) = Q(j, i)
       end do
    end do

    ! Forward: y = Q^T @ u
    c => input%create_result(array_shape=[k, num_samples])
    c%val = matmul(QT, input%val)

    deallocate(B, Q, QT)

    ! Store metadata
    allocate(c%indices(2))
    c%indices = [n, k]

    c%get_partial_left     => get_partial_ono_encode_input
    c%get_partial_right    => get_partial_ono_encode_basis
    c%get_partial_left_val => get_partial_ono_encode_input_val
    c%get_partial_right_val => get_partial_ono_encode_basis_val
    if(input%requires_grad .or. basis_weights%requires_grad)then
       c%requires_grad    = .true.
       c%is_forward       = input%is_forward .or. basis_weights%is_forward
       c%operation        = 'ono_encode'
       c%left_operand     => input
       c%right_operand    => basis_weights
       c%owns_left_operand  = input%is_temporary
       c%owns_right_operand = basis_weights%is_temporary
    end if

  end function ono_encode