module athena__orthogonal_nop_block
  !! Module containing implementation of an Orthogonal Neural Operator layer
  !!
  !! This module implements the Orthogonal Neural Operator (ONO) from
  !! "Improved Operator Learning by Orthogonal Attention" (Luo et al., 2024).
  !!
  !! The ONO layer uses an orthogonal attention kernel to approximate the
  !! integral operator. It combines:
  !!   1. A learned orthogonal basis for efficient attention (k << N)
  !!   2. A spectral pathway through the orthogonal basis
  !!   3. A local affine bypass
  !!
  !! The layer computes:
  !! \[
  !!   \mathbf{v} = \sigma\!\bigl(
  !!     \mathbf{W}_V\,\mathbf{\Phi}\,(\mathbf{\Phi}^T\,\mathbf{u})
  !!   + \mathbf{W}\,\mathbf{u}
  !!   + \mathbf{b}\bigr)
  !! \]
  !!
  !! where \(\mathbf{\Phi} \in \mathbb{R}^{n_{in} \times k}\) is obtained
  !! by QR/Gram-Schmidt orthogonalisation of learnable basis weights
  !! \(\mathbf{B}\).
  !!
  !! Parameters (learnable):
  !!   - \(\mathbf{R} \in \mathbb{R}^{k \times k}\) spectral mixing
  !!   - \(\mathbf{B} \in \mathbb{R}^{n_{in} \times k}\) basis (orthogonalised)
  !!   - \(\mathbf{W} \in \mathbb{R}^{n_{out} \times n_{in}}\) bypass
  !!   - \(\mathbf{b} \in \mathbb{R}^{n_{out}}\) bias (optional)
  !!
  !! The spectral path is: decode( R * encode(u) )
  !!   encode = Phi^T @ u       [k, batch]
  !!   mix    = R @ encoded      [k, batch]
  !!   decode = Phi @ mix        [n_in, batch]
  !!
  !! Then a linear projection to the output: W_out @ decode -> [n_out, batch]
  !! Plus bypass: W @ u -> [n_out, batch]
  use coreutils, only: real32, stop_program
  use athena__base_layer, only: learnable_layer_type, base_layer_type
  use athena__misc_types, only: base_actv_type, base_init_type, &
       onnx_attribute_type
  use diffstruc, only: array_type, matmul, operator(+)
  use athena__diffstruc_extd, only: ono_encode, ono_decode
  implicit none
  public :: read_orthogonal_nop_block


  type, extends(learnable_layer_type) :: orthogonal_nop_block_type
     !! Type for an Orthogonal Neural Operator layer
     integer :: num_inputs = 0
     !! Number of inputs (discretisation points)
     integer :: num_outputs = 0
     !! Number of outputs (discretisation points)
     integer :: num_basis = 0
     !! Number of orthogonal basis functions (k)
     type(array_type), dimension(1) :: z
     !! Temporary array for pre-activation values
   contains
     procedure, pass(this) :: get_num_params => get_num_params_ono
     procedure, pass(this) :: set_hyperparams => set_hyperparams_ono
     procedure, pass(this) :: init => init_ono
     procedure, pass(this) :: print_to_unit => print_to_unit_ono
     procedure, pass(this) :: read => read_ono

     procedure, pass(this) :: forward => forward_ono
     procedure, pass(this) :: get_bases => get_bases_ono
     procedure, pass(this) :: get_orthogonality_metric
     procedure, pass(this) :: get_attributes => get_attributes_ono

     final :: finalise_ono
  end type orthogonal_nop_block_type

  interface orthogonal_nop_block_type
     module function layer_setup( &
          num_outputs, num_basis, &
          num_inputs, use_bias, &
          activation, &
          kernel_initialiser, bias_initialiser, verbose &
     ) result(layer)
       integer, intent(in) :: num_outputs
       integer, intent(in) :: num_basis
       integer, optional, intent(in) :: num_inputs
       logical, optional, intent(in) :: use_bias
       class(*), optional, intent(in) :: activation
       class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
       integer, optional, intent(in) :: verbose
       type(orthogonal_nop_block_type) :: layer
     end function layer_setup
  end interface orthogonal_nop_block_type



contains

