athena_orthogonal_attention_layer.f90 Source File


Source Code

module athena__orthogonal_attention_layer
  !! Module containing implementation of an Orthogonal Attention layer
  !!
  !! This module implements the Orthogonal Attention mechanism from
  !! "Improved Operator Learning by Orthogonal Attention" (Luo et al., 2024).
  !!
  !! Instead of softmax attention, this layer projects queries and keys
  !! onto a learned orthonormal basis of dimension \(k \ll N\), giving
  !! a linear-cost approximation to the attention kernel.
  !!
  !! Given input \(\mathbf{u} \in \mathbb{R}^{n_{in}}\):
  !!
  !! \[
  !!   \mathbf{Q} = \mathbf{W}_Q\,\mathbf{u}, \quad
  !!   \mathbf{K} = \mathbf{W}_K\,\mathbf{u}, \quad
  !!   \mathbf{V} = \mathbf{W}_V\,\mathbf{u}
  !! \]
  !!
  !! The orthogonal basis \(\mathbf{\Phi} \in \mathbb{R}^{n_{in} \times k}\)
  !! is obtained by QR decomposition of learnable weights
  !! \(\mathbf{B} \in \mathbb{R}^{n_{in} \times k}\).
  !!
  !! The attention output is:
  !! \[
  !!   \text{Attn}(\mathbf{u}) = \mathbf{\Phi}\,
  !!     (\mathbf{\Phi}^T \mathbf{Q})^T\,
  !!     (\mathbf{\Phi}^T \mathbf{K})\,
  !!     \mathbf{V}
  !! \]
  !!
  !! The layer output is:
  !! \[
  !!   \mathbf{v} = \sigma\!\bigl(
  !!     \text{Attn}(\mathbf{u}) + \mathbf{W}\,\mathbf{u} + \mathbf{b}
  !!   \bigr)
  !! \]
  !!
  !! Parameters (learnable):
  !!   - \(\mathbf{W}_Q \in \mathbb{R}^{d_k \times n_{in}}\)
  !!   - \(\mathbf{W}_K \in \mathbb{R}^{d_k \times n_{in}}\)
  !!   - \(\mathbf{W}_V \in \mathbb{R}^{n_{out} \times n_{in}}\)
  !!   - \(\mathbf{B}   \in \mathbb{R}^{n_{in} \times k}\)  (basis, orthogonalised)
  !!   - \(\mathbf{W}   \in \mathbb{R}^{n_{out} \times n_{in}}\)  (bypass)
  !!   - \(\mathbf{b}   \in \mathbb{R}^{n_{out}}\)  (optional bias)
  use coreutils, only: real32, stop_program
  use athena__base_layer, only: learnable_layer_type, base_layer_type
  use athena__misc_types, only: base_actv_type, base_init_type, &
       onnx_attribute_type
  use diffstruc, only: array_type, matmul, operator(+), operator(*), tanh
  use athena__diffstruc_extd, only: ono_encode, ono_decode, softmax
  implicit none


  private

  public :: orthogonal_attention_layer_type
  public :: read_orthogonal_attention_layer


  type, extends(learnable_layer_type) :: orthogonal_attention_layer_type
     !! Type for an Orthogonal Attention layer
     integer :: num_inputs = 0
     !! Number of input features / discretisation points
     integer :: num_outputs = 0
     !! Number of output features / discretisation points
     integer :: num_basis = 0
     !! Number of orthogonal basis functions (k)
     integer :: key_dim = 0
     !! Dimension of query/key projections (d_k)
     type(array_type), dimension(1) :: z
     !! Temporary array for pre-activation values
   contains
     procedure, pass(this) :: get_num_params => get_num_params_ono_attn
     procedure, pass(this) :: set_hyperparams => set_hyperparams_ono_attn
     procedure, pass(this) :: init => init_ono_attn
     procedure, pass(this) :: print_to_unit => print_to_unit_ono_attn
     procedure, pass(this) :: read => read_ono_attn

     procedure, pass(this) :: forward => forward_ono_attn
     procedure, pass(this) :: get_bases => get_bases_ono_attn
     procedure, pass(this) :: get_attributes => get_attributes_ono_attn

     final :: finalise_ono_attn
  end type orthogonal_attention_layer_type

  interface orthogonal_attention_layer_type
     module function layer_setup( &
          num_outputs, num_basis, key_dim, &
          num_inputs, use_bias, &
          activation, &
          kernel_initialiser, bias_initialiser, verbose &
     ) result(layer)
       integer, intent(in) :: num_outputs
       integer, intent(in) :: num_basis
       integer, optional, intent(in) :: key_dim
       integer, optional, intent(in) :: num_inputs
       logical, optional, intent(in) :: use_bias
       class(*), optional, intent(in) :: activation
       class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
       integer, optional, intent(in) :: verbose
       type(orthogonal_attention_layer_type) :: layer
     end function layer_setup
  end interface orthogonal_attention_layer_type



