emit_onnx_nodes_fixed_lno Subroutine

private subroutine emit_onnx_nodes_fixed_lno(this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, input_name, is_last_layer, format)

Uses

    • coreutils

Emit decomposed standard ONNX nodes for a Fixed LNO layer.

Forward: v = sigma(D * R * E * u + W * u + b) where E and D are fixed Laplace bases, R is a learnable mixing matrix.

Type Bound

fixed_lno_layer_type

Arguments

Type IntentOptional Attributes Name
class(fixed_lno_layer_type), intent(in) :: this

Fixed LNO layer instance

character(len=*), intent(in) :: prefix

Layer name prefix (e.g. "layer1")

type(onnx_node_type), intent(inout), dimension(:) :: nodes

Node accumulator

integer, intent(inout) :: num_nodes

Node counter

integer, intent(in) :: max_nodes

Node limit

type(onnx_initialiser_type), intent(inout), dimension(:) :: inits

Initialiser accumulator

integer, intent(inout) :: num_inits

Initialiser counter

integer, intent(in) :: max_inits

Initialiser limit

character(len=*), intent(in), optional :: input_name

Name of the input tensor

logical, intent(in), optional :: is_last_layer

Whether this is the last layer

integer, intent(in), optional :: format

Export format selector


Source Code

  subroutine emit_onnx_nodes_fixed_lno( &
       this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, &
       max_inits, input_name, is_last_layer, format)
    !! Emit decomposed standard ONNX nodes for a Fixed LNO layer.
    !!
    !! Forward: v = sigma(D * R * E * u + W * u + b)
    !! where E and D are fixed Laplace bases, R is a learnable mixing matrix.
    use coreutils, only: pi
    implicit none

    ! Arguments
    class(fixed_lno_layer_type), intent(in) :: this
    !! Fixed LNO layer instance
    character(*), intent(in) :: prefix
    !! Layer name prefix (e.g. "layer1")
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    !! Node accumulator
    integer, intent(inout) :: num_nodes
    !! Node counter
    integer, intent(in) :: max_nodes
    !! Node limit
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    !! Initialiser accumulator
    integer, intent(inout) :: num_inits
    !! Initialiser counter
    integer, intent(in) :: max_inits
    !! Initialiser limit
    character(*), optional, intent(in) :: input_name
    !! Name of the input tensor
    logical, optional, intent(in) :: is_last_layer
    !! Whether this is the last layer
    integer, optional, intent(in) :: format
    !! Export format selector

    ! Local variables
    integer :: j, k, idx, n
    real(real32) :: s, t
    real(real32), allocatable :: e_data(:), d_data(:)
    character(128) :: e_name, d_name, r_name, w_name, b_name
    character(128) :: trans_in_out, mm_e_out, mm_r_out, mm_d_out
    character(128) :: mm_w_out, add_out, add_b_out, final_output, &
         output_source
    integer :: format_

    format_ = 1
    if(present(format)) format_ = format
    if(format_ .ne. 2) return
    if(.not.present(input_name)) return
    if(.not.present(is_last_layer)) return

    !--------------------------------------------------------------------------
    ! Build names
    !--------------------------------------------------------------------------
    write(e_name, '(A,".E")') trim(prefix)
    write(d_name, '(A,".D")') trim(prefix)
    write(r_name, '(A,".R")') trim(prefix)
    write(w_name, '(A,".W")') trim(prefix)
    write(b_name, '(A,".b")') trim(prefix)

    write(trans_in_out, '("/",A,"/Transpose_output_0")') trim(prefix)
    write(mm_e_out, '("/",A,"/MatMul_output_0")') trim(prefix)
    write(mm_r_out, '("/",A,"/MatMul_1_output_0")') trim(prefix)
    write(mm_d_out, '("/",A,"/MatMul_2_output_0")') trim(prefix)
    write(mm_w_out, '("/",A,"/MatMul_3_output_0")') trim(prefix)
    write(add_out, '("/",A,"/Add_output_0")') trim(prefix)
    write(add_b_out, '("/",A,"/Add_1_output_0")') trim(prefix)

    !--------------------------------------------------------------------------
    ! Emit nodes
    !--------------------------------------------------------------------------
    ! 1. Transpose(input)
    call emit_nop_input_transpose(trim(prefix), trim(input_name), nodes, &
         num_nodes, trim(trans_in_out))

    ! 2. MatMul(E, x_t)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(e_name)
    nodes(num_nodes)%inputs(2) = trim(trans_in_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_e_out)
    nodes(num_nodes)%attributes_json = ''

    ! 3. MatMul(R, encoded)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul_1")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(r_name)
    nodes(num_nodes)%inputs(2) = trim(mm_e_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_r_out)
    nodes(num_nodes)%attributes_json = ''

    ! 4. MatMul(D, mixed)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul_2")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(d_name)
    nodes(num_nodes)%inputs(2) = trim(mm_r_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_d_out)
    nodes(num_nodes)%attributes_json = ''

    ! 5. MatMul(W, x_t)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul_3")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(w_name)
    nodes(num_nodes)%inputs(2) = trim(trans_in_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_w_out)
    nodes(num_nodes)%attributes_json = ''

    ! 6. Add(spectral, local)
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Add")') trim(prefix)
    nodes(num_nodes)%op_type = 'Add'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(mm_d_out)
    nodes(num_nodes)%inputs(2) = trim(mm_w_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(add_out)
    nodes(num_nodes)%attributes_json = ''

    ! 7. Add(combined, bias)
    if(this%use_bias)then
       num_nodes = num_nodes + 1
       write(nodes(num_nodes)%name, '("/",A,"/Add_1")') trim(prefix)
       nodes(num_nodes)%op_type = 'Add'
       allocate(nodes(num_nodes)%inputs(2))
       nodes(num_nodes)%inputs(1) = trim(add_out)
       nodes(num_nodes)%inputs(2) = trim(b_name)
       allocate(nodes(num_nodes)%outputs(1))
       nodes(num_nodes)%outputs(1) = trim(add_b_out)
       nodes(num_nodes)%attributes_json = ''
    end if

    if(this%use_bias)then
       output_source = add_b_out
    else
       output_source = add_out
    end if
    call emit_nop_output_tail(trim(prefix), trim(this%activation%name), &
         is_last_layer, trim(output_source), nodes, num_nodes, final_output)

    !--------------------------------------------------------------------------
    ! Emit initialisers
    !--------------------------------------------------------------------------
    ! E: fixed encoder basis [M, n_in] in row-major
    n = this%num_modes * this%num_inputs
    allocate(e_data(n))
    do j = 1, this%num_inputs
       if(this%num_inputs .gt. 1)then
          t = real(j - 1, real32) / real(this%num_inputs - 1, real32)
       else
          t = 0.0_real32
       end if
       do k = 1, this%num_modes
          s = real(k, real32) * pi
          idx = (k - 1) * this%num_inputs + j
          e_data(idx) = exp(-s * t)
       end do
    end do
    call emit_float_initialiser(trim(e_name), e_data, &
         [this%num_modes, this%num_inputs], inits, num_inits)
    deallocate(e_data)

    ! D: fixed decoder basis [n_out, M] in row-major
    n = this%num_outputs * this%num_modes
    allocate(d_data(n))
    do k = 1, this%num_modes
       s = real(k, real32) * pi
       do j = 1, this%num_outputs
          if(this%num_outputs .gt. 1)then
             t = real(j - 1, real32) / real(this%num_outputs - 1, real32)
          else
             t = 0.0_real32
          end if
          idx = (j - 1) * this%num_modes + k
          d_data(idx) = exp(-s * t)
       end do
    end do
    call emit_float_initialiser(trim(d_name), d_data, &
         [this%num_outputs, this%num_modes], inits, num_inits)
    deallocate(d_data)

    ! R: spectral mixing [M, M] in row-major
    call emit_matrix_initialiser(trim(r_name), this%params(1)%val(:,1), &
         this%num_modes, this%num_modes, inits, num_inits)

    ! W: bypass weights [n_out, n_in] in row-major
    call emit_matrix_initialiser(trim(w_name), this%params(2)%val(:,1), &
         this%num_outputs, this%num_inputs, inits, num_inits)

    ! b: bias [n_out, 1]
    if(this%use_bias)then
       call emit_float_initialiser(trim(b_name), this%params(3)%val(:,1), &
            [this%num_outputs, 1], inits, num_inits)
    end if

  end subroutine emit_onnx_nodes_fixed_lno