add_standard_layer_from_onnx Subroutine

subroutine add_standard_layer_from_onnx(network, layer_id, node_index, nodes, num_nodes, inits, num_inits, verbose_)

Create standard (non-GNN) layers for a given layer_id using the registered ONNX creator framework (list_of_onnx_layer_creators).

Processes the primary node and any trailing activation node.

Arguments

Type IntentOptional Attributes Name
type(network_type), intent(inout) :: network

Network receiving the created layer(s)

integer, intent(in) :: layer_id

Layer id, primary node index, node count and initialiser count

integer, intent(in) :: node_index

Layer id, primary node index, node count and initialiser count

type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Layer id, primary node index, node count and initialiser count

type(onnx_initialiser_type), intent(in) :: inits(:)

Parsed ONNX initialisers

integer, intent(in) :: num_inits

Layer id, primary node index, node count and initialiser count

integer, intent(in) :: verbose_

Effective verbosity level


Source Code

  subroutine add_standard_layer_from_onnx( &
       network, layer_id, node_index, nodes, num_nodes, &
       inits, num_inits, verbose_)
    !! Create standard (non-GNN) layers for a given layer_id using the
    !! registered ONNX creator framework (list_of_onnx_layer_creators).
    !!
    !! Processes the primary node and any trailing activation node.
    use athena__base_layer, only: base_layer_type
    use athena__container_layer, only: list_of_onnx_layer_creators
    use athena__onnx_utils, only: row_to_col_major_2d
    implicit none

    ! Arguments
    type(network_type), intent(inout) :: network
    !! Network receiving the created layer(s)
    integer, intent(in) :: layer_id, node_index, num_nodes, num_inits
    !! Layer id, primary node index, node count and initialiser count
    integer, intent(in) :: verbose_
    !! Effective verbosity level
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    type(onnx_initialiser_type), intent(in) :: inits(:)
    !! Parsed ONNX initialisers

    ! Local variables
    integer :: j, k, layer_index, actv_index, ndims, num_matching
    !! Loop indices and creator/shape lookup values
    character(128) :: op_type_name, out_name
    !! Current ONNX op_type and output tensor name
    type(onnx_initialiser_type), allocatable :: init_list(:)
    !! Initialisers matched to the active node inputs
    type(onnx_tensor_type), allocatable :: value_info_list(:)
    !! Synthetic output shape hints passed to creator
    class(base_layer_type), allocatable :: layer
    !! Created ATHENA layer instance

    op_type_name = trim(adjustl(nodes(node_index)%op_type))

    layer_index = findloc( &
         [ list_of_onnx_layer_creators(:)%op_type ], &
         trim(op_type_name), dim = 1)

    if(layer_index .eq. 0)then
       if(verbose_ .gt. 0)then
          write(*,*) 'Skipping unsupported ONNX node in GNN import: ', &
               trim(nodes(node_index)%name), ' op=', trim(op_type_name)
       end if
       return
    end if

    num_matching = 0
    if(allocated(nodes(node_index)%inputs))then
       do j = 1, size(nodes(node_index)%inputs)
          do k = 1, num_inits
             if(trim(nodes(node_index)%inputs(j)) .eq. &
                  trim(inits(k)%name))then
                num_matching = num_matching + 1
             end if
          end do
       end do
    end if

    allocate(init_list(num_matching))
    num_matching = 0
    if(allocated(nodes(node_index)%inputs))then
       do j = 1, size(nodes(node_index)%inputs)
          do k = 1, num_inits
             if(trim(nodes(node_index)%inputs(j)) .ne. &
                  trim(inits(k)%name)) cycle

             num_matching = num_matching + 1
             init_list(num_matching)%name = inits(k)%name
             init_list(num_matching)%data_type = inits(k)%data_type

             if(allocated(inits(k)%dims))then
                allocate(init_list(num_matching)%dims(size(inits(k)%dims)))
                init_list(num_matching)%dims = inits(k)%dims
             end if

             if(allocated(inits(k)%data))then
                allocate(init_list(num_matching)%data(size(inits(k)%data)))
                if(allocated(inits(k)%dims))then
                   if(size(inits(k)%dims) .eq. 2)then
                      call row_to_col_major_2d( &
                           inits(k)%data, init_list(num_matching)%data, &
                           inits(k)%dims(1), inits(k)%dims(2))
                   else
                      init_list(num_matching)%data = inits(k)%data
                   end if
                else
                   init_list(num_matching)%data = inits(k)%data
                end if
             end if

             if(allocated(inits(k)%int_data))then
                allocate(init_list(num_matching)%int_data(size(inits(k)%int_data)))
                init_list(num_matching)%int_data = inits(k)%int_data
             end if
          end do
       end do
    end if

    allocate(value_info_list(0))
    if(allocated(nodes(node_index)%outputs) .and. &
         nodes(node_index)%num_outputs .ge. 1)then
       out_name = trim(nodes(node_index)%outputs(1))

       do j = 1, size(init_list)
          if(.not.allocated(init_list(j)%dims)) cycle
          if(size(init_list(j)%dims) .lt. 2) cycle
          ndims = size(init_list(j)%dims)

          block
            type(onnx_tensor_type) :: vi

            vi%name = out_name
            vi%elem_type = 1
            if(trim(op_type_name) .eq. 'Conv' .and. ndims .ge. 3)then
               allocate(vi%dims(ndims))
               vi%dims(1) = 1
               vi%dims(2) = init_list(j)%dims(ndims)
               vi%dims(3:ndims) = 0
            else
               allocate(vi%dims(2))
               vi%dims(1) = 1
               vi%dims(2) = init_list(j)%dims(1)
            end if

            deallocate(value_info_list)
            allocate(value_info_list(1))
            value_info_list(1)%name = vi%name
            value_info_list(1)%elem_type = vi%elem_type
            if(allocated(vi%dims))then
               allocate(value_info_list(1)%dims(size(vi%dims)))
               value_info_list(1)%dims = vi%dims
            end if
          end block
          exit
       end do
    end if

    layer = list_of_onnx_layer_creators(layer_index)%create_ptr( &
         nodes(node_index), init_list, value_info_list, verbose=verbose_)
    call network%add(layer)

    deallocate(init_list)
    deallocate(value_info_list)

    actv_index = find_activation_node_for_layer_id( &
         nodes, num_nodes, layer_id)
    if(actv_index .gt. 0)then
       op_type_name = trim(adjustl(nodes(actv_index)%op_type))
       layer_index = findloc( &
            [ list_of_onnx_layer_creators(:)%op_type ], &
            trim(op_type_name), dim = 1)
       if(layer_index .gt. 0)then
          allocate(init_list(0))
          allocate(value_info_list(0))
          if(allocated(layer)) deallocate(layer)
          layer = list_of_onnx_layer_creators(layer_index)%create_ptr( &
               nodes(actv_index), init_list, value_info_list, &
               verbose=verbose_)
          call network%add(layer)
          deallocate(init_list)
          deallocate(value_info_list)
       end if
    end if

  end subroutine add_standard_layer_from_onnx