contains

!###############################################################################
  subroutine finalise_ono_attn(this)
    !! Finalise the orthogonal attention layer
    implicit none

    ! Arguments
    type(orthogonal_attention_layer_type), intent(inout) :: this
    !! Layer instance to release

    if(allocated(this%input_shape)) deallocate(this%input_shape)
    if(allocated(this%output)) deallocate(this%output)
    if(this%z(1)%allocated) call this%z(1)%deallocate()

  end subroutine finalise_ono_attn
!###############################################################################


!###############################################################################
  pure function get_num_params_ono_attn(this) result(num_params)
    !! Return the number of learnable parameters for the layer
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(in) :: this
    !! Layer instance
    integer :: num_params
    !! Total number of learnable parameters

    ! W_Q: key_dim * num_inputs
    ! W_K: key_dim * num_inputs
    ! W_V: num_outputs * num_inputs
    ! B:   num_inputs * num_basis  (basis weights to orthogonalise)
    ! W:   num_outputs * num_inputs (bypass)
    ! b:   num_outputs (optional)
    num_params = this%key_dim * this%num_inputs + &     ! W_Q
         this%key_dim * this%num_inputs + &              ! W_K
         this%num_outputs * this%num_inputs + &          ! W_V
         this%num_inputs * this%num_basis + &            ! B
         this%num_outputs * this%num_inputs              ! W
    if(this%use_bias) num_params = num_params + this%num_outputs

  end function get_num_params_ono_attn
!###############################################################################


!###############################################################################
  module function layer_setup( &
       num_outputs, num_basis, key_dim, &
       num_inputs, use_bias, &
       activation, &
       kernel_initialiser, bias_initialiser, verbose &
  ) result(layer)
    use athena__activation, only: activation_setup
    use athena__initialiser, only: initialiser_setup
    implicit none

    ! Arguments
    integer, intent(in) :: num_outputs
    !! Number of output features
    integer, intent(in) :: num_basis
    !! Number of orthogonal basis vectors
    integer, optional, intent(in) :: key_dim
    !! Query/key projection dimension
    integer, optional, intent(in) :: num_inputs
    !! Number of input features when known at construction time
    logical, optional, intent(in) :: use_bias
    !! Whether to allocate a bias term
    class(*), optional, intent(in) :: activation
    !! Activation function specification
    class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
    !! Kernel and bias initialiser specifications
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    type(orthogonal_attention_layer_type) :: layer
    !! Constructed orthogonal attention layer

    ! Local variables
    integer :: verbose_ = 0
    !! Effective verbosity level
    integer :: key_dim_
    !! Query/key projection dimension after defaults
    logical :: use_bias_ = .true.
    !! Effective bias flag
    class(base_actv_type), allocatable :: activation_
    !! Materialised activation object
    class(base_init_type), allocatable :: kernel_initialiser_, bias_initialiser_
    !! Materialised kernel and bias initialisers

    if(present(verbose)) verbose_ = verbose
    if(present(use_bias)) use_bias_ = use_bias
    key_dim_ = num_basis
    if(present(key_dim)) key_dim_ = key_dim

    if(present(activation))then
       activation_ = activation_setup(activation)
    else
       activation_ = activation_setup("none")
    end if

    if(present(kernel_initialiser))then
       kernel_initialiser_ = initialiser_setup(kernel_initialiser)
    end if
    if(present(bias_initialiser))then
       bias_initialiser_ = initialiser_setup(bias_initialiser)
    end if

    call layer%set_hyperparams( &
         num_outputs = num_outputs, &
         num_basis = num_basis, &
         key_dim = key_dim_, &
         use_bias = use_bias_, &
         activation = activation_, &
         kernel_initialiser = kernel_initialiser_, &
         bias_initialiser = bias_initialiser_, &
         verbose = verbose_ &
    )

    if(present(num_inputs)) call layer%init(input_shape=[num_inputs])

  end function layer_setup
