emit_onnx_nodes_kipf Subroutine

private subroutine emit_onnx_nodes_kipf(this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, input_name, is_last_layer, format)

Emit ONNX JSON nodes for Kipf GCN layer

Decomposes the Kipf message passing layer into standard ONNX ops: Gather, ScatterElements, Mul, Pow, MatMul, activation

Kipf GCN: H^(l+1) = sigma(D~^(-1/2) A~ D~^(-1/2) H^(l) W^(l)) Decomposed per timestep: 1. Extract source/target indices from edge_index 2. Gather source vertex features 3. Compute normalisation coeff = (deg_src * deg_tgt)^(-0.5) 4. Scale source features by coefficient 5. Scatter-add to target vertices 6. MatMul with weight W (transposed) 7. Apply activation

Type Bound

kipf_msgpass_layer_type

Arguments

Type IntentOptional Attributes Name
class(kipf_msgpass_layer_type), intent(in) :: this

Instance of the layer

character(len=*), intent(in) :: prefix

Node name prefix (e.g. "node_2")

type(onnx_node_type), intent(inout), dimension(:) :: nodes

Accumulator for ONNX nodes

integer, intent(inout) :: num_nodes

Current number of nodes

integer, intent(in) :: max_nodes

Maximum capacity

type(onnx_initialiser_type), intent(inout), dimension(:) :: inits

Accumulator for ONNX initialisers

integer, intent(inout) :: num_inits

Current number of initialisers

integer, intent(in) :: max_inits

Maximum capacity

character(len=*), intent(in), optional :: input_name

Unused sequential input name

logical, intent(in), optional :: is_last_layer

Unused last-layer flag

integer, intent(in), optional :: format

Unused export format selector


Source Code

  subroutine emit_onnx_nodes_kipf( &
       this, prefix, &
       nodes, num_nodes, max_nodes, &
       inits, num_inits, max_inits, &
       input_name, is_last_layer, format &
  )
    !! Emit ONNX JSON nodes for Kipf GCN layer
    !!
    !! Decomposes the Kipf message passing layer into standard ONNX ops:
    !!   Gather, ScatterElements, Mul, Pow, MatMul, activation
    !!
    !! Kipf GCN: H^(l+1) = sigma(D~^(-1/2) A~ D~^(-1/2) H^(l) W^(l))
    !! Decomposed per timestep:
    !!   1. Extract source/target indices from edge_index
    !!   2. Gather source vertex features
    !!   3. Compute normalisation coeff = (deg_src * deg_tgt)^(-0.5)
    !!   4. Scale source features by coefficient
    !!   5. Scatter-add to target vertices
    !!   6. MatMul with weight W (transposed)
    !!   7. Apply activation
    use athena__onnx_msgpass_utils, only: emit_output_identity
    implicit none

    ! Arguments
    class(kipf_msgpass_layer_type), intent(in) :: this
    !! Instance of the layer
    character(*), intent(in) :: prefix
    !! Node name prefix (e.g. "node_2")
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    !! Accumulator for ONNX nodes
    integer, intent(inout) :: num_nodes
    !! Current number of nodes
    integer, intent(in) :: max_nodes
    !! Maximum capacity
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    !! Accumulator for ONNX initialisers
    integer, intent(inout) :: num_inits
    !! Current number of initialisers
    integer, intent(in) :: max_inits
    !! Maximum capacity
    character(*), optional, intent(in) :: input_name
    !! Unused sequential input name
    logical, optional, intent(in) :: is_last_layer
    !! Unused last-layer flag
    integer, optional, intent(in) :: format
    !! Unused export format selector

    ! Local variables
    integer :: t
    !! Time-step index
    character(128) :: cur_vertex_name
    !! Current timestep output tensor name

    do t = 1, this%num_time_steps
       call emit_kipf_timestep( &
            prefix, t, &
            this%num_vertex_features(t-1), &
            this%num_vertex_features(t), &
            this%params(t)%val(:,1), &
            this%activation%name, &
            nodes, num_nodes, max_nodes, &
            inits, num_inits, max_inits, &
            cur_vertex_name &
       )
    end do

    ! Kipf produces node-level output (no readout).
    call emit_output_identity( &
         prefix, trim(cur_vertex_name), this%activation%name, &
         nodes, num_nodes)

  end subroutine emit_onnx_nodes_kipf