Encode input through an orthogonalised basis.
Forward: y = Q(B)^T @ u [k, batch] Q = modified_gram_schmidt(B), B [n x k] from basis_weights
left_operand → input u [n, batch] right_operand → basis weights B [n*k, 1] output → encoded [k, batch]
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in), | target | :: | input |
Input tensor [n, batch] |
|
| class(array_type), | intent(in), | target | :: | basis_weights |
Flattened basis matrix parameters [n*k, 1] |
|
| integer, | intent(in) | :: | num_inputs |
Input dimension and basis size |
||
| integer, | intent(in) | :: | num_basis |
Input dimension and basis size |
Encoded output tensor
module function ono_encode( & input, basis_weights, num_inputs, num_basis & ) result(c) !! Encode input through an orthogonalised basis. !! !! Forward: y = Q(B)^T @ u [k, batch] !! Q = modified_gram_schmidt(B), B [n x k] from basis_weights !! !! left_operand → input u [n, batch] !! right_operand → basis weights B [n*k, 1] !! output → encoded [k, batch] implicit none ! Arguments class(array_type), intent(in), target :: input !! Input tensor [n, batch] class(array_type), intent(in), target :: basis_weights !! Flattened basis matrix parameters [n*k, 1] integer, intent(in) :: num_inputs, num_basis !! Input dimension and basis size type(array_type), pointer :: c !! Encoded output tensor ! Local variables integer :: num_samples, n, k, i, j, s !! Batch/dimension values and loop indices real(real32), allocatable :: B(:,:), Q(:,:), QT(:,:) !! Basis matrix, orthonormal basis and transpose buffer real(real32) :: norm_val, proj !! Gram-Schmidt norm and projection scalars n = num_inputs k = num_basis num_samples = size(input%val, 2) ! Modified Gram-Schmidt: B -> Q allocate(B(n, k), Q(n, k), QT(k, n)) B = reshape(basis_weights%val(:, 1), [n, k]) Q = B do j = 1, k do i = 1, j - 1 proj = dot_product(Q(:,i), Q(:,j)) Q(:,j) = Q(:,j) - proj * Q(:,i) end do norm_val = sqrt(dot_product(Q(:,j), Q(:,j))) if(norm_val .gt. 1.0e-12_real32)then Q(:,j) = Q(:,j) / norm_val else Q(:,j) = 0.0_real32 end if end do ! Transpose do j = 1, n do i = 1, k QT(i, j) = Q(j, i) end do end do ! Forward: y = Q^T @ u c => input%create_result(array_shape=[k, num_samples]) c%val = matmul(QT, input%val) deallocate(B, Q, QT) ! Store metadata allocate(c%indices(2)) c%indices = [n, k] c%get_partial_left => get_partial_ono_encode_input c%get_partial_right => get_partial_ono_encode_basis c%get_partial_left_val => get_partial_ono_encode_input_val c%get_partial_right_val => get_partial_ono_encode_basis_val if(input%requires_grad .or. basis_weights%requires_grad)then c%requires_grad = .true. c%is_forward = input%is_forward .or. basis_weights%is_forward c%operation = 'ono_encode' c%left_operand => input c%right_operand => basis_weights c%owns_left_operand = input%is_temporary c%owns_right_operand = basis_weights%is_temporary end if end function ono_encode