emit_kipf_timestep Subroutine

private subroutine emit_kipf_timestep(prefix, t, nv_in, nv_out, weight_data, activation_name, nodes, num_nodes, max_nodes, inits, num_inits, max_inits, vertex_out)

Emit ONNX nodes for one Kipf GCN time step.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix
integer, intent(in) :: t
integer, intent(in) :: nv_in
integer, intent(in) :: nv_out
real(kind=real32), intent(in) :: weight_data(:)
character(len=*), intent(in) :: activation_name
type(onnx_node_type), intent(inout), dimension(:) :: nodes
integer, intent(inout) :: num_nodes
integer, intent(in) :: max_nodes
type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
integer, intent(inout) :: num_inits
integer, intent(in) :: max_inits
character(len=128), intent(out) :: vertex_out

Source Code

  subroutine emit_kipf_timestep( &
       prefix, t, nv_in, nv_out, weight_data, activation_name, &
       nodes, num_nodes, max_nodes, &
       inits, num_inits, max_inits, vertex_out)
    !! Emit ONNX nodes for one Kipf GCN time step.
    use athena__onnx_utils, only: emit_node, emit_constant_int64, &
         emit_constant_float, emit_activation_node
    use athena__onnx_msgpass_utils, only: get_timestep_output_name, &
         emit_edge_index_component, emit_scatter_aggregator, &
         emit_weight_initialiser_2d
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    integer, intent(in) :: t, nv_in, nv_out
    real(real32), intent(in) :: weight_data(:)
    character(*), intent(in) :: activation_name
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes
    integer, intent(in) :: max_nodes
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits
    integer, intent(in) :: max_inits
    character(128), intent(out) :: vertex_out

    ! Local variables
    character(128) :: tp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7
    character(128) :: vertex_in, edge_index_in, degree_in
    character(128) :: src_idx, target_idx, aggr_name
    character(len=*), parameter :: onnx_axis0_attr = &
         '        "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
    character(len=*), parameter :: onnx_transpose_10_attr = &
         '        "attribute": [{"name": "perm", "ints": ["1", "0"], ' // &
         '"type": "INTS"}]'
    character(len=*), parameter :: onnx_cast_float_attr = &
         '        "attribute": [{"name": "to", "i": "1", "type": "INT"}]'

    write(tp, '(A,"_t",I0)') trim(prefix), t
    write(vertex_in, '(A,"_vertex_in")') trim(prefix)
    write(edge_index_in, '(A,"_edge_index_in")') trim(prefix)
    write(degree_in, '(A,"_degree_in")') trim(prefix)
    if(t .gt. 1)then
       call get_timestep_output_name( &
            prefix, t-1, activation_name, '_mm_out', '', vertex_in)
    end if

    ! --- Step 1: Extract source and target indices from edge_index ---
    write(tmp1, '(A,"_idx0")') trim(tp)
    call emit_constant_int64(trim(tmp1), [0], [1], &
         nodes, num_nodes, inits, num_inits)
    write(tmp2, '(A,"_idx2")') trim(tp)
    call emit_constant_int64(trim(tmp2), [2], [1], &
         nodes, num_nodes, inits, num_inits)

    call emit_edge_index_component( &
         tp, edge_index_in, trim(tmp1), 'src', src_idx, nodes, num_nodes)
    call emit_edge_index_component( &
         tp, edge_index_in, trim(tmp2), 'tgt', target_idx, nodes, num_nodes)

    ! --- Step 2: Gather source features and compute normalisation ---
    write(tmp1, '(A,"_src_feat")') trim(tp)
    call emit_node('Gather', trim(tp)//'_gather_vfeat', &
         trim(tmp1), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(vertex_in), in2=trim(src_idx))

    write(tmp2, '(A,"_deg_f")') trim(tp)
    call emit_node('Cast', trim(tp)//'_cast_deg', &
         trim(tmp2), onnx_cast_float_attr, nodes, num_nodes, &
         in1=trim(degree_in))

    write(tmp4, '(A,"_deg_src")') trim(tp)
    call emit_node('Gather', trim(tp)//'_gather_deg_src', &
         trim(tmp4), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(tmp2), in2=trim(src_idx))

    write(tmp6, '(A,"_deg_tgt")') trim(tp)
    call emit_node('Gather', trim(tp)//'_gather_deg_tgt', &
         trim(tmp6), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(tmp2), in2=trim(target_idx))

    write(tmp7, '(A,"_deg_prod")') trim(tp)
    call emit_node('Mul', trim(tp)//'_mul_deg', &
         trim(tmp7), '', nodes, num_nodes, &
         in1=trim(tmp4), in2=trim(tmp6))

    write(tmp2, '(A,"_neg_half")') trim(tp)
    call emit_constant_float(trim(tmp2), [-0.5_real32], [1], &
         nodes, num_nodes, inits, num_inits)

    write(tmp3, '(A,"_coeff")') trim(tp)
    call emit_node('Pow', trim(tp)//'_pow_coeff', &
         trim(tmp3), '', nodes, num_nodes, &
         in1=trim(tmp7), in2=trim(tmp2))

    ! Unsqueeze coeff for broadcasting and scale the source features.
    write(tmp4, '(A,"_coeff_us")') trim(tp)
    write(tmp6, '(A,"_us_ax1")') trim(tp)
    call emit_constant_int64(trim(tmp6), [1], [1], &
         nodes, num_nodes, inits, num_inits)
    call emit_node('Unsqueeze', trim(tp)//'_us_coeff', &
         trim(tmp4), '', nodes, num_nodes, &
         in1=trim(tmp3), in2=trim(tmp6))

    write(tmp2, '(A,"_scaled_feat")') trim(tp)
    call emit_node('Mul', trim(tp)//'_mul_coeff', &
         trim(tmp2), '', nodes, num_nodes, &
         in1=trim(tmp1), in2=trim(tmp4))

    ! --- Step 3: Scatter-add normalised messages to target vertices ---
    call emit_scatter_aggregator( &
         tp, vertex_in, target_idx, trim(tmp2), nv_in, &
         nodes, num_nodes, inits, num_inits, aggr_name)

    ! --- Step 4: MatMul with weight W ---
    write(tmp1, '(A,"_W")') trim(tp)
    call emit_weight_initialiser_2d( &
         trim(tmp1), nv_out, nv_in, weight_data, inits, num_inits)

    write(tmp2, '(A,"_Wt")') trim(tp)
    call emit_node('Transpose', trim(tp)//'_transpose_W', &
         trim(tmp2), onnx_transpose_10_attr, nodes, num_nodes, &
         in1=trim(tmp1))

    write(tmp3, '(A,"_mm_out")') trim(tp)
    call emit_node('MatMul', trim(tp)//'_matmul', &
         trim(tmp3), '', nodes, num_nodes, &
         in1=trim(aggr_name), in2=trim(tmp2))

    ! --- Step 5: Activation ---
    if(trim(activation_name) .ne. 'none')then
       call emit_activation_node(activation_name, trim(tp), trim(tmp3), &
            nodes, num_nodes, max_nodes)
       vertex_out = trim(nodes(num_nodes)%outputs(1))
    else
       vertex_out = trim(tmp3)
    end if

  end subroutine emit_kipf_timestep