emit_msgpass_graph_inputs Subroutine

public subroutine emit_msgpass_graph_inputs(prefix, input_shape, graph_inputs, num_inputs)

Emit the standard graph input tensors used by message-passing layers.

Adds vertex features, optional edge features, edge_index, and degree.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix

Input name prefix (e.g. "input_1")

integer, intent(in), dimension(:) :: input_shape

Layer input shape [num_vertex_features, num_edge_features]

type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs

Accumulator for graph inputs

integer, intent(inout) :: num_inputs

Current number of graph inputs


Source Code

  subroutine emit_msgpass_graph_inputs(prefix, input_shape, graph_inputs, &
       num_inputs)
    !! Emit the standard graph input tensors used by message-passing layers.
    !!
    !! Adds vertex features, optional edge features, edge_index, and degree.
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Input name prefix (e.g. "input_1")
    integer, dimension(:), intent(in) :: input_shape
    !! Layer input shape [num_vertex_features, num_edge_features]
    type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
    !! Accumulator for graph inputs
    integer, intent(inout) :: num_inputs
    !! Current number of graph inputs

    ! Vertex features: [num_nodes, nv]
    call add_graph_input_tensor( &
         graph_inputs, num_inputs, trim(prefix)//'_vertex', 1, &
         -1, 'num_nodes', input_shape(1), '')

    ! Edge features: [num_edges, ne]
    if(input_shape(2) .gt. 0)then
       call add_graph_input_tensor( &
            graph_inputs, num_inputs, trim(prefix)//'_edge', 1, &
            -1, 'num_edges', input_shape(2), '')
    end if

    ! Edge index: [3, num_csr_entries]
    call add_graph_input_tensor( &
         graph_inputs, num_inputs, trim(prefix)//'_edge_index', 7, &
         3, '', -1, 'num_csr_entries')

    ! Node degree: [num_nodes]
    call add_graph_input_tensor( &
         graph_inputs, num_inputs, trim(prefix)//'_degree', 7, &
         -1, 'num_nodes')

  end subroutine emit_msgpass_graph_inputs