forward_ono Subroutine

public subroutine forward_ono(this, input)

Forward propagation for the Orthogonal Neural Operator layer

Computes: encoded = Phi^T @ u [k, batch] mixed = R @ encoded [k, batch] decoded = Phi @ mixed [n_in, batch] spectral= W @ decoded [n_out, batch] (reuse W for output proj)

bypass = W @ u [n_out, batch]

v = sigma( spectral + bypass + b )

Actually, we separate the spectral and bypass paths clearly: spectral path uses the orthogonal basis + R mixing bypass path uses W directly on input Both project to [n_out] via W (shared) or separate matrices.

Here we implement: spectral = W @ Phi @ R @ Phi^T @ u bypass = W @ u v = sigma( spectral + bypass + b )

Note: W is params(3) [n_out x n_in], shared for both paths This means: v = sigma( W @ (Phi @ R @ Phi^T @ u + u) + b ) = sigma( W @ ((Phi @ R @ Phi^T + I) @ u) + b )

Type Bound

orthogonal_nop_block_type

Arguments

Type IntentOptional Attributes Name
class(orthogonal_nop_block_type), intent(inout) :: this

Layer instance to execute

class(array_type), intent(in), dimension(:,:) :: input

Input batch tensor collection


Source Code

  subroutine forward_ono(this, input)
    !! Forward propagation for the Orthogonal Neural Operator layer
    !!
    !! Computes:
    !!   encoded = Phi^T @ u          [k, batch]
    !!   mixed   = R @ encoded        [k, batch]
    !!   decoded = Phi @ mixed        [n_in, batch]
    !!   spectral= W @ decoded        [n_out, batch]  (reuse W for output proj)
    !!
    !!   bypass  = W @ u              [n_out, batch]
    !!
    !!   v = sigma( spectral + bypass + b )
    !!
    !! Actually, we separate the spectral and bypass paths clearly:
    !!   spectral path uses the orthogonal basis + R mixing
    !!   bypass path uses W directly on input
    !!   Both project to [n_out] via W (shared) or separate matrices.
    !!
    !! Here we implement:
    !!   spectral = W @ Phi @ R @ Phi^T @ u
    !!   bypass   = W @ u
    !!   v = sigma( spectral + bypass + b )
    !!
    !! Note: W is params(3) [n_out x n_in], shared for both paths
    !! This means: v = sigma( W @ (Phi @ R @ Phi^T @ u + u) + b )
    !!           = sigma( W @ ((Phi @ R @ Phi^T + I) @ u) + b )
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(inout) :: this
    !! Layer instance to execute
    class(array_type), dimension(:,:), intent(in) :: input
    !! Input batch tensor collection

    ! Local variables
    type(array_type), pointer :: ptr, ptr_spec, ptr_bypass
    !! Combined output, spectral-path output and bypass-path output
    type(array_type), pointer :: ptr_encoded, ptr_mixed, ptr_decoded
    !! Encoded spectrum, mixed spectrum and decoded tensor


    ! Spectral pathway: Phi @ R @ Phi^T @ u
    ! Uses autodiff-tracked ono_encode/ono_decode for basis gradients
    !---------------------------------------------------------------------------

    ! Encode: Q(B)^T @ u  -> [k, batch]
    ptr_encoded => ono_encode(input(1,1), this%params(2), &
         this%num_inputs, this%num_basis)

    ! Mix: R @ encoded   -> [k, batch]
    ptr_mixed => matmul(this%params(1), ptr_encoded)

    ! Decode: Q(B) @ mixed -> [n_in, batch]
    ptr_decoded => ono_decode(ptr_mixed, this%params(2), &
         this%num_inputs, this%num_basis)

    ! Spectral projection: W @ decoded -> [n_out, batch]
    ptr_spec => matmul(this%params(3), ptr_decoded)

    ! Bypass: W @ u -> [n_out, batch]
    ptr_bypass => matmul(this%params(3), input(1,1))

    ! Combine
    ptr => ptr_spec + ptr_bypass

    ! Add bias
    if(this%use_bias)then
       ptr => ptr + this%params(4)
    end if

    ! Apply activation
    call this%output(1,1)%zero_grad()
    if(trim(this%activation%name) .eq. "none")then
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    else
       call this%z(1)%zero_grad()
       call this%z(1)%assign_and_deallocate_source(ptr)
       this%z(1)%is_temporary = .false.
       ptr => this%activation%apply(this%z(1))
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    end if
    this%output(1,1)%is_temporary = .false.

  end subroutine forward_ono