is_onnx_expanded_gnn_graph Function

function is_onnx_expanded_gnn_graph(nodes, num_nodes) result(output)

Return true when the parsed ONNX graph contains expanded-ONNX GNN patterns that ATHENA can collapse back into native message passing layers.

Arguments

Type IntentOptional Attributes Name
type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Number of valid node entries

Return Value logical

Whether the graph contains recognizable expanded-ONNX GNN patterns


Source Code

  function is_onnx_expanded_gnn_graph(nodes, num_nodes) result(output)
    !! Return true when the parsed ONNX graph contains expanded-ONNX GNN
    !! patterns that ATHENA can collapse back into native message passing
    !! layers.
    use athena__container_layer, only: &
         list_of_onnx_expanded_gnn_layer_creators, &
         allocate_list_of_onnx_expanded_gnn_layer_creators
    implicit none

    ! Arguments
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    integer, intent(in) :: num_nodes
    !! Number of valid node entries

    logical :: output
    !! Whether the graph contains recognizable expanded-ONNX GNN patterns

    ! Local variables
    integer, allocatable :: layer_ids(:)
    !! Unique layer ids from node names
    integer :: i, j, layer_id
    !! Loop indices and current layer id
    character(32) :: prefix
    !! Candidate GNN prefix

    if(.not.allocated( &
         list_of_onnx_expanded_gnn_layer_creators))then
       call allocate_list_of_onnx_expanded_gnn_layer_creators()
    end if

    output = .false.
    allocate(layer_ids(0))

    do i = 1, num_nodes
       call parse_any_node_layer_id( &
            nodes(i)%name, layer_id, j)
       if(j .le. 0) cycle
       if(.not.any(layer_ids .eq. layer_id))then
          layer_ids = [layer_ids, layer_id]
       end if
    end do

    do i = 1, size(layer_ids)
       write(prefix, '("node_",I0)') layer_ids(i)
       do j = 1, size( &
            list_of_onnx_expanded_gnn_layer_creators)
          if(list_of_onnx_expanded_gnn_layer_creators( &
               j)%classify_ptr( &
               prefix, nodes, num_nodes))then
             output = .true.
             return
          end if
       end do
    end do

  end function is_onnx_expanded_gnn_graph