!###############################################################################


!###############################################################################
  subroutine set_hyperparams_ono_attn( &
       this, num_outputs, num_basis, key_dim, &
       use_bias, &
       activation, &
       kernel_initialiser, bias_initialiser, &
       verbose &
  )
    use athena__activation, only: activation_setup
    use athena__initialiser, only: get_default_initialiser, initialiser_setup
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(inout) :: this
    !! Layer instance to configure
    integer, intent(in) :: num_outputs
    !! Number of output features
    integer, intent(in) :: num_basis
    !! Number of orthogonal basis vectors
    integer, intent(in) :: key_dim
    !! Query/key projection dimension
    logical, intent(in) :: use_bias
    !! Whether to use a bias term
    class(base_actv_type), allocatable, intent(in) :: activation
    !! Activation function object
    class(base_init_type), allocatable, intent(in) :: &
         kernel_initialiser, bias_initialiser
    !! Kernel and bias initialiser objects
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    character(len=256) :: buffer
    !! Buffer for default initialiser lookup

    this%name = "orthogonal_attention"
    this%type = "nop"
    this%input_rank = 1
    this%output_rank = 1
    this%use_bias = use_bias
    this%num_outputs = num_outputs
    this%num_basis = num_basis
    this%key_dim = key_dim

    if(allocated(this%activation)) deallocate(this%activation)
    if(.not.allocated(activation))then
       this%activation = activation_setup("none")
    else
       allocate(this%activation, source=activation)
    end if
    if(allocated(this%kernel_init)) deallocate(this%kernel_init)
    if(.not.allocated(kernel_initialiser))then
       buffer = get_default_initialiser(this%activation%name)
       this%kernel_init = initialiser_setup(buffer)
    else
       allocate(this%kernel_init, source=kernel_initialiser)
    end if
    if(allocated(this%bias_init)) deallocate(this%bias_init)
    if(.not.allocated(bias_initialiser))then
       buffer = get_default_initialiser( &
            this%activation%name, &
            is_bias=.true. &
       )
       this%bias_init = initialiser_setup(buffer)
    else
       allocate(this%bias_init, source=bias_initialiser)
    end if

    if(present(verbose))then
       if(abs(verbose).gt.0)then
          write(*,'("ORTHOGONAL_ATTENTION activation: ",A)') &
               trim(this%activation%name)
       end if
    end if

  end subroutine set_hyperparams_ono_attn
!###############################################################################


