emit_onnx_nodes_dynamic_lno Subroutine

private subroutine emit_onnx_nodes_dynamic_lno(this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, input_name, is_last_layer, format)

Emit decomposed standard ONNX nodes for a Dynamic LNO layer.

Decomposes the forward pass v = sigma(D(mu)diag(beta)E(mu)u + Wu + b) into: Exp, MatMul, Mul, Add, Transpose, and optional Relu nodes.

Type Bound

dynamic_lno_layer_type

Arguments

Type IntentOptional Attributes Name
class(dynamic_lno_layer_type), intent(in) :: this

Dynamic LNO layer instance

character(len=*), intent(in) :: prefix

Layer name prefix (e.g. "layer1")

type(onnx_node_type), intent(inout), dimension(:) :: nodes

Node accumulator

integer, intent(inout) :: num_nodes

Node counter

integer, intent(in) :: max_nodes

Node limit

type(onnx_initialiser_type), intent(inout), dimension(:) :: inits

Initialiser accumulator

integer, intent(inout) :: num_inits

Initialiser counter

integer, intent(in) :: max_inits

Initialiser limit

character(len=*), intent(in), optional :: input_name

Name of the input tensor (e.g. "input" or previous layer output)

logical, intent(in), optional :: is_last_layer

Whether this is the last layer in the network

integer, intent(in), optional :: format

Export format selector


