Forward propagation for the Orthogonal Attention layer
Computes: Q = W_Q @ u [k, batch] K = W_K @ u [k, batch]
scores = tanh( (Q * K) / sqrt(k) ) [k, batch] bounded per-basis interaction scores
attn = softmax(scores, dim=1) [k, batch] normalised attention weights across basis modes
spectral = Q(B)^T @ u [k, batch] project input to orthogonal spectral basis
modulated = spectral + attn * spectral [k, batch] residual spectral modulation
decoded = Q(B) @ modulated [n_in, batch] decode modulated spectral representation
attn_out = W_V @ decoded [n_out, batch] bypass = W @ u [n_out, batch]
v = sigma( attn_out + bypass + b )
orthogonal_attention_layer_type
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to execute |
||
| class(array_type), | intent(in), | dimension(:,:) | :: | input |
Input batch tensor collection |
subroutine forward_ono_attn(this, input) !! Forward propagation for the Orthogonal Attention layer !! !! Computes: !! Q = W_Q @ u [k, batch] !! K = W_K @ u [k, batch] !! !! scores = tanh( (Q * K) / sqrt(k) ) [k, batch] !! bounded per-basis interaction scores !! !! attn = softmax(scores, dim=1) [k, batch] !! normalised attention weights across basis modes !! !! spectral = Q(B)^T @ u [k, batch] !! project input to orthogonal spectral basis !! !! modulated = spectral + attn * spectral [k, batch] !! residual spectral modulation !! !! decoded = Q(B) @ modulated [n_in, batch] !! decode modulated spectral representation !! !! attn_out = W_V @ decoded [n_out, batch] !! bypass = W @ u [n_out, batch] !! !! v = sigma( attn_out + bypass + b ) implicit none ! Arguments class(orthogonal_attention_layer_type), intent(inout) :: this !! Layer instance to execute class(array_type), dimension(:,:), intent(in) :: input !! Input batch tensor collection ! Local variables type(array_type), pointer :: ptr, ptr_attn, ptr_bypass !! Combined output, attention-path output and bypass-path output type(array_type), pointer :: ptr_Q, ptr_K, ptr_coeff !! Query, key and per-basis attention coefficient tensors type(array_type), pointer :: ptr_spec, ptr_mod, ptr_decoded !! Spectral encoding, modulated spectrum and decoded tensors integer :: n, nb !! Input size and basis count real(real32) :: scale !! Precomputed scaling factor for attention scores n = this%num_inputs nb = this%num_basis !--------------------------------------------------------------------------- ! Scaling (critical for stability) !--------------------------------------------------------------------------- scale = 1.0_real32 / sqrt(real(this%key_dim, kind=real32)) !--------------------------------------------------------------------------- ! Attention scores from Q and K projections !--------------------------------------------------------------------------- ptr_Q => matmul(this%params(1), input(1,1)) ! W_Q @ u: [k, batch] ptr_K => matmul(this%params(2), input(1,1)) ! W_K @ u: [k, batch] !--------------------------------------------------------------------------- ! Stable interaction (bounded instead of raw product) !--------------------------------------------------------------------------- ptr_coeff => ptr_Q * ptr_K * scale ! scaled interaction ptr_coeff => tanh(ptr_coeff) ! bound to [-1, 1] ptr_coeff => softmax(ptr_coeff, 1) ! [k, batch], sum_k = 1 !--------------------------------------------------------------------------- ! Spectral pathway: modulate spectral coefficients by attention scores !--------------------------------------------------------------------------- ptr_spec => ono_encode(input(1,1), this%params(4), n, nb) ! [k, batch] ptr_mod => ptr_coeff * ptr_spec ! [k, batch] ptr_decoded => ono_decode(ptr_mod, this%params(4), n, nb) ! [n, batch] ! Value projection ptr_attn => matmul(this%params(3), ptr_decoded) ! [n_out, batch] ! Bypass: W @ u ptr_bypass => matmul(this%params(5), input(1,1)) ! [n_out, batch] ! Combine: attn_out + bypass ptr => ptr_attn + ptr_bypass ! Add bias if(this%use_bias)then ptr => ptr + this%params(6) end if ! Apply activation call this%output(1,1)%zero_grad() if(trim(this%activation%name) .eq. "none")then call this%output(1,1)%assign_and_deallocate_source(ptr) else call this%z(1)%zero_grad() call this%z(1)%assign_and_deallocate_source(ptr) this%z(1)%is_temporary = .false. ptr => this%activation%apply(this%z(1)) call this%output(1,1)%assign_and_deallocate_source(ptr) end if this%output(1,1)%is_temporary = .false. end subroutine forward_ono_attn