!###############################################################################
  subroutine init_ono_attn(this, input_shape, verbose)
    !! Initialise parameter storage and output buffers for the layer
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(inout) :: this
    !! Layer instance to initialise
    integer, dimension(:), intent(in) :: input_shape
    !! Input shape used to infer num_inputs
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: num_inputs, idx, nparams
    !! Effective fan-in size and reserved scratch integers
    integer :: verbose_ = 0
    !! Effective verbosity level

    if(present(verbose)) verbose_ = verbose


    !---------------------------------------------------------------------------
    ! Set shapes
    !---------------------------------------------------------------------------
    if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
    this%num_inputs = this%input_shape(1)
    this%output_shape = [this%num_outputs]
    this%num_params = this%get_num_params()


    !---------------------------------------------------------------------------
    ! Allocate learnable parameters
    !
    ! params(1): W_Q  query projection   [key_dim x num_inputs]
    ! params(2): W_K  key projection     [key_dim x num_inputs]
    ! params(3): W_V  value projection   [num_outputs x num_inputs]
    ! params(4): B    basis weights       [num_inputs x num_basis]
    ! params(5): W    bypass weights      [num_outputs x num_inputs]
    ! params(6): b    bias                [num_outputs]  (optional)
    !---------------------------------------------------------------------------
    allocate(this%weight_shape(2,5))
    this%weight_shape(:,1) = [ this%key_dim, this%num_inputs ]
    this%weight_shape(:,2) = [ this%key_dim, this%num_inputs ]
    this%weight_shape(:,3) = [ this%num_outputs, this%num_inputs ]
    this%weight_shape(:,4) = [ this%num_inputs, this%num_basis ]
    this%weight_shape(:,5) = [ this%num_outputs, this%num_inputs ]

    if(this%use_bias)then
       this%bias_shape = [ this%num_outputs ]
       allocate(this%params(6))
    else
       allocate(this%params(5))
    end if

    num_inputs = this%num_inputs
    if(this%use_bias) num_inputs = this%num_inputs + 1

    ! W_Q
    call this%params(1)%allocate([this%key_dim, this%num_inputs, 1])
    call this%params(1)%set_requires_grad(.true.)
    this%params(1)%fix_pointer = .true.
    this%params(1)%is_sample_dependent = .false.
    this%params(1)%is_temporary = .false.

    ! W_K
    call this%params(2)%allocate([this%key_dim, this%num_inputs, 1])
    call this%params(2)%set_requires_grad(.true.)
    this%params(2)%fix_pointer = .true.
    this%params(2)%is_sample_dependent = .false.
    this%params(2)%is_temporary = .false.

    ! W_V
    call this%params(3)%allocate([this%num_outputs, this%num_inputs, 1])
    call this%params(3)%set_requires_grad(.true.)
    this%params(3)%fix_pointer = .true.
    this%params(3)%is_sample_dependent = .false.
    this%params(3)%is_temporary = .false.

    ! B (basis weights)
    call this%params(4)%allocate([this%num_inputs, this%num_basis, 1])
    call this%params(4)%set_requires_grad(.true.)
    this%params(4)%fix_pointer = .true.
    this%params(4)%is_sample_dependent = .false.
    this%params(4)%is_temporary = .false.

    ! W (bypass)
    call this%params(5)%allocate([this%num_outputs, this%num_inputs, 1])
    call this%params(5)%set_requires_grad(.true.)
    this%params(5)%fix_pointer = .true.
    this%params(5)%is_sample_dependent = .false.
    this%params(5)%is_temporary = .false.

    ! b (bias, optional)
    if(this%use_bias)then
       call this%params(6)%allocate([this%bias_shape, 1])
       call this%params(6)%set_requires_grad(.true.)
       this%params(6)%fix_pointer = .true.
       this%params(6)%is_sample_dependent = .false.
       this%params(6)%is_temporary = .false.
    end if


    !---------------------------------------------------------------------------
    ! Initialise learnable parameters
    !---------------------------------------------------------------------------
    call this%kernel_init%initialise( &
         this%params(1)%val(:,1), &
         fan_in = this%num_inputs, fan_out = this%key_dim, &
         spacing = [ this%key_dim ] &
    )
    call this%kernel_init%initialise( &
         this%params(2)%val(:,1), &
         fan_in = this%num_inputs, fan_out = this%key_dim, &
         spacing = [ this%key_dim ] &
    )
    call this%kernel_init%initialise( &
         this%params(3)%val(:,1), &
         fan_in = num_inputs, fan_out = this%num_outputs, &
         spacing = [ this%num_outputs ] &
    )
    call this%kernel_init%initialise( &
         this%params(4)%val(:,1), &
         fan_in = this%num_inputs, fan_out = this%num_basis, &
         spacing = [ this%num_inputs ] &
    )
    call this%kernel_init%initialise( &
         this%params(5)%val(:,1), &
         fan_in = num_inputs, fan_out = this%num_outputs, &
         spacing = [ this%num_outputs ] &
    )
    if(this%use_bias)then
       call this%bias_init%initialise( &
            this%params(6)%val(:,1), &
            fan_in = num_inputs, fan_out = this%num_outputs &
       )
    end if


    !---------------------------------------------------------------------------
    ! Allocate output arrays
    !---------------------------------------------------------------------------
    if(allocated(this%output)) deallocate(this%output)
    allocate(this%output(1,1))
    if(this%z(1)%allocated) call this%z(1)%deallocate()

  end subroutine init_ono_attn
