emit_gnn_input_renames Subroutine

subroutine emit_gnn_input_renames(network, layer_id, vertex_idx, nodes, num_nodes)

Emit Identity nodes that rename GNN inputs to the expected convention.

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(in) :: network

Instance of the network

integer, intent(in) :: layer_id

Layer identifier and vertex position

integer, intent(in) :: vertex_idx

Layer identifier and vertex position

type(onnx_node_type), intent(inout) :: nodes(:)

Exported ONNX nodes

integer, intent(inout) :: num_nodes

Number of exported nodes


Source Code

  subroutine emit_gnn_input_renames( &
       network, layer_id, vertex_idx, nodes, num_nodes)
    !! Emit Identity nodes that rename GNN inputs to the expected convention.
    use athena__onnx_utils, only: emit_node
    implicit none

    ! Arguments
    class(network_type), intent(in) :: network
    !! Instance of the network
    integer, intent(in) :: layer_id, vertex_idx
    !! Layer identifier and vertex position
    type(onnx_node_type), intent(inout) :: nodes(:)
    !! Exported ONNX nodes
    integer, intent(inout) :: num_nodes
    !! Number of exported nodes

    ! Local variables
    integer :: j, input_layer_id
    !! Loop index and source layer identifier
    character(128) :: prefix, vertex_in, edge_in, edge_index_in, degree_in
    !! Temporary tensor names
    character(:), allocatable :: suffix
    !! Optional activation suffix for chained vertex inputs

    write(prefix, '("node_",I0)') network%model(layer_id)%layer%id

    vertex_in = ''
    edge_in = ''
    edge_index_in = ''
    degree_in = ''

    do j = 1, network%auto_graph%num_vertices
       input_layer_id = network%auto_graph%vertex(j)%id
       if(network%auto_graph%adjacency( &
            j, network%vertex_order(vertex_idx)) .eq. 0) cycle

       if(all(network%auto_graph%adjacency(:,j) .eq. 0))then
          write(vertex_in, '("input_",I0,"_vertex")') &
               network%model(input_layer_id)%layer%id
          write(edge_in, '("input_",I0,"_edge")') &
               network%model(input_layer_id)%layer%id
          write(edge_index_in, '("input_",I0,"_edge_index")') &
               network%model(input_layer_id)%layer%id
          write(degree_in, '("input_",I0,"_degree")') &
               network%model(input_layer_id)%layer%id
       else
          suffix = '_output'
          select type(prev => network%model(input_layer_id)%layer)
          class is(learnable_layer_type)
             if(prev%activation%name .ne. 'none')then
                suffix = '_' // trim(adjustl(prev%activation%name)) // &
                     '_output'
             end if
          end select
          write(vertex_in, '("node_",I0,A)') &
               network%model(input_layer_id)%layer%id, suffix
       end if
    end do

    if(len_trim(vertex_in) .gt. 0)then
       call emit_node('Identity', trim(prefix)//'_rename_vertex', &
            trim(prefix)//'_vertex_in', '', nodes, num_nodes, &
            in1=trim(vertex_in))
    end if

    if(len_trim(edge_in) .gt. 0)then
       call emit_node('Identity', trim(prefix)//'_rename_edge', &
            trim(prefix)//'_edge_in', '', nodes, num_nodes, &
            in1=trim(edge_in))
    end if

    if(len_trim(edge_index_in) .gt. 0)then
       call emit_node('Identity', trim(prefix)//'_rename_edge_index', &
            trim(prefix)//'_edge_index_in', '', nodes, num_nodes, &
            in1=trim(edge_index_in))
    end if

    if(len_trim(degree_in) .gt. 0)then
       call emit_node('Identity', trim(prefix)//'_rename_degree', &
            trim(prefix)//'_degree_in', '', nodes, num_nodes, &
            in1=trim(degree_in))
    end if

  end subroutine emit_gnn_input_renames