Decode through an orthogonalised basis.
Forward: y = Q(B) @ x [n, batch] Q = modified_gram_schmidt(B), B [n x k] from basis_weights
left_operand → mixed x [k, batch] right_operand → basis weights B [n*k, 1] output → decoded [n, batch]
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in), | target | :: | mixed |
Mixed spectral tensor [k, batch] |
|
| class(array_type), | intent(in), | target | :: | basis_weights |
Flattened basis matrix parameters [n*k, 1] |
|
| integer, | intent(in) | :: | num_inputs |
Output dimension and basis size |
||
| integer, | intent(in) | :: | num_basis |
Output dimension and basis size |
Decoded output tensor
module function ono_decode( & mixed, basis_weights, num_inputs, num_basis & ) result(c) !! Decode through an orthogonalised basis. !! !! Forward: y = Q(B) @ x [n, batch] !! Q = modified_gram_schmidt(B), B [n x k] from basis_weights !! !! left_operand → mixed x [k, batch] !! right_operand → basis weights B [n*k, 1] !! output → decoded [n, batch] implicit none ! Arguments class(array_type), intent(in), target :: mixed !! Mixed spectral tensor [k, batch] class(array_type), intent(in), target :: basis_weights !! Flattened basis matrix parameters [n*k, 1] integer, intent(in) :: num_inputs, num_basis !! Output dimension and basis size type(array_type), pointer :: c !! Decoded output tensor ! Local variables integer :: num_samples, n, k, i, j !! Batch/dimension values and loop indices real(real32), allocatable :: B(:,:), Q(:,:) !! Basis matrix and orthonormal basis real(real32) :: norm_val, proj !! Gram-Schmidt norm and projection scalars n = num_inputs k = num_basis num_samples = size(mixed%val, 2) ! Modified Gram-Schmidt: B -> Q allocate(B(n, k), Q(n, k)) B = reshape(basis_weights%val(:, 1), [n, k]) Q = B do j = 1, k do i = 1, j - 1 proj = dot_product(Q(:,i), Q(:,j)) Q(:,j) = Q(:,j) - proj * Q(:,i) end do norm_val = sqrt(dot_product(Q(:,j), Q(:,j))) if(norm_val .gt. 1.0e-12_real32)then Q(:,j) = Q(:,j) / norm_val else Q(:,j) = 0.0_real32 end if end do ! Forward: y = Q @ x c => mixed%create_result(array_shape=[n, num_samples]) c%val = matmul(Q, mixed%val) deallocate(B, Q) ! Store metadata allocate(c%indices(2)) c%indices = [n, k] c%get_partial_left => get_partial_ono_decode_mixed c%get_partial_right => get_partial_ono_decode_basis c%get_partial_left_val => get_partial_ono_decode_mixed_val c%get_partial_right_val => get_partial_ono_decode_basis_val if(mixed%requires_grad .or. basis_weights%requires_grad)then c%requires_grad = .true. c%is_forward = mixed%is_forward .or. basis_weights%is_forward c%operation = 'ono_decode' c%left_operand => mixed c%right_operand => basis_weights c%owns_left_operand = mixed%is_temporary c%owns_right_operand = basis_weights%is_temporary end if end function ono_decode