build_dynamic_lno_onnx_expanded_nop Function

public function build_dynamic_lno_onnx_expanded_nop(prefix, nodes, num_nodes, inits, num_inits) result(layer)

Build one dynamic LNO layer from an expanded-ONNX node cluster.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix

Layer node prefix (e.g. layer1)

type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Number of valid node entries

type(onnx_initialiser_type), intent(in) :: inits(:)

Parsed ONNX initialisers

integer, intent(in) :: num_inits

Number of valid initialiser entries

Return Value class(base_layer_type), allocatable

Constructed dynamic LNO layer


Source Code

  function build_dynamic_lno_onnx_expanded_nop( &
       prefix, nodes, num_nodes, inits, num_inits) result(layer)
    !! Build one dynamic LNO layer from an expanded-ONNX node cluster.
    use athena__dynamic_lno_layer, only: dynamic_lno_layer_type
    use athena__onnx_nop_utils, only: infer_dynamic_lno_poles
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Layer node prefix (e.g. layer1)
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    integer, intent(in) :: num_nodes
    !! Number of valid node entries
    type(onnx_initialiser_type), intent(in) :: inits(:)
    !! Parsed ONNX initialisers
    integer, intent(in) :: num_inits
    !! Number of valid initialiser entries
    class(base_layer_type), allocatable :: layer
    !! Constructed dynamic LNO layer

    ! Local variables
    type(dynamic_lno_layer_type) :: typed_layer
    !! Concrete layer object before up-casting
    integer :: exp_idx, exp1_idx, mul_idx, matmul2_idx, add1_idx
    !! Node indices for the dynamic LNO decomposition
    integer :: e_idx, d_idx, beta_idx, w_idx, b_idx
    !! Initialiser indices used to populate the layer parameters
    integer :: num_inputs, num_outputs, num_modes
    !! Reconstructed layer dimensions
    logical :: use_bias
    !! Whether the graph includes a bias add
    character(64) :: activation_name
    !! Activation reconstructed from the tail of the graph
    real(real32), allocatable :: poles(:)
    !! Dynamic poles reconstructed from exported encoder/decoder arguments

    exp_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'Exp')
    exp1_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'Exp_1')
    mul_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'Mul')
    matmul2_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'MatMul_2')
    add1_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'Add_1')

    if(exp_idx .le. 0 .or. exp1_idx .le. 0 .or. mul_idx .le. 0 .or. &
         matmul2_idx .le. 0)then
       call stop_program('Dynamic LNO ONNX cluster is incomplete for ' // &
            trim(prefix))
    end if

    e_idx = find_node_initialiser_index(nodes(exp_idx), inits, num_inits)
    d_idx = find_node_initialiser_index(nodes(exp1_idx), inits, num_inits)
    beta_idx = find_node_initialiser_index(nodes(mul_idx), inits, num_inits)
    w_idx = find_node_initialiser_index(nodes(matmul2_idx), inits, num_inits)

    if(min(e_idx, d_idx, beta_idx, w_idx) .le. 0)then
       call stop_program('Dynamic LNO ONNX parameters are missing for ' // &
            trim(prefix))
    end if

    num_modes = inits(beta_idx)%dims(1)
    num_outputs = inits(w_idx)%dims(1)
    num_inputs = inits(w_idx)%dims(2)
    use_bias = add1_idx .gt. 0
    activation_name = detect_onnx_expanded_nop_activation( &
         prefix, nodes, num_nodes)

    typed_layer = dynamic_lno_layer_type( &
         num_outputs=num_outputs, num_modes=num_modes, &
         num_inputs=num_inputs, use_bias=use_bias, &
         activation=trim(activation_name))

    allocate(poles(num_modes))
    call infer_dynamic_lno_poles( &
         inits(e_idx), inits(d_idx), num_inputs, num_outputs, poles)
    typed_layer%params(1)%val(:,1) = poles
    typed_layer%params(2)%val(:,1) = inits(beta_idx)%data
    call load_onnx_expanded_matrix_param( &
         typed_layer%params(3), inits(w_idx), num_outputs, num_inputs)

    if(use_bias)then
       b_idx = find_node_initialiser_index(nodes(add1_idx), inits, num_inits)
       if(b_idx .le. 0)then
          call stop_program('Dynamic LNO bias initialiser missing for ' // &
               trim(prefix))
       end if
       typed_layer%params(4)%val(:,1) = inits(b_idx)%data
    end if

    allocate(layer, source=typed_layer)

  end function build_dynamic_lno_onnx_expanded_nop