emit_onnx_nodes_duvenaud Subroutine

private subroutine emit_onnx_nodes_duvenaud(this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, input_name, is_last_layer, format)

Emit ONNX JSON nodes for Duvenaud GNN layer

Decomposes the Duvenaud message passing layer into standard ONNX ops: Gather, Concat, ScatterElements, MatMul, Sigmoid/activation, Softmax, ReduceSum, Add, Div, Clip, Sub, etc.

This override is called by write_onnx instead of the standard node emission logic, making the ONNX export extensible for new GNN layer types.

Type Bound

duvenaud_msgpass_layer_type

Arguments

Type IntentOptional Attributes Name
class(duvenaud_msgpass_layer_type), intent(in) :: this

Instance of the layer

character(len=*), intent(in) :: prefix

Node name prefix (e.g. "node_2")

type(onnx_node_type), intent(inout), dimension(:) :: nodes

Accumulator for ONNX nodes

integer, intent(inout) :: num_nodes

Current number of nodes

integer, intent(in) :: max_nodes

Maximum capacity

type(onnx_initialiser_type), intent(inout), dimension(:) :: inits

Accumulator for ONNX initialisers

integer, intent(inout) :: num_inits

Current number of initialisers

integer, intent(in) :: max_inits

Maximum capacity

character(len=*), intent(in), optional :: input_name

Unused sequential input name

logical, intent(in), optional :: is_last_layer

Unused last-layer flag

integer, intent(in), optional :: format

Unused export format selector


Source Code

  subroutine emit_onnx_nodes_duvenaud( &
       this, prefix, &
       nodes, num_nodes, max_nodes, &
       inits, num_inits, max_inits, &
       input_name, is_last_layer, format &
  )
    !! Emit ONNX JSON nodes for Duvenaud GNN layer
    !!
    !! Decomposes the Duvenaud message passing layer into standard ONNX ops:
    !!   Gather, Concat, ScatterElements, MatMul, Sigmoid/activation,
    !!   Softmax, ReduceSum, Add, Div, Clip, Sub, etc.
    !!
    !! This override is called by write_onnx instead of the standard
    !! node emission logic, making the ONNX export extensible for new
    !! GNN layer types.
    use athena__onnx_msgpass_utils, only: emit_output_identity
    implicit none

    ! Arguments
    class(duvenaud_msgpass_layer_type), intent(in) :: this
    !! Instance of the layer
    character(*), intent(in) :: prefix
    !! Node name prefix (e.g. "node_2")
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    !! Accumulator for ONNX nodes
    integer, intent(inout) :: num_nodes
    !! Current number of nodes
    integer, intent(in) :: max_nodes
    !! Maximum capacity
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    !! Accumulator for ONNX initialisers
    integer, intent(inout) :: num_inits
    !! Current number of initialisers
    integer, intent(in) :: max_inits
    !! Maximum capacity
    character(*), optional, intent(in) :: input_name
    !! Unused sequential input name
    logical, optional, intent(in) :: is_last_layer
    !! Unused last-layer flag
    integer, optional, intent(in) :: format
    !! Unused export format selector

    ! Local variables
    integer :: t
    character(128) :: cur_vertex_name, readout_accum

    ! Must be called with vertex_input, edge_input etc. already set
    ! These are stored in the node's input naming convention
    ! prefix is e.g. "node_2", inputs come from the calling context

    ! ===== Emit message passing time steps =====
    do t = 1, this%num_time_steps
       call emit_duvenaud_timestep( &
            prefix, t, &
            this%num_vertex_features(t-1), this%num_edge_features(0), &
            this%num_vertex_features(t), &
            this%min_vertex_degree, this%max_vertex_degree, &
            this%params(t)%val(:,1), &
            this%activation%name, &
            nodes, num_nodes, max_nodes, &
            inits, num_inits, max_inits, &
            cur_vertex_name &
       )
    end do

    ! ===== Emit readout =====
    call emit_duvenaud_readout_impl( &
         prefix, this, &
         nodes, num_nodes, max_nodes, &
         inits, num_inits, max_inits, &
         readout_accum &
    )

    ! The readout output becomes the layer output for downstream layers.
    call emit_output_identity( &
         prefix, trim(readout_accum), this%activation%name, &
         nodes, num_nodes)

  end subroutine emit_onnx_nodes_duvenaud