!###############################################################################


!###############################################################################
  function get_bases_ono_attn(this) result(phi)
    !! Orthogonalise the basis matrix B using modified Gram-Schmidt
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(in) :: this
    !! Layer instance providing basis parameters
    type(array_type) :: phi
    !! Orthogonalised basis matrix packed in an array_type

    ! Local variables
    integer :: n, k, i, j
    !! Basis dimensions and Gram-Schmidt loop indices
    real(real32), allocatable :: B(:,:), Q(:,:)
    !! Raw basis matrix and orthogonalised copy
    real(real32) :: norm_val, proj
    !! Gram-Schmidt norm and projection scalars

    n = this%num_inputs
    k = this%num_basis

    allocate(B(n, k), Q(n, k))

    ! Reshape B from flat params(4) into [n, k]
    B = reshape(this%params(4)%val(:,1), [n, k])

    ! Modified Gram-Schmidt orthogonalisation
    Q = B
    do j = 1, k
       ! Subtract projections of previous orthogonal vectors
       do i = 1, j - 1
          proj = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - proj * Q(:,i)
       end do
       ! Normalise
       norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / norm_val
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    ! Store in phi as a fixed array_type
    call phi%allocate([n, k, 1])
    phi%is_sample_dependent = .false.
    phi%requires_grad = .false.
    phi%fix_pointer = .true.
    phi%is_temporary = .false.
    phi%val(:,1) = reshape(Q, [n * k])

    deallocate(B, Q)

  end function get_bases_ono_attn
!###############################################################################


!###############################################################################
  subroutine print_to_unit_ono_attn(this, unit)
    !! Print orthogonal attention layer settings and parameters to a unit
    use coreutils, only: to_upper
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(in) :: this
    !! Layer instance to print
    integer, intent(in) :: unit
    !! Output unit number

    write(unit,'(3X,"NUM_INPUTS = ",I0)') this%num_inputs
    write(unit,'(3X,"NUM_OUTPUTS = ",I0)') this%num_outputs
    write(unit,'(3X,"NUM_BASIS = ",I0)') this%num_basis
    write(unit,'(3X,"KEY_DIM = ",I0)') this%key_dim
    write(unit,'(3X,"USE_BIAS = ",L1)') this%use_bias
    if(this%activation%name .ne. 'none')then
       call this%activation%print_to_unit(unit)
    end if

    write(unit,'("WEIGHTS")')
    write(unit,'(5(E16.8E2))') this%params(1)%val(:,1)   ! W_Q
    write(unit,'(5(E16.8E2))') this%params(2)%val(:,1)   ! W_K
    write(unit,'(5(E16.8E2))') this%params(3)%val(:,1)   ! W_V
    write(unit,'(5(E16.8E2))') this%params(4)%val(:,1)   ! B
    write(unit,'(5(E16.8E2))') this%params(5)%val(:,1)   ! W
    if(this%use_bias)then
       write(unit,'(5(E16.8E2))') this%params(6)%val(:,1) ! b
    end if
    write(unit,'("END WEIGHTS")')

  end subroutine print_to_unit_ono_attn
!###############################################################################


