collect_export_nodes Subroutine

subroutine collect_export_nodes(network, ifmt, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, gnn_metadata, num_gnn_meta)

Build the ONNX nodes, initialisers and GNN metadata.

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(in) :: network

Instance of the network

integer, intent(in) :: ifmt

Export format selector

type(onnx_node_type), intent(inout) :: nodes(:)

Exported ONNX nodes

integer, intent(inout) :: num_nodes

Node counter and allocation limit

integer, intent(inout) :: max_nodes

Node counter and allocation limit

type(onnx_initialiser_type), intent(inout) :: inits(:)

Exported ONNX initialisers

integer, intent(inout) :: num_inits

Initialiser counter and allocation limit

integer, intent(inout) :: max_inits

Initialiser counter and allocation limit

character(len=4096), intent(inout) :: gnn_metadata(:)

Exported GNN metadata entries

integer, intent(inout) :: num_gnn_meta

Number of metadata entries


Source Code

  subroutine collect_export_nodes( &
       network, ifmt, nodes, num_nodes, max_nodes, &
       inits, num_inits, max_inits, &
       gnn_metadata, num_gnn_meta)
    !! Build the ONNX nodes, initialisers and GNN metadata.
    implicit none

    ! Arguments
    class(network_type), intent(in) :: network
    !! Instance of the network
    integer, intent(in) :: ifmt
    !! Export format selector
    type(onnx_node_type), intent(inout) :: nodes(:)
    !! Exported ONNX nodes
    integer, intent(inout) :: num_nodes, max_nodes
    !! Node counter and allocation limit
    type(onnx_initialiser_type), intent(inout) :: inits(:)
    !! Exported ONNX initialisers
    integer, intent(inout) :: num_inits, max_inits
    !! Initialiser counter and allocation limit
    character(4096), intent(inout) :: gnn_metadata(:)
    !! Exported GNN metadata entries
    integer, intent(inout) :: num_gnn_meta
    !! Number of metadata entries

    ! Local variables
    integer :: i, ii, layer_id, layer_num, lid
    !! Loop index and layer identifier
    character(128) :: node_name, input_name
    !! Node name prefix and sequential input name
    logical :: is_last_layer
    !! Whether the current NOP is the last non-input layer

    if(ifmt .eq. 2)then
       layer_num = 0
       input_name = 'input'

       do i = 1, network%auto_graph%num_vertices
          layer_id = network%auto_graph%vertex(network%vertex_order(i))%id
          if(trim(network%model(layer_id)%layer%type) .eq. 'inpt') cycle

          if(trim(network%model(layer_id)%layer%type) .ne. 'nop')then
             call stop_program( &
                  'write_onnx: pytorch format supports NOP layers only')
             return
          end if

          layer_num = layer_num + 1
          write(node_name, '("layer",I0)') layer_num

          is_last_layer = .true.
          do ii = i + 1, network%auto_graph%num_vertices
             lid = network%auto_graph%vertex(network%vertex_order(ii))%id
             if(trim(network%model(lid)%layer%type) .ne. 'inpt')then
                is_last_layer = .false.
                exit
             end if
          end do

          call network%model(layer_id)%layer%emit_onnx_nodes( &
               trim(node_name), nodes, num_nodes, max_nodes, &
               inits, num_inits, max_inits, input_name=trim(input_name), &
               is_last_layer=is_last_layer, format=ifmt)

          call update_pytorch_prev_output( &
               network%model(layer_id)%layer, trim(node_name), &
               is_last_layer, input_name)
       end do

       return
    end if

    do i = 1, network%auto_graph%num_vertices
       layer_id = network%auto_graph%vertex(network%vertex_order(i))%id
       write(node_name, '("node_",I0)') network%model(layer_id)%layer%id

       select case(trim(network%model(layer_id)%layer%type))
       case('inpt')
          cycle
       case('msgp')
          call emit_gnn_input_renames( &
               network, layer_id, i, nodes, num_nodes)
          call network%model(layer_id)%layer%emit_onnx_nodes( &
               trim(node_name), nodes, num_nodes, max_nodes, &
               inits, num_inits, max_inits)
          call build_gnn_metadata( &
               network%model(layer_id)%layer, trim(node_name), &
               gnn_metadata, num_gnn_meta)
       case('nop')
          call emit_standard_node_json( &
               network, layer_id, i, nodes, num_nodes, max_nodes, &
               inits, num_inits, max_inits)
          call emit_nop_metadata( &
               network%model(layer_id)%layer, trim(node_name), &
               gnn_metadata, num_gnn_meta)
       case default
          call emit_standard_node_json( &
               network, layer_id, i, nodes, num_nodes, max_nodes, &
               inits, num_inits, max_inits)
       end select
    end do

  end subroutine collect_export_nodes