ono_decode Module Function

module function ono_decode(mixed, basis_weights, num_inputs, num_basis) result(c)

Decode through an orthogonalised basis.

Forward: y = Q(B) @ x [n, batch] Q = modified_gram_schmidt(B), B [n x k] from basis_weights

left_operand → mixed x [k, batch] right_operand → basis weights B [n*k, 1] output → decoded [n, batch]

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: mixed

Mixed spectral tensor [k, batch]

class(array_type), intent(in), target :: basis_weights

Flattened basis matrix parameters [n*k, 1]

integer, intent(in) :: num_inputs

Output dimension and basis size

integer, intent(in) :: num_basis

Output dimension and basis size

Return Value type(array_type), pointer

Decoded output tensor


Source Code

  module function ono_decode( &
       mixed, basis_weights, num_inputs, num_basis &
  ) result(c)
    !! Decode through an orthogonalised basis.
    !!
    !! Forward: y = Q(B) @ x   [n, batch]
    !!   Q = modified_gram_schmidt(B), B [n x k] from basis_weights
    !!
    !! left_operand  → mixed x          [k, batch]
    !! right_operand → basis weights B   [n*k, 1]
    !! output        → decoded           [n, batch]
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: mixed
    !! Mixed spectral tensor [k, batch]
    class(array_type), intent(in), target :: basis_weights
    !! Flattened basis matrix parameters [n*k, 1]
    integer, intent(in) :: num_inputs, num_basis
    !! Output dimension and basis size
    type(array_type), pointer :: c
    !! Decoded output tensor

    ! Local variables
    integer :: num_samples, n, k, i, j
    !! Batch/dimension values and loop indices
    real(real32), allocatable :: B(:,:), Q(:,:)
    !! Basis matrix and orthonormal basis
    real(real32) :: norm_val, proj
    !! Gram-Schmidt norm and projection scalars

    n = num_inputs
    k = num_basis
    num_samples = size(mixed%val, 2)

    ! Modified Gram-Schmidt: B -> Q
    allocate(B(n, k), Q(n, k))
    B = reshape(basis_weights%val(:, 1), [n, k])
    Q = B
    do j = 1, k
       do i = 1, j - 1
          proj = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - proj * Q(:,i)
       end do
       norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / norm_val
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    ! Forward: y = Q @ x
    c => mixed%create_result(array_shape=[n, num_samples])
    c%val = matmul(Q, mixed%val)

    deallocate(B, Q)

    ! Store metadata
    allocate(c%indices(2))
    c%indices = [n, k]

    c%get_partial_left     => get_partial_ono_decode_mixed
    c%get_partial_right    => get_partial_ono_decode_basis
    c%get_partial_left_val => get_partial_ono_decode_mixed_val
    c%get_partial_right_val => get_partial_ono_decode_basis_val
    if(mixed%requires_grad .or. basis_weights%requires_grad)then
       c%requires_grad    = .true.
       c%is_forward       = mixed%is_forward .or. basis_weights%is_forward
       c%operation        = 'ono_decode'
       c%left_operand     => mixed
       c%right_operand    => basis_weights
       c%owns_left_operand  = mixed%is_temporary
       c%owns_right_operand = basis_weights%is_temporary
    end if

  end function ono_decode