module athena__onnx_nop_utils !! Shared utility routines for NOP ONNX export/import. use coreutils, only: real32, stop_program use athena__base_layer, only: base_layer_type use athena__misc_types, only: onnx_attribute_type, onnx_node_type, & onnx_initialiser_type use athena__onnx_utils, only: emit_node, col_to_row_major_2d, & row_to_col_major_2d use diffstruc, only: array_type implicit none private public :: emit_nop_input_transpose public :: emit_nop_output_tail public :: emit_float_initialiser public :: emit_matrix_initialiser public :: emit_nop_metadata public :: parse_nop_metadata public :: extract_nop_prefix public :: load_nop_param_from_inits public :: find_initialiser_by_name public :: infer_dynamic_lno_poles public :: find_onnx_expanded_node_by_suffix public :: find_node_initialiser_index public :: detect_onnx_expanded_nop_activation public :: load_onnx_expanded_matrix_param contains !############################################################################### subroutine emit_nop_input_transpose(prefix, input_name, nodes, num_nodes, & output_name) !! Emit the common NOP input transpose. implicit none character(*), intent(in) :: prefix, input_name type(onnx_node_type), intent(inout) :: nodes(:) integer, intent(inout) :: num_nodes character(*), intent(in) :: output_name character(4096) :: perm_attr perm_attr = ' "attribute": [{"name": "perm", "ints": ' // & '["1", "0"], "type": "INTS"}]' call emit_node('Transpose', '/' // trim(prefix) // '/Transpose', & trim(output_name), trim(perm_attr), nodes, num_nodes, & in1=trim(input_name)) end subroutine emit_nop_input_transpose !############################################################################### !############################################################################### subroutine emit_nop_output_tail(prefix, activation_name, is_last_layer, & input_name, nodes, num_nodes, final_output) !! Emit the common transpose and optional activation at the end of a NOP. use coreutils, only: to_camel_case implicit none character(*), intent(in) :: prefix, activation_name, input_name logical, intent(in) :: is_last_layer type(onnx_node_type), intent(inout) :: nodes(:) integer, intent(inout) :: num_nodes character(128), intent(out) :: final_output character(4096) :: perm_attr character(128) :: transpose_output character(128) :: activation_op, activation_node character(4096) :: activation_attr perm_attr = ' "attribute": [{"name": "perm", "ints": ' // & '["1", "0"], "type": "INTS"}]' if(is_last_layer .and. trim(activation_name) .eq. 'none')then transpose_output = 'output' else write(transpose_output, '("/",A,"/Transpose_1_output_0")') & trim(prefix) end if call emit_node('Transpose', '/' // trim(prefix) // '/Transpose_1', & trim(transpose_output), trim(perm_attr), nodes, num_nodes, & in1=trim(input_name)) if(trim(activation_name) .ne. 'none')then activation_op = to_camel_case( & trim(adjustl(activation_name)), & capitalise_first_letter = .true.) activation_attr = '' if(trim(activation_name) .eq. 'leaky_relu')then activation_op = 'LeakyRelu' activation_attr = ' "attribute": [{"name": "alpha", ' // & '"f": 0.01, "type": "FLOAT"}]' end if if(is_last_layer)then final_output = 'output' else write(final_output, '("/",A,"/",A,"_output_0")') & trim(prefix), trim(activation_op) end if activation_node = '/' // trim(prefix) // '/' // trim(activation_op) call emit_node(trim(activation_op), trim(activation_node), & trim(final_output), trim(activation_attr), nodes, num_nodes, & in1=trim(transpose_output)) else final_output = transpose_output end if end subroutine emit_nop_output_tail !############################################################################### !############################################################################### subroutine emit_float_initialiser(name, data, dims, inits, num_inits) !! Emit a float32 initialiser with explicit dimensions. implicit none character(*), intent(in) :: name real(real32), intent(in) :: data(:) integer, intent(in) :: dims(:) type(onnx_initialiser_type), intent(inout) :: inits(:) integer, intent(inout) :: num_inits num_inits = num_inits + 1 inits(num_inits)%name = trim(name) inits(num_inits)%data_type = 1 allocate(inits(num_inits)%dims(size(dims))) inits(num_inits)%dims = dims allocate(inits(num_inits)%data(size(data))) inits(num_inits)%data = data end subroutine emit_float_initialiser !############################################################################### !############################################################################### subroutine emit_matrix_initialiser(name, data_col_major, rows, cols, inits, & num_inits) !! Emit a 2D float32 initialiser after converting to row-major order. implicit none character(*), intent(in) :: name real(real32), intent(in) :: data_col_major(:) integer, intent(in) :: rows, cols type(onnx_initialiser_type), intent(inout) :: inits(:) integer, intent(inout) :: num_inits real(real32), allocatable :: row_major(:) allocate(row_major(size(data_col_major))) call col_to_row_major_2d(data_col_major, row_major, rows, cols) call emit_float_initialiser(name, row_major, [rows, cols], inits, & num_inits) deallocate(row_major) end subroutine emit_matrix_initialiser !############################################################################### !############################################################################### subroutine emit_nop_metadata(layer, prefix, metadata, num_meta) !! Build the metadata entry required to reconstruct a NOP layer. implicit none class(base_layer_type), intent(in) :: layer character(*), intent(in) :: prefix character(4096), intent(inout) :: metadata(:) integer, intent(inout) :: num_meta type(onnx_attribute_type), allocatable :: attrs(:) integer :: i character(2048) :: value_str attrs = layer%get_attributes() if(.not.allocated(attrs)) return if(size(attrs) .eq. 0) return value_str = 'subtype=' // trim(adjustl(layer%name)) do i = 1, size(attrs) value_str = trim(value_str) // ';' // trim(attrs(i)%name) // '=' // & trim(adjustl(attrs(i)%val)) end do num_meta = num_meta + 1 write(metadata(num_meta), '(A)') & ' {"key": "athena_nop_' // trim(prefix) // & '", "value": "' // trim(value_str) // '"}' end subroutine emit_nop_metadata !############################################################################### !############################################################################### subroutine parse_nop_metadata(meta_value, & num_inputs, num_outputs, num_modes, use_bias, activation_name) !! Parse common NOP hyperparameters from metadata value string. implicit none character(*), intent(in) :: meta_value integer, intent(inout) :: num_inputs, num_outputs, num_modes logical, intent(inout) :: use_bias character(64), intent(inout) :: activation_name integer :: k, pos, pos2, stat character(256) :: token, key, val logical :: logical_val pos = 1 do while(pos .le. len_trim(meta_value)) pos2 = index(meta_value(pos:), ';') if(pos2 .eq. 0)then token = meta_value(pos:len_trim(meta_value)) pos = len_trim(meta_value) + 1 else token = meta_value(pos:pos+pos2-2) pos = pos + pos2 end if k = index(token, '=') if(k .eq. 0) cycle key = trim(adjustl(token(1:k-1))) val = trim(adjustl(token(k+1:))) select case(trim(key)) case('num_inputs') read(val, *) num_inputs case('num_outputs') read(val, *) num_outputs case('num_modes', 'num_basis') read(val, *) num_modes case('use_bias') read(val, *, iostat=stat) logical_val if(stat .eq. 0)then use_bias = logical_val else select case(trim(adjustl(val))) case('1', 'T', 't', 'true', 'TRUE', 'True') use_bias = .true. case('0', 'F', 'f', 'false', 'FALSE', 'False') use_bias = .false. case default call stop_program('parse_nop_metadata: invalid use_bias value') end select end if case('activation') activation_name = trim(val) end select end do end subroutine parse_nop_metadata !############################################################################### !############################################################################### function extract_nop_prefix(meta_key) result(prefix) !! Extract the node prefix from an athena_nop_node_X metadata key. implicit none character(*), intent(in) :: meta_key character(64) :: prefix integer :: pos prefix = trim(meta_key) pos = index(prefix, 'athena_nop_') if(pos .gt. 0) prefix = prefix(pos+11:) end function extract_nop_prefix !############################################################################### !############################################################################### subroutine load_nop_param_from_inits( & param, prefix, suffix, inits, num_inits, dims) !! Load a parameter from ONNX initialisers into a diffstruc array. implicit none type(array_type), intent(inout) :: param character(*), intent(in) :: prefix, suffix type(onnx_initialiser_type), intent(in) :: inits(:) integer, intent(in) :: num_inits integer, intent(in) :: dims(2) integer :: k character(128) :: target_name real(real32), allocatable :: col_data(:) write(target_name, '(A,A)') trim(prefix), suffix do k = 1, num_inits if(trim(inits(k)%name) .ne. trim(target_name)) cycle if(.not.allocated(inits(k)%data)) cycle if(dims(2) .gt. 1)then allocate(col_data(size(inits(k)%data))) call row_to_col_major_2d(inits(k)%data, col_data, dims(1), dims(2)) param%val(:,1) = col_data deallocate(col_data) else param%val(:,1) = inits(k)%data end if return end do end subroutine load_nop_param_from_inits !############################################################################### !############################################################################### integer function find_initialiser_by_name(name, inits, num_inits) !! Return the index of a named initialiser, or zero when not found. implicit none character(*), intent(in) :: name type(onnx_initialiser_type), intent(in) :: inits(:) integer, intent(in) :: num_inits integer :: i find_initialiser_by_name = 0 do i = 1, num_inits if(trim(inits(i)%name) .eq. trim(name))then find_initialiser_by_name = i return end if end do end function find_initialiser_by_name !############################################################################### !############################################################################### subroutine infer_dynamic_lno_poles(e_args_init, d_args_init, num_inputs, & num_outputs, poles) !! Reconstruct dynamic LNO poles from exported encoder/decoder arguments. implicit none type(onnx_initialiser_type), intent(in) :: e_args_init, d_args_init integer, intent(in) :: num_inputs, num_outputs real(real32), intent(out) :: poles(:) integer :: k, idx, num_modes real(real32) :: pi_value num_modes = size(poles) if(num_inputs .gt. 1 .and. allocated(e_args_init%data))then do k = 1, num_modes idx = (k - 1) * num_inputs + num_inputs poles(k) = -e_args_init%data(idx) end do return end if if(num_outputs .gt. 1 .and. allocated(d_args_init%data))then do k = 1, num_modes idx = (num_outputs - 1) * num_modes + k poles(k) = -d_args_init%data(idx) end do return end if pi_value = acos(-1.0_real32) do k = 1, num_modes poles(k) = real(k, real32) * pi_value end do end subroutine infer_dynamic_lno_poles !############################################################################### !############################################################################### integer function find_onnx_expanded_node_by_suffix( & nodes, num_nodes, prefix, suffix) !! Return the node index matching one /layerN/suffix name, or zero. implicit none ! Arguments type(onnx_node_type), intent(in) :: nodes(:) !! Parsed ONNX nodes integer, intent(in) :: num_nodes !! Number of valid node entries character(*), intent(in) :: prefix, suffix !! Layer prefix and trailing node name token ! Local variables integer :: i !! Loop index character(128) :: target_name !! Full node name to match write(target_name, '("/",A,"/",A)') trim(prefix), trim(suffix) find_onnx_expanded_node_by_suffix = 0 do i = 1, num_nodes if(trim(nodes(i)%name) .eq. trim(target_name))then find_onnx_expanded_node_by_suffix = i return end if end do end function find_onnx_expanded_node_by_suffix !############################################################################### !############################################################################### integer function find_node_initialiser_index(node, inits, num_inits) !! Return the first initialiser referenced by a node's inputs. implicit none ! Arguments type(onnx_node_type), intent(in) :: node !! Parsed ONNX node whose inputs may reference an initialiser type(onnx_initialiser_type), intent(in) :: inits(:) !! Parsed ONNX initialisers integer, intent(in) :: num_inits !! Number of valid initialiser entries ! Local variables integer :: i, init_idx !! Loop index and candidate initialiser index find_node_initialiser_index = 0 if(.not.allocated(node%inputs)) return do i = 1, size(node%inputs) init_idx = find_initialiser_by_name(node%inputs(i), inits, num_inits) if(init_idx .gt. 0)then find_node_initialiser_index = init_idx return end if end do end function find_node_initialiser_index !############################################################################### !############################################################################### function detect_onnx_expanded_nop_activation(prefix, nodes, num_nodes) & result(name) !! Reconstruct the activation name from the tail of an expanded-ONNX NOP !! cluster. use athena__onnx_utils, only: onnx_to_athena_activation implicit none ! Arguments character(*), intent(in) :: prefix !! Layer node prefix without leading slash type(onnx_node_type), intent(in) :: nodes(:) !! Parsed ONNX nodes integer, intent(in) :: num_nodes !! Number of valid node entries character(64) :: name !! Reconstructed ATHENA activation name integer :: i character(128) :: cluster_prefix write(cluster_prefix, '("/",A,"/")') trim(prefix) name = 'none' do i = 1, num_nodes if(index(trim(nodes(i)%name), trim(cluster_prefix)) .ne. 1) cycle select case(trim(nodes(i)%op_type)) case('Relu', 'LeakyRelu', 'Sigmoid', 'Softmax', 'Tanh', 'Selu', & 'Swish') name = onnx_to_athena_activation(trim(nodes(i)%op_type)) return end select end do end function detect_onnx_expanded_nop_activation !############################################################################### !############################################################################### subroutine load_onnx_expanded_matrix_param(param, init, rows, cols) !! Copy a row-major ONNX matrix initialiser into a diffstruc parameter. implicit none ! Arguments type(array_type), intent(inout) :: param !! Destination diffstruc parameter tensor type(onnx_initialiser_type), intent(in) :: init !! Row-major ONNX initialiser data integer, intent(in) :: rows, cols !! Matrix shape ! Local variables real(real32), allocatable :: col_major(:) !! Temporary column-major buffer for ATHENA internal storage allocate(col_major(rows * cols)) call row_to_col_major_2d(init%data, col_major, rows, cols) param%val(:,1) = col_major deallocate(col_major) end subroutine load_onnx_expanded_matrix_param !############################################################################### end module athena__onnx_nop_utils