Emit ONNX JSON nodes for Duvenaud GNN layer
Decomposes the Duvenaud message passing layer into standard ONNX ops: Gather, Concat, ScatterElements, MatMul, Sigmoid/activation, Softmax, ReduceSum, Add, Div, Clip, Sub, etc.
This override is called by write_onnx instead of the standard node emission logic, making the ONNX export extensible for new GNN layer types.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | this |
Instance of the layer |
||
| character(len=*), | intent(in) | :: | prefix |
Node name prefix (e.g. "node_2") |
||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes |
Accumulator for ONNX nodes |
|
| integer, | intent(inout) | :: | num_nodes |
Current number of nodes |
||
| integer, | intent(in) | :: | max_nodes |
Maximum capacity |
||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits |
Accumulator for ONNX initialisers |
|
| integer, | intent(inout) | :: | num_inits |
Current number of initialisers |
||
| integer, | intent(in) | :: | max_inits |
Maximum capacity |
||
| character(len=*), | intent(in), | optional | :: | input_name |
Unused sequential input name |
|
| logical, | intent(in), | optional | :: | is_last_layer |
Unused last-layer flag |
|
| integer, | intent(in), | optional | :: | format |
Unused export format selector |
subroutine emit_onnx_nodes_duvenaud( & this, prefix, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, & input_name, is_last_layer, format & ) !! Emit ONNX JSON nodes for Duvenaud GNN layer !! !! Decomposes the Duvenaud message passing layer into standard ONNX ops: !! Gather, Concat, ScatterElements, MatMul, Sigmoid/activation, !! Softmax, ReduceSum, Add, Div, Clip, Sub, etc. !! !! This override is called by write_onnx instead of the standard !! node emission logic, making the ONNX export extensible for new !! GNN layer types. use athena__onnx_msgpass_utils, only: emit_output_identity implicit none ! Arguments class(duvenaud_msgpass_layer_type), intent(in) :: this !! Instance of the layer character(*), intent(in) :: prefix !! Node name prefix (e.g. "node_2") type(onnx_node_type), intent(inout), dimension(:) :: nodes !! Accumulator for ONNX nodes integer, intent(inout) :: num_nodes !! Current number of nodes integer, intent(in) :: max_nodes !! Maximum capacity type(onnx_initialiser_type), intent(inout), dimension(:) :: inits !! Accumulator for ONNX initialisers integer, intent(inout) :: num_inits !! Current number of initialisers integer, intent(in) :: max_inits !! Maximum capacity character(*), optional, intent(in) :: input_name !! Unused sequential input name logical, optional, intent(in) :: is_last_layer !! Unused last-layer flag integer, optional, intent(in) :: format !! Unused export format selector ! Local variables integer :: t character(128) :: cur_vertex_name, readout_accum ! Must be called with vertex_input, edge_input etc. already set ! These are stored in the node's input naming convention ! prefix is e.g. "node_2", inputs come from the calling context ! ===== Emit message passing time steps ===== do t = 1, this%num_time_steps call emit_duvenaud_timestep( & prefix, t, & this%num_vertex_features(t-1), this%num_edge_features(0), & this%num_vertex_features(t), & this%min_vertex_degree, this%max_vertex_degree, & this%params(t)%val(:,1), & this%activation%name, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, & cur_vertex_name & ) end do ! ===== Emit readout ===== call emit_duvenaud_readout_impl( & prefix, this, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, & readout_accum & ) ! The readout output becomes the layer output for downstream layers. call emit_output_identity( & prefix, trim(readout_accum), this%activation%name, & nodes, num_nodes) end subroutine emit_onnx_nodes_duvenaud