!###############################################################################
  subroutine finalise_ono(this)
    !! Finalise the orthogonal neural operator block
    implicit none

    ! Arguments
    type(orthogonal_nop_block_type), intent(inout) :: this
    !! Layer instance to release

    if(allocated(this%input_shape)) deallocate(this%input_shape)
    if(allocated(this%output)) deallocate(this%output)
    if(this%z(1)%allocated) call this%z(1)%deallocate()

  end subroutine finalise_ono
!###############################################################################


!###############################################################################
  pure function get_num_params_ono(this) result(num_params)
    !! Return the number of learnable parameters for the block
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(in) :: this
    !! Layer instance
    integer :: num_params
    !! Total number of learnable parameters

    ! R:     num_basis^2         (spectral mixing)
    ! B:     num_inputs * num_basis (basis weights)
    ! W_out: num_outputs * num_inputs (output projection / bypass)
    ! b:     num_outputs (optional)
    num_params = this%num_basis * this%num_basis + &
         this%num_inputs * this%num_basis + &
         this%num_outputs * this%num_inputs
    if(this%use_bias) num_params = num_params + this%num_outputs

  end function get_num_params_ono
!###############################################################################


!###############################################################################
  module function layer_setup( &
       num_outputs, num_basis, &
       num_inputs, use_bias, &
       activation, &
       kernel_initialiser, bias_initialiser, verbose &
  ) result(layer)
    use athena__activation, only: activation_setup
    use athena__initialiser, only: initialiser_setup
    implicit none

    ! Arguments
    integer, intent(in) :: num_outputs
    !! Number of output features
    integer, intent(in) :: num_basis
    !! Number of orthogonal basis vectors
    integer, optional, intent(in) :: num_inputs
    !! Number of input features when known at construction time
    logical, optional, intent(in) :: use_bias
    !! Whether to allocate a bias term
    class(*), optional, intent(in) :: activation
    !! Activation function specification
    class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
    !! Kernel and bias initialiser specifications
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    type(orthogonal_nop_block_type) :: layer
    !! Constructed orthogonal neural operator block

    ! Local variables
    integer :: verbose_ = 0
    !! Effective verbosity level
    logical :: use_bias_ = .true.
    !! Effective bias flag
    class(base_actv_type), allocatable :: activation_
    !! Materialised activation object
    class(base_init_type), allocatable :: kernel_initialiser_, bias_initialiser_
    !! Materialised kernel and bias initialisers

    if(present(verbose)) verbose_ = verbose
    if(present(use_bias)) use_bias_ = use_bias

    if(present(activation))then
       activation_ = activation_setup(activation)
    else
       activation_ = activation_setup("none")
    end if

    if(present(kernel_initialiser))then
       kernel_initialiser_ = initialiser_setup(kernel_initialiser)
    end if
    if(present(bias_initialiser))then
       bias_initialiser_ = initialiser_setup(bias_initialiser)
    end if

    call layer%set_hyperparams( &
         num_outputs = num_outputs, &
         num_basis = num_basis, &
         use_bias = use_bias_, &
         activation = activation_, &
         kernel_initialiser = kernel_initialiser_, &
         bias_initialiser = bias_initialiser_, &
         verbose = verbose_ &
    )

    if(present(num_inputs)) call layer%init(input_shape=[num_inputs])

  end function layer_setup
!###############################################################################