!###############################################################################
  subroutine read_ono_attn(this, unit, verbose)
    use athena__tools_infile, only: assign_val, assign_vec, move
    use coreutils, only: to_lower, to_upper, icount
    use athena__activation, only: read_activation
    use athena__initialiser, only: initialiser_setup
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(inout) :: this
    !! Layer instance to populate from file data
    integer, intent(in) :: unit
    !! Input unit number
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: stat, verbose_ = 0
    !! I/O status and effective verbosity level
    integer :: j, k, c, itmp1, iline
    !! Loop counters and parser scratch integers
    integer :: num_inputs, num_outputs, num_basis, key_dim
    !! Parsed layer dimensions
    logical :: use_bias = .true.
    !! Parsed bias flag
    character(14) :: kernel_initialiser_name='', bias_initialiser_name=''
    !! Parsed initialiser names
    class(base_actv_type), allocatable :: activation
    !! Parsed activation object
    class(base_init_type), allocatable :: kernel_initialiser, bias_initialiser
    !! Parsed initialiser objects
    character(256) :: buffer, tag, err_msg
    !! Input buffer, parsed tag and formatted error message
    real(real32), allocatable, dimension(:) :: data_list
    !! Temporary storage for flattened parameter blocks
    integer :: param_line, final_line, num_vals
    !! Weights-section line markers and current block size

    if(present(verbose)) verbose_ = verbose

    key_dim = 0
    iline = 0
    param_line = 0
    final_line = 0
    tag_loop: do
       read(unit,'(A)',iostat=stat) buffer
       if(stat.ne.0)then
          write(err_msg,'("file encountered error (EoF?) before END ",A)') &
               to_upper(this%name)
          call stop_program(err_msg)
          return
       end if
       if(trim(adjustl(buffer)).eq."") cycle tag_loop

       if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
          final_line = iline
          backspace(unit)
          exit tag_loop
       end if
       iline = iline + 1

       tag=trim(adjustl(buffer))
       if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))

       select case(trim(tag))
       case("NUM_INPUTS")
          call assign_val(buffer, num_inputs, itmp1)
       case("NUM_OUTPUTS")
          call assign_val(buffer, num_outputs, itmp1)
       case("NUM_BASIS")
          call assign_val(buffer, num_basis, itmp1)
       case("KEY_DIM")
          call assign_val(buffer, key_dim, itmp1)
       case("USE_BIAS")
          call assign_val(buffer, use_bias, itmp1)
       case("ACTIVATION")
          iline = iline - 1
          backspace(unit)
          activation = read_activation(unit, iline)
       case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALIZER")
          call assign_val(buffer, kernel_initialiser_name, itmp1)
       case("BIAS_INITIALISER", "BIAS_INIT", "BIAS_INITIALIZER")
          call assign_val(buffer, bias_initialiser_name, itmp1)
       case("WEIGHTS")
          kernel_initialiser_name = 'zeros'
          bias_initialiser_name   = 'zeros'
          param_line = iline
       case default
          if(scan(to_lower(trim(adjustl(buffer))),&
               'abcdfghijklmnopqrstuvwxyz').eq.0)then
             cycle tag_loop
          elseif(tag(:3).eq.'END')then
             cycle tag_loop
          end if
          write(err_msg,'("Unrecognised line in input file: ",A)') &
               trim(adjustl(buffer))
          call stop_program(err_msg)
          return
       end select
    end do tag_loop
    kernel_initialiser = initialiser_setup(kernel_initialiser_name)
    bias_initialiser = initialiser_setup(bias_initialiser_name)

    if(key_dim .eq. 0) key_dim = num_basis

    call this%set_hyperparams( &
         num_outputs = num_outputs, &
         num_basis = num_basis, &
         key_dim = key_dim, &
         use_bias = use_bias, &
         activation = activation, &
         kernel_initialiser = kernel_initialiser, &
         bias_initialiser = bias_initialiser, &
         verbose = verbose_ &
    )
    call this%init(input_shape=[num_inputs])

    if(param_line.eq.0)then
       write(0,*) "WARNING: WEIGHTS card in " // trim(this%name) // " not found"
    else
       call move(unit, param_line - iline, iostat=stat)

       ! Read W_Q
       num_vals = key_dim * num_inputs
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(1)%val(:,1) = data_list
       deallocate(data_list)

       ! Read W_K
       num_vals = key_dim * num_inputs
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(2)%val(:,1) = data_list
       deallocate(data_list)

       ! Read W_V
       num_vals = num_outputs * num_inputs
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(3)%val(:,1) = data_list
       deallocate(data_list)

       ! Read B
       num_vals = num_inputs * num_basis
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(4)%val(:,1) = data_list
       deallocate(data_list)

       ! Read W (bypass)
       num_vals = num_outputs * num_inputs
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(5)%val(:,1) = data_list
       deallocate(data_list)

       ! Read b if use_bias
       if(use_bias)then
          allocate(data_list(num_outputs), source=0._real32)
          c = 1; k = 1
          do while(c.le.num_outputs)
             read(unit,'(A)',iostat=stat) buffer
             if(stat.ne.0) exit
             k = icount(buffer)
             read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
             c = c + k
          end do
          this%params(6)%val(:,1) = data_list(1:num_outputs)
          deallocate(data_list)
       end if

       read(unit,'(A)') buffer
       if(trim(adjustl(buffer)).ne."END WEIGHTS")then
          call stop_program("END WEIGHTS not where expected")
          return
       end if
    end if

    call move(unit, final_line - iline, iostat=stat)
    read(unit,'(A)') buffer
    if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
       write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
       call stop_program(err_msg)
       return
    end if

  end subroutine read_ono_attn