Source Code

  subroutine emit_onnx_nodes_dynamic_lno( &
       this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, &
       max_inits, input_name, is_last_layer, format)
    !! Emit decomposed standard ONNX nodes for a Dynamic LNO layer.
    !!
    !! Decomposes the forward pass v = sigma(D(mu)*diag(beta)*E(mu)*u + W*u + b)
    !! into: Exp, MatMul, Mul, Add, Transpose, and optional Relu nodes.
    implicit none

    ! Arguments
    class(dynamic_lno_layer_type), intent(in) :: this
    !! Dynamic LNO layer instance
    character(*), intent(in) :: prefix
    !! Layer name prefix (e.g. "layer1")
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    !! Node accumulator
    integer, intent(inout) :: num_nodes
    !! Node counter
    integer, intent(in) :: max_nodes
    !! Node limit
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    !! Initialiser accumulator
    integer, intent(inout) :: num_inits
    !! Initialiser counter
    integer, intent(in) :: max_inits
    !! Initialiser limit
    character(*), optional, intent(in) :: input_name
    !! Name of the input tensor (e.g. "input" or previous layer output)
    logical, optional, intent(in) :: is_last_layer
    !! Whether this is the last layer in the network
    integer, optional, intent(in) :: format
    !! Export format selector

    ! Local variables
    integer :: j, k, idx, n
    real(real32) :: s, t
    real(real32), allocatable :: e_args(:), d_args(:)
    character(128) :: e_args_name, d_args_name, beta_name, w_name, b_name
    character(128) :: exp_e_out, exp_d_out, trans_in_out
    character(128) :: mm_e_out, mul_out, mm_d_out, mm_w_out
    character(128) :: add_out, add_b_out, final_output, output_source
    integer :: format_

    format_ = 1
    if(present(format)) format_ = format
    if(format_ .ne. 2) return
    if(.not.present(input_name)) return
    if(.not.present(is_last_layer)) return

    !--------------------------------------------------------------------------
    ! Build initialiser names
    !--------------------------------------------------------------------------
    write(e_args_name, '(A,".E_args")') trim(prefix)
    write(d_args_name, '(A,".D_args")') trim(prefix)
    write(beta_name, '(A,".beta")') trim(prefix)
    write(w_name, '(A,".W")') trim(prefix)
    write(b_name, '(A,".b")') trim(prefix)

    !--------------------------------------------------------------------------
    ! Build intermediate tensor names
    !--------------------------------------------------------------------------
    write(exp_e_out, '("/",A,"/Exp_output_0")') trim(prefix)
    write(exp_d_out, '("/",A,"/Exp_1_output_0")') trim(prefix)
    write(trans_in_out, '("/",A,"/Transpose_output_0")') trim(prefix)
    write(mm_e_out, '("/",A,"/MatMul_output_0")') trim(prefix)
    write(mul_out, '("/",A,"/Mul_output_0")') trim(prefix)
    write(mm_d_out, '("/",A,"/MatMul_1_output_0")') trim(prefix)
    write(mm_w_out, '("/",A,"/MatMul_2_output_0")') trim(prefix)
    write(add_out, '("/",A,"/Add_output_0")') trim(prefix)
    write(add_b_out, '("/",A,"/Add_1_output_0")') trim(prefix)

    !--------------------------------------------------------------------------
    ! Emit ONNX nodes
    !--------------------------------------------------------------------------
    ! 1. Exp(E_args) -> E [M, n_in]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Exp")') trim(prefix)
    nodes(num_nodes)%op_type = 'Exp'
    allocate(nodes(num_nodes)%inputs(1))
    nodes(num_nodes)%inputs(1) = trim(e_args_name)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(exp_e_out)
    nodes(num_nodes)%attributes_json = ''

    ! 2. Exp(D_args) -> D [n_out, M]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Exp_1")') trim(prefix)
    nodes(num_nodes)%op_type = 'Exp'
    allocate(nodes(num_nodes)%inputs(1))
    nodes(num_nodes)%inputs(1) = trim(d_args_name)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(exp_d_out)
    nodes(num_nodes)%attributes_json = ''

    ! 3. Transpose(input) -> x_t [n_in, batch]
    call emit_nop_input_transpose(trim(prefix), trim(input_name), nodes, &
         num_nodes, trim(trans_in_out))

    ! 4. MatMul(E, x_t) -> encoded [M, batch]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(exp_e_out)
    nodes(num_nodes)%inputs(2) = trim(trans_in_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_e_out)
    nodes(num_nodes)%attributes_json = ''

    ! 5. Mul(beta, encoded) -> scaled [M, batch]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Mul")') trim(prefix)
    nodes(num_nodes)%op_type = 'Mul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(beta_name)
    nodes(num_nodes)%inputs(2) = trim(mm_e_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mul_out)
    nodes(num_nodes)%attributes_json = ''

    ! 6. MatMul(D, scaled) -> spectral [n_out, batch]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul_1")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(exp_d_out)
    nodes(num_nodes)%inputs(2) = trim(mul_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_d_out)
    nodes(num_nodes)%attributes_json = ''

    ! 7. MatMul(W, x_t) -> local [n_out, batch]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/MatMul_2")') trim(prefix)
    nodes(num_nodes)%op_type = 'MatMul'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(w_name)
    nodes(num_nodes)%inputs(2) = trim(trans_in_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(mm_w_out)
    nodes(num_nodes)%attributes_json = ''

    ! 8. Add(spectral, local) -> combined [n_out, batch]
    num_nodes = num_nodes + 1
    write(nodes(num_nodes)%name, '("/",A,"/Add")') trim(prefix)
    nodes(num_nodes)%op_type = 'Add'
    allocate(nodes(num_nodes)%inputs(2))
    nodes(num_nodes)%inputs(1) = trim(mm_d_out)
    nodes(num_nodes)%inputs(2) = trim(mm_w_out)
    allocate(nodes(num_nodes)%outputs(1))
    nodes(num_nodes)%outputs(1) = trim(add_out)
    nodes(num_nodes)%attributes_json = ''

    ! 9. Add(combined, bias) -> biased [n_out, batch]
    if(this%use_bias)then
       num_nodes = num_nodes + 1
       write(nodes(num_nodes)%name, '("/",A,"/Add_1")') trim(prefix)
       nodes(num_nodes)%op_type = 'Add'
       allocate(nodes(num_nodes)%inputs(2))
       nodes(num_nodes)%inputs(1) = trim(add_out)
       nodes(num_nodes)%inputs(2) = trim(b_name)
       allocate(nodes(num_nodes)%outputs(1))
       nodes(num_nodes)%outputs(1) = trim(add_b_out)
       nodes(num_nodes)%attributes_json = ''
    end if

    if(this%use_bias)then
       output_source = add_b_out
    else
       output_source = add_out
    end if
    call emit_nop_output_tail(trim(prefix), trim(this%activation%name), &
         is_last_layer, trim(output_source), nodes, num_nodes, final_output)

    !--------------------------------------------------------------------------
    ! Emit initialisers
    !--------------------------------------------------------------------------

    ! W: bypass weights [n_out, n_in] in row-major
    n = this%num_outputs * this%num_inputs
    call emit_matrix_initialiser(trim(w_name), this%params(3)%val(:,1), &
         this%num_outputs, this%num_inputs, inits, num_inits)

    ! E_args: -mu*t for encoder [M, n_in] in row-major
    n = this%num_modes * this%num_inputs
    allocate(e_args(n))
    do j = 1, this%num_inputs
       if(this%num_inputs .gt. 1)then
          t = real(j - 1, real32) / real(this%num_inputs - 1, real32)
       else
          t = 0.0_real32
       end if
       do k = 1, this%num_modes
          s = this%params(1)%val(k, 1)
          idx = (k - 1) * this%num_inputs + j
          e_args(idx) = -s * t
       end do
    end do
    call emit_float_initialiser(trim(e_args_name), e_args, &
         [this%num_modes, this%num_inputs], inits, num_inits)
    deallocate(e_args)

    ! D_args: -mu*tau for decoder [n_out, M] in row-major
    n = this%num_outputs * this%num_modes
    allocate(d_args(n))
    do k = 1, this%num_modes
       s = this%params(1)%val(k, 1)
       do j = 1, this%num_outputs
          if(this%num_outputs .gt. 1)then
             t = real(j - 1, real32) / real(this%num_outputs - 1, real32)
          else
             t = 0.0_real32
          end if
          idx = (j - 1) * this%num_modes + k
          d_args(idx) = -s * t
       end do
    end do
    call emit_float_initialiser(trim(d_args_name), d_args, &
         [this%num_outputs, this%num_modes], inits, num_inits)
    deallocate(d_args)

    ! beta: residues [M, 1]
    call emit_float_initialiser(trim(beta_name), this%params(2)%val(:,1), &
         [this%num_modes, 1], inits, num_inits)

    ! b: bias [n_out, 1] (if use_bias)
    if(this%use_bias)then
       call emit_float_initialiser(trim(b_name), this%params(4)%val(:,1), &
            [this%num_outputs, 1], inits, num_inits)
    end if

  end subroutine emit_onnx_nodes_dynamic_lno