!###############################################################################
  subroutine set_hyperparams_ono( &
       this, num_outputs, num_basis, &
       use_bias, &
       activation, &
       kernel_initialiser, bias_initialiser, &
       verbose &
  )
    use athena__activation, only: activation_setup
    use athena__initialiser, only: get_default_initialiser, initialiser_setup
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(inout) :: this
    !! Layer instance to configure
    integer, intent(in) :: num_outputs
    !! Number of output features
    integer, intent(in) :: num_basis
    !! Number of orthogonal basis vectors
    logical, intent(in) :: use_bias
    !! Whether to use a bias term
    class(base_actv_type), allocatable, intent(in) :: activation
    !! Activation function object
    class(base_init_type), allocatable, intent(in) :: &
         kernel_initialiser, bias_initialiser
    !! Kernel and bias initialiser objects
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    character(len=256) :: buffer
    !! Buffer for default initialiser lookup

    this%name = "orthogonal_nop"
    this%type = "nop"
    this%input_rank = 1
    this%output_rank = 1
    this%use_bias = use_bias
    this%num_outputs = num_outputs
    this%num_basis = num_basis

    if(allocated(this%activation)) deallocate(this%activation)
    if(.not.allocated(activation))then
       this%activation = activation_setup("none")
    else
       allocate(this%activation, source=activation)
    end if
    if(allocated(this%kernel_init)) deallocate(this%kernel_init)
    if(.not.allocated(kernel_initialiser))then
       buffer = get_default_initialiser(this%activation%name)
       this%kernel_init = initialiser_setup(buffer)
    else
       allocate(this%kernel_init, source=kernel_initialiser)
    end if
    if(allocated(this%bias_init)) deallocate(this%bias_init)
    if(.not.allocated(bias_initialiser))then
       buffer = get_default_initialiser( &
            this%activation%name, &
            is_bias=.true. &
       )
       this%bias_init = initialiser_setup(buffer)
    else
       allocate(this%bias_init, source=bias_initialiser)
    end if

    if(present(verbose))then
       if(abs(verbose).gt.0)then
          write(*,'("ORTHOGONAL_NOP activation: ",A)') &
               trim(this%activation%name)
       end if
    end if

  end subroutine set_hyperparams_ono
!###############################################################################


!###############################################################################
  subroutine init_ono(this, input_shape, verbose)
    !! Initialise parameter storage and output buffers for the block
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(inout) :: this
    !! Layer instance to initialise
    integer, dimension(:), intent(in) :: input_shape
    !! Input shape used to infer num_inputs
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: num_inputs
    !! Effective fan-in size used for initialisation
    integer :: verbose_ = 0
    !! Effective verbosity level

    if(present(verbose)) verbose_ = verbose


    !---------------------------------------------------------------------------
    ! Set shapes
    !---------------------------------------------------------------------------
    if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
    this%num_inputs = this%input_shape(1)
    this%output_shape = [this%num_outputs]
    this%num_params = this%get_num_params()


    !---------------------------------------------------------------------------
    ! Allocate learnable parameters
    !
    ! params(1): R      spectral mixing       [num_basis x num_basis]
    ! params(2): B      basis weights          [num_inputs x num_basis]
    ! params(3): W      bypass/output weights  [num_outputs x num_inputs]
    ! params(4): b      bias                   [num_outputs]  (optional)
    !---------------------------------------------------------------------------
    allocate(this%weight_shape(2,3))
    this%weight_shape(:,1) = [ this%num_basis, this%num_basis ]
    this%weight_shape(:,2) = [ this%num_inputs, this%num_basis ]
    this%weight_shape(:,3) = [ this%num_outputs, this%num_inputs ]

    if(this%use_bias)then
       this%bias_shape = [ this%num_outputs ]
       allocate(this%params(4))
    else
       allocate(this%params(3))
    end if

    num_inputs = this%num_inputs
    if(this%use_bias) num_inputs = this%num_inputs + 1

    ! R: spectral mixing weights
    call this%params(1)%allocate([this%num_basis, this%num_basis, 1])
    call this%params(1)%set_requires_grad(.true.)
    this%params(1)%fix_pointer = .true.
    this%params(1)%is_sample_dependent = .false.
    this%params(1)%is_temporary = .false.

    ! B: basis weights (stored flat for Gram-Schmidt, but allocated shaped)
    call this%params(2)%allocate([this%num_inputs, this%num_basis, 1])
    call this%params(2)%set_requires_grad(.true.)
    this%params(2)%fix_pointer = .true.
    this%params(2)%is_sample_dependent = .false.
    this%params(2)%is_temporary = .false.

    ! W: bypass/output weights
    call this%params(3)%allocate([this%num_outputs, this%num_inputs, 1])
    call this%params(3)%set_requires_grad(.true.)
    this%params(3)%fix_pointer = .true.
    this%params(3)%is_sample_dependent = .false.
    this%params(3)%is_temporary = .false.

    if(this%use_bias)then
       call this%params(4)%allocate([this%bias_shape, 1])
       call this%params(4)%set_requires_grad(.true.)
       this%params(4)%fix_pointer = .true.
       this%params(4)%is_sample_dependent = .false.
       this%params(4)%is_temporary = .false.
    end if


    !---------------------------------------------------------------------------
    ! Initialise learnable parameters
    !---------------------------------------------------------------------------
    call this%kernel_init%initialise( &
         this%params(1)%val(:,1), &
         fan_in = this%num_basis, fan_out = this%num_basis, &
         spacing = [ this%num_basis ] &
    )
    call this%kernel_init%initialise( &
         this%params(2)%val(:,1), &
         fan_in = this%num_inputs, fan_out = this%num_basis, &
         spacing = [ this%num_inputs ] &
    )
    call this%kernel_init%initialise( &
         this%params(3)%val(:,1), &
         fan_in = num_inputs, fan_out = this%num_outputs, &
         spacing = [ this%num_outputs ] &
    )
    if(this%use_bias)then
       call this%bias_init%initialise( &
            this%params(4)%val(:,1), &
            fan_in = num_inputs, fan_out = this%num_outputs &
       )
    end if


    !---------------------------------------------------------------------------
    ! Allocate output arrays
    !---------------------------------------------------------------------------
    if(allocated(this%output)) deallocate(this%output)
    allocate(this%output(1,1))
    if(this%z(1)%allocated) call this%z(1)%deallocate()

  end subroutine init_ono