!###############################################################################


!###############################################################################
  function read_orthogonal_attention_layer(unit, verbose) result(layer)
    !! Read an orthogonal attention layer from file and return it
    implicit none

    ! Arguments
    integer, intent(in) :: unit
    !! Input unit number
    integer, optional, intent(in) :: verbose
    !! Verbosity level
    class(base_layer_type), allocatable :: layer
    !! Allocated base-layer instance containing the result

    ! Local variables
    integer :: verbose_ = 0
    !! Effective verbosity level

    if(present(verbose)) verbose_ = verbose
    allocate(layer, source=orthogonal_attention_layer_type( &
         num_outputs=0, num_basis=1))
    call layer%read(unit, verbose=verbose_)

  end function read_orthogonal_attention_layer
!###############################################################################


!###############################################################################
  subroutine forward_ono_attn(this, input)
    !! Forward propagation for the Orthogonal Attention layer
    !!
    !! Computes:
    !!   Q = W_Q @ u                                          [k, batch]
    !!   K = W_K @ u                                          [k, batch]
    !!
    !!   scores = tanh( (Q * K) / sqrt(k) )                   [k, batch]
    !!            bounded per-basis interaction scores
    !!
    !!   attn = softmax(scores, dim=1)                        [k, batch]
    !!          normalised attention weights across basis modes
    !!
    !!   spectral = Q(B)^T @ u                                [k, batch]
    !!              project input to orthogonal spectral basis
    !!
    !!   modulated = spectral + attn * spectral               [k, batch]
    !!               residual spectral modulation
    !!
    !!   decoded = Q(B) @ modulated                           [n_in, batch]
    !!             decode modulated spectral representation
    !!
    !!   attn_out = W_V @ decoded                             [n_out, batch]
    !!   bypass   = W @ u                                     [n_out, batch]
    !!
    !!   v = sigma( attn_out + bypass + b )
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(inout) :: this
    !! Layer instance to execute
    class(array_type), dimension(:,:), intent(in) :: input
    !! Input batch tensor collection

    ! Local variables
    type(array_type), pointer :: ptr, ptr_attn, ptr_bypass
    !! Combined output, attention-path output and bypass-path output
    type(array_type), pointer :: ptr_Q, ptr_K, ptr_coeff
    !! Query, key and per-basis attention coefficient tensors
    type(array_type), pointer :: ptr_spec, ptr_mod, ptr_decoded
    !! Spectral encoding, modulated spectrum and decoded tensors

    integer :: n, nb
    !! Input size and basis count
    real(real32) :: scale
    !! Precomputed scaling factor for attention scores


    n = this%num_inputs
    nb = this%num_basis


    !---------------------------------------------------------------------------
    ! Scaling (critical for stability)
    !---------------------------------------------------------------------------
    scale = 1.0_real32 / sqrt(real(this%key_dim, kind=real32))


    !---------------------------------------------------------------------------
    ! Attention scores from Q and K projections
    !---------------------------------------------------------------------------
    ptr_Q => matmul(this%params(1), input(1,1))    ! W_Q @ u: [k, batch]
    ptr_K => matmul(this%params(2), input(1,1))    ! W_K @ u: [k, batch]


    !---------------------------------------------------------------------------
    ! Stable interaction (bounded instead of raw product)
    !---------------------------------------------------------------------------
    ptr_coeff => ptr_Q * ptr_K * scale             ! scaled interaction
    ptr_coeff => tanh(ptr_coeff)                   ! bound to [-1, 1]
    ptr_coeff => softmax(ptr_coeff, 1)             ! [k, batch], sum_k = 1


    !---------------------------------------------------------------------------
    ! Spectral pathway: modulate spectral coefficients by attention scores
    !---------------------------------------------------------------------------
    ptr_spec => ono_encode(input(1,1), this%params(4), n, nb)  ! [k, batch]
    ptr_mod  => ptr_coeff * ptr_spec                           ! [k, batch]
    ptr_decoded => ono_decode(ptr_mod, this%params(4), n, nb)  ! [n, batch]

    ! Value projection
    ptr_attn => matmul(this%params(3), ptr_decoded)  ! [n_out, batch]

    ! Bypass: W @ u
    ptr_bypass => matmul(this%params(5), input(1,1))   ! [n_out, batch]

    ! Combine: attn_out + bypass
    ptr => ptr_attn + ptr_bypass

    ! Add bias
    if(this%use_bias)then
       ptr => ptr + this%params(6)
    end if

    ! Apply activation
    call this%output(1,1)%zero_grad()
    if(trim(this%activation%name) .eq. "none")then
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    else
       call this%z(1)%zero_grad()
       call this%z(1)%assign_and_deallocate_source(ptr)
       this%z(1)%is_temporary = .false.
       ptr => this%activation%apply(this%z(1))
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    end if
    this%output(1,1)%is_temporary = .false.

  end subroutine forward_ono_attn
