build_fixed_lno_onnx_expanded_nop Function

public function build_fixed_lno_onnx_expanded_nop(prefix, nodes, num_nodes, inits, num_inits) result(layer)

Build one fixed LNO layer from an expanded-ONNX node cluster.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix

Layer node prefix (e.g. layer2)

type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Number of valid node entries

type(onnx_initialiser_type), intent(in) :: inits(:)

Parsed ONNX initialisers

integer, intent(in) :: num_inits

Number of valid initialiser entries

Return Value class(base_layer_type), allocatable

Constructed fixed LNO layer


Source Code

  function build_fixed_lno_onnx_expanded_nop( &
       prefix, nodes, num_nodes, inits, num_inits) result(layer)
    !! Build one fixed LNO layer from an expanded-ONNX node cluster.
    use athena__fixed_lno_layer, only: fixed_lno_layer_type
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Layer node prefix (e.g. layer2)
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    integer, intent(in) :: num_nodes
    !! Number of valid node entries
    type(onnx_initialiser_type), intent(in) :: inits(:)
    !! Parsed ONNX initialisers
    integer, intent(in) :: num_inits
    !! Number of valid initialiser entries
    class(base_layer_type), allocatable :: layer
    !! Constructed fixed LNO layer

    ! Local variables
    type(fixed_lno_layer_type) :: typed_layer
    !! Concrete layer object before up-casting
    integer :: matmul1_idx, matmul3_idx, add1_idx
    !! Node indices for learnable parameters in the fixed LNO decomposition
    integer :: r_idx, w_idx, b_idx
    !! Initialiser indices used to populate the layer parameters
    integer :: num_inputs, num_outputs, num_modes
    !! Reconstructed layer dimensions
    logical :: use_bias
    !! Whether the graph includes a bias add
    character(64) :: activation_name
    !! Activation reconstructed from the tail of the graph

    matmul1_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'MatMul_1')
    matmul3_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'MatMul_3')
    add1_idx = find_onnx_expanded_node_by_suffix(nodes, num_nodes, prefix, &
         'Add_1')

    if(matmul1_idx .le. 0 .or. matmul3_idx .le. 0)then
       call stop_program('Fixed LNO ONNX cluster is incomplete for ' // &
            trim(prefix))
    end if

    r_idx = find_node_initialiser_index(nodes(matmul1_idx), inits, num_inits)
    w_idx = find_node_initialiser_index(nodes(matmul3_idx), inits, num_inits)

    if(min(r_idx, w_idx) .le. 0)then
       call stop_program('Fixed LNO ONNX parameters are missing for ' // &
            trim(prefix))
    end if

    num_modes = inits(r_idx)%dims(1)
    num_outputs = inits(w_idx)%dims(1)
    num_inputs = inits(w_idx)%dims(2)
    use_bias = add1_idx .gt. 0
    activation_name = detect_onnx_expanded_nop_activation( &
         prefix, nodes, num_nodes)

    typed_layer = fixed_lno_layer_type( &
         num_outputs=num_outputs, num_modes=num_modes, &
         num_inputs=num_inputs, use_bias=use_bias, &
         activation=trim(activation_name))

    call load_onnx_expanded_matrix_param( &
         typed_layer%params(1), inits(r_idx), num_modes, num_modes)
    call load_onnx_expanded_matrix_param( &
         typed_layer%params(2), inits(w_idx), num_outputs, num_inputs)

    if(use_bias)then
       b_idx = find_node_initialiser_index(nodes(add1_idx), inits, num_inits)
       if(b_idx .le. 0)then
          call stop_program('Fixed LNO bias initialiser missing for ' // &
               trim(prefix))
       end if
       typed_layer%params(3)%val(:,1) = inits(b_idx)%data
    end if

    allocate(layer, source=typed_layer)

  end function build_fixed_lno_onnx_expanded_nop