Emit decomposed standard ONNX nodes for a Dynamic LNO layer.
Decomposes the forward pass v = sigma(D(mu)diag(beta)E(mu)u + Wu + b) into: Exp, MatMul, Mul, Add, Transpose, and optional Relu nodes.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(dynamic_lno_layer_type), | intent(in) | :: | this |
Dynamic LNO layer instance |
||
| character(len=*), | intent(in) | :: | prefix |
Layer name prefix (e.g. "layer1") |
||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes |
Node accumulator |
|
| integer, | intent(inout) | :: | num_nodes |
Node counter |
||
| integer, | intent(in) | :: | max_nodes |
Node limit |
||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits |
Initialiser accumulator |
|
| integer, | intent(inout) | :: | num_inits |
Initialiser counter |
||
| integer, | intent(in) | :: | max_inits |
Initialiser limit |
||
| character(len=*), | intent(in), | optional | :: | input_name |
Name of the input tensor (e.g. "input" or previous layer output) |
|
| logical, | intent(in), | optional | :: | is_last_layer |
Whether this is the last layer in the network |
|
| integer, | intent(in), | optional | :: | format |
Export format selector |
subroutine emit_onnx_nodes_dynamic_lno( & this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, & max_inits, input_name, is_last_layer, format) !! Emit decomposed standard ONNX nodes for a Dynamic LNO layer. !! !! Decomposes the forward pass v = sigma(D(mu)*diag(beta)*E(mu)*u + W*u + b) !! into: Exp, MatMul, Mul, Add, Transpose, and optional Relu nodes. implicit none ! Arguments class(dynamic_lno_layer_type), intent(in) :: this !! Dynamic LNO layer instance character(*), intent(in) :: prefix !! Layer name prefix (e.g. "layer1") type(onnx_node_type), intent(inout), dimension(:) :: nodes !! Node accumulator integer, intent(inout) :: num_nodes !! Node counter integer, intent(in) :: max_nodes !! Node limit type(onnx_initialiser_type), intent(inout), dimension(:) :: inits !! Initialiser accumulator integer, intent(inout) :: num_inits !! Initialiser counter integer, intent(in) :: max_inits !! Initialiser limit character(*), optional, intent(in) :: input_name !! Name of the input tensor (e.g. "input" or previous layer output) logical, optional, intent(in) :: is_last_layer !! Whether this is the last layer in the network integer, optional, intent(in) :: format !! Export format selector ! Local variables integer :: j, k, idx, n real(real32) :: s, t real(real32), allocatable :: e_args(:), d_args(:) character(128) :: e_args_name, d_args_name, beta_name, w_name, b_name character(128) :: exp_e_out, exp_d_out, trans_in_out character(128) :: mm_e_out, mul_out, mm_d_out, mm_w_out character(128) :: add_out, add_b_out, final_output, output_source integer :: format_ format_ = 1 if(present(format)) format_ = format if(format_ .ne. 2) return if(.not.present(input_name)) return if(.not.present(is_last_layer)) return !-------------------------------------------------------------------------- ! Build initialiser names !-------------------------------------------------------------------------- write(e_args_name, '(A,".E_args")') trim(prefix) write(d_args_name, '(A,".D_args")') trim(prefix) write(beta_name, '(A,".beta")') trim(prefix) write(w_name, '(A,".W")') trim(prefix) write(b_name, '(A,".b")') trim(prefix) !-------------------------------------------------------------------------- ! Build intermediate tensor names !-------------------------------------------------------------------------- write(exp_e_out, '("/",A,"/Exp_output_0")') trim(prefix) write(exp_d_out, '("/",A,"/Exp_1_output_0")') trim(prefix) write(trans_in_out, '("/",A,"/Transpose_output_0")') trim(prefix) write(mm_e_out, '("/",A,"/MatMul_output_0")') trim(prefix) write(mul_out, '("/",A,"/Mul_output_0")') trim(prefix) write(mm_d_out, '("/",A,"/MatMul_1_output_0")') trim(prefix) write(mm_w_out, '("/",A,"/MatMul_2_output_0")') trim(prefix) write(add_out, '("/",A,"/Add_output_0")') trim(prefix) write(add_b_out, '("/",A,"/Add_1_output_0")') trim(prefix) !-------------------------------------------------------------------------- ! Emit ONNX nodes !-------------------------------------------------------------------------- ! 1. Exp(E_args) -> E [M, n_in] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Exp")') trim(prefix) nodes(num_nodes)%op_type = 'Exp' allocate(nodes(num_nodes)%inputs(1)) nodes(num_nodes)%inputs(1) = trim(e_args_name) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(exp_e_out) nodes(num_nodes)%attributes_json = '' ! 2. Exp(D_args) -> D [n_out, M] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Exp_1")') trim(prefix) nodes(num_nodes)%op_type = 'Exp' allocate(nodes(num_nodes)%inputs(1)) nodes(num_nodes)%inputs(1) = trim(d_args_name) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(exp_d_out) nodes(num_nodes)%attributes_json = '' ! 3. Transpose(input) -> x_t [n_in, batch] call emit_nop_input_transpose(trim(prefix), trim(input_name), nodes, & num_nodes, trim(trans_in_out)) ! 4. MatMul(E, x_t) -> encoded [M, batch] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(exp_e_out) nodes(num_nodes)%inputs(2) = trim(trans_in_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_e_out) nodes(num_nodes)%attributes_json = '' ! 5. Mul(beta, encoded) -> scaled [M, batch] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Mul")') trim(prefix) nodes(num_nodes)%op_type = 'Mul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(beta_name) nodes(num_nodes)%inputs(2) = trim(mm_e_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mul_out) nodes(num_nodes)%attributes_json = '' ! 6. MatMul(D, scaled) -> spectral [n_out, batch] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul_1")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(exp_d_out) nodes(num_nodes)%inputs(2) = trim(mul_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_d_out) nodes(num_nodes)%attributes_json = '' ! 7. MatMul(W, x_t) -> local [n_out, batch] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul_2")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(w_name) nodes(num_nodes)%inputs(2) = trim(trans_in_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_w_out) nodes(num_nodes)%attributes_json = '' ! 8. Add(spectral, local) -> combined [n_out, batch] num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Add")') trim(prefix) nodes(num_nodes)%op_type = 'Add' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(mm_d_out) nodes(num_nodes)%inputs(2) = trim(mm_w_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(add_out) nodes(num_nodes)%attributes_json = '' ! 9. Add(combined, bias) -> biased [n_out, batch] if(this%use_bias)then num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Add_1")') trim(prefix) nodes(num_nodes)%op_type = 'Add' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(add_out) nodes(num_nodes)%inputs(2) = trim(b_name) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(add_b_out) nodes(num_nodes)%attributes_json = '' end if if(this%use_bias)then output_source = add_b_out else output_source = add_out end if call emit_nop_output_tail(trim(prefix), trim(this%activation%name), & is_last_layer, trim(output_source), nodes, num_nodes, final_output) !-------------------------------------------------------------------------- ! Emit initialisers !-------------------------------------------------------------------------- ! W: bypass weights [n_out, n_in] in row-major n = this%num_outputs * this%num_inputs call emit_matrix_initialiser(trim(w_name), this%params(3)%val(:,1), & this%num_outputs, this%num_inputs, inits, num_inits) ! E_args: -mu*t for encoder [M, n_in] in row-major n = this%num_modes * this%num_inputs allocate(e_args(n)) do j = 1, this%num_inputs if(this%num_inputs .gt. 1)then t = real(j - 1, real32) / real(this%num_inputs - 1, real32) else t = 0.0_real32 end if do k = 1, this%num_modes s = this%params(1)%val(k, 1) idx = (k - 1) * this%num_inputs + j e_args(idx) = -s * t end do end do call emit_float_initialiser(trim(e_args_name), e_args, & [this%num_modes, this%num_inputs], inits, num_inits) deallocate(e_args) ! D_args: -mu*tau for decoder [n_out, M] in row-major n = this%num_outputs * this%num_modes allocate(d_args(n)) do k = 1, this%num_modes s = this%params(1)%val(k, 1) do j = 1, this%num_outputs if(this%num_outputs .gt. 1)then t = real(j - 1, real32) / real(this%num_outputs - 1, real32) else t = 0.0_real32 end if idx = (j - 1) * this%num_modes + k d_args(idx) = -s * t end do end do call emit_float_initialiser(trim(d_args_name), d_args, & [this%num_outputs, this%num_modes], inits, num_inits) deallocate(d_args) ! beta: residues [M, 1] call emit_float_initialiser(trim(beta_name), this%params(2)%val(:,1), & [this%num_modes, 1], inits, num_inits) ! b: bias [n_out, 1] (if use_bias) if(this%use_bias)then call emit_float_initialiser(trim(b_name), this%params(4)%val(:,1), & [this%num_outputs, 1], inits, num_inits) end if end subroutine emit_onnx_nodes_dynamic_lno