!###############################################################################


!###############################################################################
  function get_attributes_ono_attn(this) result(attributes)
    !! Return list of orthogonal attention attributes for ONNX export
    implicit none

    ! Arguments
    class(orthogonal_attention_layer_type), intent(in) :: this
    !! Instance of the orthogonal attention layer
    type(onnx_attribute_type), allocatable, dimension(:) :: attributes
    !! List of attributes for ONNX export

    ! Local variables
    character(32) :: buffer
    !! Buffer for integer-to-string conversion

    allocate(attributes(6))

    write(buffer, '(I0)') this%num_inputs
    attributes(1) = onnx_attribute_type( &
         name='num_inputs', type='int', val=trim(buffer))
    write(buffer, '(I0)') this%num_outputs
    attributes(2) = onnx_attribute_type( &
         name='num_outputs', type='int', val=trim(buffer))
    write(buffer, '(I0)') this%num_basis
    attributes(3) = onnx_attribute_type( &
         name='num_basis', type='int', val=trim(buffer))
    write(buffer, '(I0)') this%key_dim
    attributes(4) = onnx_attribute_type( &
         name='key_dim', type='int', val=trim(buffer))
    if(this%use_bias)then
       buffer = '1'
    else
       buffer = '0'
    end if
    attributes(5) = onnx_attribute_type( &
         name='use_bias', type='int', val=trim(buffer))
    attributes(6) = onnx_attribute_type( &
         name='activation', type='string', val=trim(this%activation%name))

  end function get_attributes_ono_attn
!###############################################################################

end module athena__orthogonal_attention_layer