Emit Identity nodes that rename GNN inputs to the expected convention.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(in) | :: | network |
Instance of the network |
||
| integer, | intent(in) | :: | layer_id |
Layer identifier and vertex position |
||
| integer, | intent(in) | :: | vertex_idx |
Layer identifier and vertex position |
||
| type(onnx_node_type), | intent(inout) | :: | nodes(:) |
Exported ONNX nodes |
||
| integer, | intent(inout) | :: | num_nodes |
Number of exported nodes |
subroutine emit_gnn_input_renames( & network, layer_id, vertex_idx, nodes, num_nodes) !! Emit Identity nodes that rename GNN inputs to the expected convention. use athena__onnx_utils, only: emit_node implicit none ! Arguments class(network_type), intent(in) :: network !! Instance of the network integer, intent(in) :: layer_id, vertex_idx !! Layer identifier and vertex position type(onnx_node_type), intent(inout) :: nodes(:) !! Exported ONNX nodes integer, intent(inout) :: num_nodes !! Number of exported nodes ! Local variables integer :: j, input_layer_id !! Loop index and source layer identifier character(128) :: prefix, vertex_in, edge_in, edge_index_in, degree_in !! Temporary tensor names character(:), allocatable :: suffix !! Optional activation suffix for chained vertex inputs write(prefix, '("node_",I0)') network%model(layer_id)%layer%id vertex_in = '' edge_in = '' edge_index_in = '' degree_in = '' do j = 1, network%auto_graph%num_vertices input_layer_id = network%auto_graph%vertex(j)%id if(network%auto_graph%adjacency( & j, network%vertex_order(vertex_idx)) .eq. 0) cycle if(all(network%auto_graph%adjacency(:,j) .eq. 0))then write(vertex_in, '("input_",I0,"_vertex")') & network%model(input_layer_id)%layer%id write(edge_in, '("input_",I0,"_edge")') & network%model(input_layer_id)%layer%id write(edge_index_in, '("input_",I0,"_edge_index")') & network%model(input_layer_id)%layer%id write(degree_in, '("input_",I0,"_degree")') & network%model(input_layer_id)%layer%id else suffix = '_output' select type(prev => network%model(input_layer_id)%layer) class is(learnable_layer_type) if(prev%activation%name .ne. 'none')then suffix = '_' // trim(adjustl(prev%activation%name)) // & '_output' end if end select write(vertex_in, '("node_",I0,A)') & network%model(input_layer_id)%layer%id, suffix end if end do if(len_trim(vertex_in) .gt. 0)then call emit_node('Identity', trim(prefix)//'_rename_vertex', & trim(prefix)//'_vertex_in', '', nodes, num_nodes, & in1=trim(vertex_in)) end if if(len_trim(edge_in) .gt. 0)then call emit_node('Identity', trim(prefix)//'_rename_edge', & trim(prefix)//'_edge_in', '', nodes, num_nodes, & in1=trim(edge_in)) end if if(len_trim(edge_index_in) .gt. 0)then call emit_node('Identity', trim(prefix)//'_rename_edge_index', & trim(prefix)//'_edge_index_in', '', nodes, num_nodes, & in1=trim(edge_index_in)) end if if(len_trim(degree_in) .gt. 0)then call emit_node('Identity', trim(prefix)//'_rename_degree', & trim(prefix)//'_degree_in', '', nodes, num_nodes, & in1=trim(degree_in)) end if end subroutine emit_gnn_input_renames