!###############################################################################


!###############################################################################
  function get_bases_ono(this) result(phi)
    !! Orthogonalise the basis matrix B using modified Gram-Schmidt
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(in) :: this
    !! Layer instance providing basis parameters
    type(array_type) :: phi
    !! Orthogonalised basis matrix packed in an array_type

    ! Local variables
    integer :: n, k, i, j
    !! Basis dimensions and Gram-Schmidt loop indices
    real(real32), allocatable :: B(:,:), Q(:,:)
    !! Raw basis matrix and orthogonalised copy
    real(real32) :: norm_val, proj
    !! Gram-Schmidt norm and projection scalars

    n = this%num_inputs
    k = this%num_basis

    allocate(B(n, k), Q(n, k))

    ! Reshape B from flat params(2) into [n, k]
    B = reshape(this%params(2)%val(:,1), [n, k])

    ! Modified Gram-Schmidt orthogonalisation
    Q = B
    do j = 1, k
       do i = 1, j - 1
          proj = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - proj * Q(:,i)
       end do
       norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / norm_val
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    ! Store phi [n x k]
    call phi%allocate([n, k, 1])
    phi%is_sample_dependent = .false.
    phi%requires_grad = .false.
    phi%fix_pointer = .true.
    phi%is_temporary = .false.
    phi%val(:,1) = reshape(Q, [n * k])

    deallocate(B, Q)

  end function get_bases_ono
!###############################################################################


