Emit decomposed standard ONNX nodes for a Fixed LNO layer.
Forward: v = sigma(D * R * E * u + W * u + b) where E and D are fixed Laplace bases, R is a learnable mixing matrix.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(fixed_lno_layer_type), | intent(in) | :: | this |
Fixed LNO layer instance |
||
| character(len=*), | intent(in) | :: | prefix |
Layer name prefix (e.g. "layer1") |
||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes |
Node accumulator |
|
| integer, | intent(inout) | :: | num_nodes |
Node counter |
||
| integer, | intent(in) | :: | max_nodes |
Node limit |
||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits |
Initialiser accumulator |
|
| integer, | intent(inout) | :: | num_inits |
Initialiser counter |
||
| integer, | intent(in) | :: | max_inits |
Initialiser limit |
||
| character(len=*), | intent(in), | optional | :: | input_name |
Name of the input tensor |
|
| logical, | intent(in), | optional | :: | is_last_layer |
Whether this is the last layer |
|
| integer, | intent(in), | optional | :: | format |
Export format selector |
subroutine emit_onnx_nodes_fixed_lno( & this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, & max_inits, input_name, is_last_layer, format) !! Emit decomposed standard ONNX nodes for a Fixed LNO layer. !! !! Forward: v = sigma(D * R * E * u + W * u + b) !! where E and D are fixed Laplace bases, R is a learnable mixing matrix. use coreutils, only: pi implicit none ! Arguments class(fixed_lno_layer_type), intent(in) :: this !! Fixed LNO layer instance character(*), intent(in) :: prefix !! Layer name prefix (e.g. "layer1") type(onnx_node_type), intent(inout), dimension(:) :: nodes !! Node accumulator integer, intent(inout) :: num_nodes !! Node counter integer, intent(in) :: max_nodes !! Node limit type(onnx_initialiser_type), intent(inout), dimension(:) :: inits !! Initialiser accumulator integer, intent(inout) :: num_inits !! Initialiser counter integer, intent(in) :: max_inits !! Initialiser limit character(*), optional, intent(in) :: input_name !! Name of the input tensor logical, optional, intent(in) :: is_last_layer !! Whether this is the last layer integer, optional, intent(in) :: format !! Export format selector ! Local variables integer :: j, k, idx, n real(real32) :: s, t real(real32), allocatable :: e_data(:), d_data(:) character(128) :: e_name, d_name, r_name, w_name, b_name character(128) :: trans_in_out, mm_e_out, mm_r_out, mm_d_out character(128) :: mm_w_out, add_out, add_b_out, final_output, & output_source integer :: format_ format_ = 1 if(present(format)) format_ = format if(format_ .ne. 2) return if(.not.present(input_name)) return if(.not.present(is_last_layer)) return !-------------------------------------------------------------------------- ! Build names !-------------------------------------------------------------------------- write(e_name, '(A,".E")') trim(prefix) write(d_name, '(A,".D")') trim(prefix) write(r_name, '(A,".R")') trim(prefix) write(w_name, '(A,".W")') trim(prefix) write(b_name, '(A,".b")') trim(prefix) write(trans_in_out, '("/",A,"/Transpose_output_0")') trim(prefix) write(mm_e_out, '("/",A,"/MatMul_output_0")') trim(prefix) write(mm_r_out, '("/",A,"/MatMul_1_output_0")') trim(prefix) write(mm_d_out, '("/",A,"/MatMul_2_output_0")') trim(prefix) write(mm_w_out, '("/",A,"/MatMul_3_output_0")') trim(prefix) write(add_out, '("/",A,"/Add_output_0")') trim(prefix) write(add_b_out, '("/",A,"/Add_1_output_0")') trim(prefix) !-------------------------------------------------------------------------- ! Emit nodes !-------------------------------------------------------------------------- ! 1. Transpose(input) call emit_nop_input_transpose(trim(prefix), trim(input_name), nodes, & num_nodes, trim(trans_in_out)) ! 2. MatMul(E, x_t) num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(e_name) nodes(num_nodes)%inputs(2) = trim(trans_in_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_e_out) nodes(num_nodes)%attributes_json = '' ! 3. MatMul(R, encoded) num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul_1")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(r_name) nodes(num_nodes)%inputs(2) = trim(mm_e_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_r_out) nodes(num_nodes)%attributes_json = '' ! 4. MatMul(D, mixed) num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul_2")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(d_name) nodes(num_nodes)%inputs(2) = trim(mm_r_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_d_out) nodes(num_nodes)%attributes_json = '' ! 5. MatMul(W, x_t) num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/MatMul_3")') trim(prefix) nodes(num_nodes)%op_type = 'MatMul' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(w_name) nodes(num_nodes)%inputs(2) = trim(trans_in_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(mm_w_out) nodes(num_nodes)%attributes_json = '' ! 6. Add(spectral, local) num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Add")') trim(prefix) nodes(num_nodes)%op_type = 'Add' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(mm_d_out) nodes(num_nodes)%inputs(2) = trim(mm_w_out) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(add_out) nodes(num_nodes)%attributes_json = '' ! 7. Add(combined, bias) if(this%use_bias)then num_nodes = num_nodes + 1 write(nodes(num_nodes)%name, '("/",A,"/Add_1")') trim(prefix) nodes(num_nodes)%op_type = 'Add' allocate(nodes(num_nodes)%inputs(2)) nodes(num_nodes)%inputs(1) = trim(add_out) nodes(num_nodes)%inputs(2) = trim(b_name) allocate(nodes(num_nodes)%outputs(1)) nodes(num_nodes)%outputs(1) = trim(add_b_out) nodes(num_nodes)%attributes_json = '' end if if(this%use_bias)then output_source = add_b_out else output_source = add_out end if call emit_nop_output_tail(trim(prefix), trim(this%activation%name), & is_last_layer, trim(output_source), nodes, num_nodes, final_output) !-------------------------------------------------------------------------- ! Emit initialisers !-------------------------------------------------------------------------- ! E: fixed encoder basis [M, n_in] in row-major n = this%num_modes * this%num_inputs allocate(e_data(n)) do j = 1, this%num_inputs if(this%num_inputs .gt. 1)then t = real(j - 1, real32) / real(this%num_inputs - 1, real32) else t = 0.0_real32 end if do k = 1, this%num_modes s = real(k, real32) * pi idx = (k - 1) * this%num_inputs + j e_data(idx) = exp(-s * t) end do end do call emit_float_initialiser(trim(e_name), e_data, & [this%num_modes, this%num_inputs], inits, num_inits) deallocate(e_data) ! D: fixed decoder basis [n_out, M] in row-major n = this%num_outputs * this%num_modes allocate(d_data(n)) do k = 1, this%num_modes s = real(k, real32) * pi do j = 1, this%num_outputs if(this%num_outputs .gt. 1)then t = real(j - 1, real32) / real(this%num_outputs - 1, real32) else t = 0.0_real32 end if idx = (j - 1) * this%num_modes + k d_data(idx) = exp(-s * t) end do end do call emit_float_initialiser(trim(d_name), d_data, & [this%num_outputs, this%num_modes], inits, num_inits) deallocate(d_data) ! R: spectral mixing [M, M] in row-major call emit_matrix_initialiser(trim(r_name), this%params(1)%val(:,1), & this%num_modes, this%num_modes, inits, num_inits) ! W: bypass weights [n_out, n_in] in row-major call emit_matrix_initialiser(trim(w_name), this%params(2)%val(:,1), & this%num_outputs, this%num_inputs, inits, num_inits) ! b: bias [n_out, 1] if(this%use_bias)then call emit_float_initialiser(trim(b_name), this%params(3)%val(:,1), & [this%num_outputs, 1], inits, num_inits) end if end subroutine emit_onnx_nodes_fixed_lno