forward_ono_attn Subroutine

private subroutine forward_ono_attn(this, input)

Forward propagation for the Orthogonal Attention layer

Computes: Q = W_Q @ u [k, batch] K = W_K @ u [k, batch]

scores = tanh( (Q * K) / sqrt(k) ) [k, batch] bounded per-basis interaction scores

attn = softmax(scores, dim=1) [k, batch] normalised attention weights across basis modes

spectral = Q(B)^T @ u [k, batch] project input to orthogonal spectral basis

modulated = spectral + attn * spectral [k, batch] residual spectral modulation

decoded = Q(B) @ modulated [n_in, batch] decode modulated spectral representation

attn_out = W_V @ decoded [n_out, batch] bypass = W @ u [n_out, batch]

v = sigma( attn_out + bypass + b )

Type Bound

orthogonal_attention_layer_type

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(inout) :: this

Layer instance to execute

class(array_type), intent(in), dimension(:,:) :: input

Input batch tensor collection


Source Code

  subroutine forward_ono_attn(this, input)
    !! Forward propagation for the Orthogonal Attention layer
    !!
    !! Computes:
    !!   Q = W_Q @ u                                          [k, batch]
    !!   K = W_K @ u                                          [k, batch]
    !!
    !!   scores = tanh( (Q * K) / sqrt(k) )                   [k, batch]
    !!            bounded per-basis interaction scores
    !!
    !!   attn = softmax(scores, dim=1)                        [k, batch]
    !!          normalised attention weights across basis modes
    !!
    !!   spectral = Q(B)^T @ u                                [k, batch]
    !!              project input to orthogonal spectral basis
    !!
    !!   modulated = spectral + attn * spectral               [k, batch]
    !!               residual spectral modulation
    !!
    !!   decoded = Q(B) @ modulated                           [n_in, batch]
    !!             decode modulated spectral representation
    !!
    !!   attn_out = W_V @ decoded                             [n_out, batch]
    !!   bypass   = W @ u                                     [n_out, batch]
    !!
    !!   v = sigma( attn_out + bypass + b )
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(inout) :: this
    !! Layer instance to execute
    class(array_type), dimension(:,:), intent(in) :: input
    !! Input batch tensor collection

    ! Local variables
    type(array_type), pointer :: ptr, ptr_attn, ptr_bypass
    !! Combined output, attention-path output and bypass-path output
    type(array_type), pointer :: ptr_Q, ptr_K, ptr_coeff
    !! Query, key and per-basis attention coefficient tensors
    type(array_type), pointer :: ptr_spec, ptr_mod, ptr_decoded
    !! Spectral encoding, modulated spectrum and decoded tensors

    integer :: n, nb
    !! Input size and basis count
    real(real32) :: scale
    !! Precomputed scaling factor for attention scores


    n = this%num_inputs
    nb = this%num_basis


    !---------------------------------------------------------------------------
    ! Scaling (critical for stability)
    !---------------------------------------------------------------------------
    scale = 1.0_real32 / sqrt(real(this%key_dim, kind=real32))


    !---------------------------------------------------------------------------
    ! Attention scores from Q and K projections
    !---------------------------------------------------------------------------
    ptr_Q => matmul(this%params(1), input(1,1))    ! W_Q @ u: [k, batch]
    ptr_K => matmul(this%params(2), input(1,1))    ! W_K @ u: [k, batch]


    !---------------------------------------------------------------------------
    ! Stable interaction (bounded instead of raw product)
    !---------------------------------------------------------------------------
    ptr_coeff => ptr_Q * ptr_K * scale             ! scaled interaction
    ptr_coeff => tanh(ptr_coeff)                   ! bound to [-1, 1]
    ptr_coeff => softmax(ptr_coeff, 1)             ! [k, batch], sum_k = 1


    !---------------------------------------------------------------------------
    ! Spectral pathway: modulate spectral coefficients by attention scores
    !---------------------------------------------------------------------------
    ptr_spec => ono_encode(input(1,1), this%params(4), n, nb)  ! [k, batch]
    ptr_mod  => ptr_coeff * ptr_spec                           ! [k, batch]
    ptr_decoded => ono_decode(ptr_mod, this%params(4), n, nb)  ! [n, batch]

    ! Value projection
    ptr_attn => matmul(this%params(3), ptr_decoded)  ! [n_out, batch]

    ! Bypass: W @ u
    ptr_bypass => matmul(this%params(5), input(1,1))   ! [n_out, batch]

    ! Combine: attn_out + bypass
    ptr => ptr_attn + ptr_bypass

    ! Add bias
    if(this%use_bias)then
       ptr => ptr + this%params(6)
    end if

    ! Apply activation
    call this%output(1,1)%zero_grad()
    if(trim(this%activation%name) .eq. "none")then
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    else
       call this%z(1)%zero_grad()
       call this%z(1)%assign_and_deallocate_source(ptr)
       this%z(1)%is_temporary = .false.
       ptr => this%activation%apply(this%z(1))
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    end if
    this%output(1,1)%is_temporary = .false.

  end subroutine forward_ono_attn