!###############################################################################
  function get_orthogonality_metric(this) result(metric)
    !! Compute max(|Phi^T @ Phi - I|) as a measure of basis orthogonality
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(in) :: this
    !! Layer instance providing basis parameters
    real(real32) :: metric
    !! Maximum absolute deviation from orthogonality

    ! Local variables
   integer :: n, k, i, j
    !! Matrix dimensions and traversal indices
   real(real32), allocatable :: basis_matrix(:,:), orthogonal_basis(:,:)
   !! Raw basis weights and orthogonalised basis matrix
   real(real32) :: norm_val, projection, val
   !! Gram-Schmidt scalars and current absolute deviation entry

    n = this%num_inputs
    k = this%num_basis

    allocate(basis_matrix(n, k), orthogonal_basis(n, k))
    basis_matrix = reshape(this%params(2)%val(:,1), [n, k])
    orthogonal_basis = basis_matrix

    do j = 1, k
       do i = 1, j - 1
          projection = dot_product(orthogonal_basis(:,i), orthogonal_basis(:,j))
          orthogonal_basis(:,j) = orthogonal_basis(:,j) - &
               projection * orthogonal_basis(:,i)
       end do
       norm_val = sqrt(dot_product( &
            orthogonal_basis(:,j), orthogonal_basis(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          orthogonal_basis(:,j) = orthogonal_basis(:,j) / norm_val
       else
          orthogonal_basis(:,j) = 0.0_real32
       end if
    end do

    ! max(|Q^T Q - I|)
    metric = 0.0_real32
    do j = 1, k
       do i = 1, k
          val = dot_product(orthogonal_basis(:,i), orthogonal_basis(:,j))
          if(i .eq. j)then
             val = abs(val - 1.0_real32)
          else
             val = abs(val)
          end if
          if(val .gt. metric) metric = val
       end do
    end do

    deallocate(basis_matrix, orthogonal_basis)

  end function get_orthogonality_metric
!###############################################################################


!###############################################################################
  subroutine print_to_unit_ono(this, unit)
    !! Print orthogonal neural operator settings and parameters to a unit
    use coreutils, only: to_upper
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(in) :: this
    !! Layer instance to print
    integer, intent(in) :: unit
    !! Output unit number

    write(unit,'(3X,"NUM_INPUTS = ",I0)') this%num_inputs
    write(unit,'(3X,"NUM_OUTPUTS = ",I0)') this%num_outputs
    write(unit,'(3X,"NUM_BASIS = ",I0)') this%num_basis
    write(unit,'(3X,"USE_BIAS = ",L1)') this%use_bias
    if(this%activation%name .ne. 'none')then
       call this%activation%print_to_unit(unit)
    end if

    write(unit,'("WEIGHTS")')
    write(unit,'(5(E16.8E2))') this%params(1)%val(:,1)   ! R
    write(unit,'(5(E16.8E2))') this%params(2)%val(:,1)   ! B
    write(unit,'(5(E16.8E2))') this%params(3)%val(:,1)   ! W
    if(this%use_bias)then
       write(unit,'(5(E16.8E2))') this%params(4)%val(:,1) ! b
    end if
    write(unit,'("END WEIGHTS")')

  end subroutine print_to_unit_ono
!###############################################################################


!###############################################################################
  subroutine read_ono(this, unit, verbose)
    use athena__tools_infile, only: assign_val, assign_vec, move
    use coreutils, only: to_lower, to_upper, icount
    use athena__activation, only: read_activation
    use athena__initialiser, only: initialiser_setup
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(inout) :: this
    !! Layer instance to populate from file data
    integer, intent(in) :: unit
    !! Input unit number
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: stat, verbose_ = 0
    !! I/O status and effective verbosity level
    integer :: j, k, c, itmp1, iline
    !! Loop counters and parser scratch integers
    integer :: num_inputs, num_outputs, num_basis
    !! Parsed layer dimensions
    logical :: use_bias = .true.
    !! Parsed bias flag
    character(14) :: kernel_initialiser_name='', bias_initialiser_name=''
    !! Parsed initialiser names
    class(base_actv_type), allocatable :: activation
    !! Parsed activation object
    class(base_init_type), allocatable :: kernel_initialiser, bias_initialiser
    !! Parsed initialiser objects
    character(256) :: buffer, tag, err_msg
    !! Input buffer, parsed tag and formatted error message
    real(real32), allocatable, dimension(:) :: data_list
    !! Temporary storage for flattened parameter blocks
    integer :: param_line, final_line, num_vals
    !! Weights-section line markers and current block size

    if(present(verbose)) verbose_ = verbose

    iline = 0
    param_line = 0
    final_line = 0
    tag_loop: do
       read(unit,'(A)',iostat=stat) buffer
       if(stat.ne.0)then
          write(err_msg,'("file encountered error (EoF?) before END ",A)') &
               to_upper(this%name)
          call stop_program(err_msg)
          return
       end if
       if(trim(adjustl(buffer)).eq."") cycle tag_loop

       if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
          final_line = iline
          backspace(unit)
          exit tag_loop
       end if
       iline = iline + 1

       tag=trim(adjustl(buffer))
       if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))

       select case(trim(tag))
       case("NUM_INPUTS")
          call assign_val(buffer, num_inputs, itmp1)
       case("NUM_OUTPUTS")
          call assign_val(buffer, num_outputs, itmp1)
       case("NUM_BASIS")
          call assign_val(buffer, num_basis, itmp1)
       case("USE_BIAS")
          call assign_val(buffer, use_bias, itmp1)
       case("ACTIVATION")
          iline = iline - 1
          backspace(unit)
          activation = read_activation(unit, iline)
       case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALIZER")
          call assign_val(buffer, kernel_initialiser_name, itmp1)
       case("BIAS_INITIALISER", "BIAS_INIT", "BIAS_INITIALIZER")
          call assign_val(buffer, bias_initialiser_name, itmp1)
       case("WEIGHTS")
          kernel_initialiser_name = 'zeros'
          bias_initialiser_name   = 'zeros'
          param_line = iline
       case default
          if(scan(to_lower(trim(adjustl(buffer))),&
               'abcdfghijklmnopqrstuvwxyz').eq.0)then
             cycle tag_loop
          elseif(tag(:3).eq.'END')then
             cycle tag_loop
          end if
          write(err_msg,'("Unrecognised line in input file: ",A)') &
               trim(adjustl(buffer))
          call stop_program(err_msg)
          return
       end select
    end do tag_loop
    kernel_initialiser = initialiser_setup(kernel_initialiser_name)
    bias_initialiser = initialiser_setup(bias_initialiser_name)

    call this%set_hyperparams( &
         num_outputs = num_outputs, &
         num_basis = num_basis, &
         use_bias = use_bias, &
         activation = activation, &
         kernel_initialiser = kernel_initialiser, &
         bias_initialiser = bias_initialiser, &
         verbose = verbose_ &
    )
    call this%init(input_shape=[num_inputs])

    if(param_line.eq.0)then
       write(0,*) "WARNING: WEIGHTS card in " // trim(this%name) // " not found"
    else
       call move(unit, param_line - iline, iostat=stat)

       ! Read R (num_basis^2)
       num_vals = num_basis * num_basis
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(1)%val(:,1) = data_list
       deallocate(data_list)

       ! Read B (num_inputs * num_basis)
       num_vals = num_inputs * num_basis
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(2)%val(:,1) = data_list
       deallocate(data_list)

       ! Read W (num_outputs * num_inputs)
       num_vals = num_outputs * num_inputs
       allocate(data_list(num_vals), source=0._real32)
       c = 1; k = 1
       do while(c.le.num_vals)
          read(unit,'(A)',iostat=stat) buffer
          if(stat.ne.0) exit
          k = icount(buffer)
          read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
          c = c + k
       end do
       this%params(3)%val(:,1) = data_list
       deallocate(data_list)

       ! Read b if use_bias
       if(use_bias)then
          allocate(data_list(num_outputs), source=0._real32)
          c = 1; k = 1
          do while(c.le.num_outputs)
             read(unit,'(A)',iostat=stat) buffer
             if(stat.ne.0) exit
             k = icount(buffer)
             read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
             c = c + k
          end do
          this%params(4)%val(:,1) = data_list(1:num_outputs)
          deallocate(data_list)
       end if

       read(unit,'(A)') buffer
       if(trim(adjustl(buffer)).ne."END WEIGHTS")then
          call stop_program("END WEIGHTS not where expected")
          return
       end if
    end if

    call move(unit, final_line - iline, iostat=stat)
    read(unit,'(A)') buffer
    if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
       write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
       call stop_program(err_msg)
       return
    end if

  end subroutine read_ono
!###############################################################################


!###############################################################################
  function read_orthogonal_nop_block(unit, verbose) result(layer)
    !! Read an orthogonal neural operator block from file and return it
    implicit none

    ! Arguments
    integer, intent(in) :: unit
    !! Input unit number
    integer, optional, intent(in) :: verbose
    !! Verbosity level
    class(base_layer_type), allocatable :: layer
    !! Allocated base-layer instance containing the result

    ! Local variables
    integer :: verbose_ = 0
    !! Effective verbosity level

    if(present(verbose)) verbose_ = verbose
    allocate(layer, source=orthogonal_nop_block_type( &
         num_outputs=0, num_basis=1))
    call layer%read(unit, verbose=verbose_)

  end function read_orthogonal_nop_block
!###############################################################################


!###############################################################################
  subroutine forward_ono(this, input)
    !! Forward propagation for the Orthogonal Neural Operator layer
    !!
    !! Computes:
    !!   encoded = Phi^T @ u          [k, batch]
    !!   mixed   = R @ encoded        [k, batch]
    !!   decoded = Phi @ mixed        [n_in, batch]
    !!   spectral= W @ decoded        [n_out, batch]  (reuse W for output proj)
    !!
    !!   bypass  = W @ u              [n_out, batch]
    !!
    !!   v = sigma( spectral + bypass + b )
    !!
    !! Actually, we separate the spectral and bypass paths clearly:
    !!   spectral path uses the orthogonal basis + R mixing
    !!   bypass path uses W directly on input
    !!   Both project to [n_out] via W (shared) or separate matrices.
    !!
    !! Here we implement:
    !!   spectral = W @ Phi @ R @ Phi^T @ u
    !!   bypass   = W @ u
    !!   v = sigma( spectral + bypass + b )
    !!
    !! Note: W is params(3) [n_out x n_in], shared for both paths
    !! This means: v = sigma( W @ (Phi @ R @ Phi^T @ u + u) + b )
    !!           = sigma( W @ ((Phi @ R @ Phi^T + I) @ u) + b )
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(inout) :: this
    !! Layer instance to execute
    class(array_type), dimension(:,:), intent(in) :: input
    !! Input batch tensor collection

    ! Local variables
    type(array_type), pointer :: ptr, ptr_spec, ptr_bypass
    !! Combined output, spectral-path output and bypass-path output
    type(array_type), pointer :: ptr_encoded, ptr_mixed, ptr_decoded
    !! Encoded spectrum, mixed spectrum and decoded tensor


    ! Spectral pathway: Phi @ R @ Phi^T @ u
    ! Uses autodiff-tracked ono_encode/ono_decode for basis gradients
    !---------------------------------------------------------------------------

    ! Encode: Q(B)^T @ u  -> [k, batch]
    ptr_encoded => ono_encode(input(1,1), this%params(2), &
         this%num_inputs, this%num_basis)

    ! Mix: R @ encoded   -> [k, batch]
    ptr_mixed => matmul(this%params(1), ptr_encoded)

    ! Decode: Q(B) @ mixed -> [n_in, batch]
    ptr_decoded => ono_decode(ptr_mixed, this%params(2), &
         this%num_inputs, this%num_basis)

    ! Spectral projection: W @ decoded -> [n_out, batch]
    ptr_spec => matmul(this%params(3), ptr_decoded)

    ! Bypass: W @ u -> [n_out, batch]
    ptr_bypass => matmul(this%params(3), input(1,1))

    ! Combine
    ptr => ptr_spec + ptr_bypass

    ! Add bias
    if(this%use_bias)then
       ptr => ptr + this%params(4)
    end if

    ! Apply activation
    call this%output(1,1)%zero_grad()
    if(trim(this%activation%name) .eq. "none")then
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    else
       call this%z(1)%zero_grad()
       call this%z(1)%assign_and_deallocate_source(ptr)
       this%z(1)%is_temporary = .false.
       ptr => this%activation%apply(this%z(1))
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    end if
    this%output(1,1)%is_temporary = .false.

  end subroutine forward_ono
!###############################################################################


!###############################################################################
  function get_attributes_ono(this) result(attributes)
    !! Return list of ONO attributes for ONNX export
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(in) :: this
    !! Instance of the ONO block
    type(onnx_attribute_type), allocatable, dimension(:) :: attributes
    !! List of attributes for ONNX export

    ! Local variables
    character(32) :: buffer
    !! Buffer for formatting

    allocate(attributes(5))

    write(buffer, '(I0)') this%num_inputs
    attributes(1) = onnx_attribute_type( &
         name='num_inputs', type='int', val=trim(buffer))
    write(buffer, '(I0)') this%num_outputs
    attributes(2) = onnx_attribute_type( &
         name='num_outputs', type='int', val=trim(buffer))
    write(buffer, '(I0)') this%num_basis
    attributes(3) = onnx_attribute_type( &
         name='num_basis', type='int', val=trim(buffer))
    if(this%use_bias)then
       buffer = '1'
    else
       buffer = '0'
    end if
    attributes(4) = onnx_attribute_type( &
         name='use_bias', type='int', val=trim(buffer))
    attributes(5) = onnx_attribute_type( &
         name='activation', type='string', val=trim(this%activation%name))

  end function get_attributes_ono
!###############################################################################

end module athena__orthogonal_nop_block
