emit_scatter_aggregator Subroutine

public subroutine emit_scatter_aggregator(tp, vertex_in, target_in, message_in, feature_dim, nodes, num_nodes, inits, num_inits, aggr_out)

Emit the zero-initialise, expand, and scatter-add aggregation block.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: tp
character(len=*), intent(in) :: vertex_in
character(len=*), intent(in) :: target_in
character(len=*), intent(in) :: message_in
integer, intent(in) :: feature_dim
type(onnx_node_type), intent(inout), dimension(:) :: nodes
integer, intent(inout) :: num_nodes
type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
integer, intent(inout) :: num_inits
character(len=128), intent(out) :: aggr_out

Source Code

  subroutine emit_scatter_aggregator( &
       tp, vertex_in, target_in, message_in, feature_dim, &
       nodes, num_nodes, inits, num_inits, aggr_out)
    !! Emit the zero-initialise, expand, and scatter-add aggregation block.
    implicit none

    ! Arguments
    character(*), intent(in) :: tp, vertex_in, target_in, message_in
    integer, intent(in) :: feature_dim
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits
    character(128), intent(out) :: aggr_out

    ! Local variables
    character(128) :: shape_name, nnodes_idx, nnodes_name
    character(128) :: feat_dim_name, aggr_shape, zeros_name
    character(128) :: target_us, axes1_name, msg_shape, target_exp

    ! Get num_nodes from shape of vertex_in.
    write(shape_name, '(A,"_vshape")') trim(tp)
    call emit_node('Shape', trim(tp)//'_shape_v', &
         trim(shape_name), '', nodes, num_nodes, &
         in1=trim(vertex_in))

    write(nnodes_idx, '(A,"_nnodes_idx")') trim(tp)
    call emit_constant_int64(trim(nnodes_idx), [0], [1], &
         nodes, num_nodes, inits, num_inits)

    write(nnodes_name, '(A,"_nnodes")') trim(tp)
    call emit_node('Gather', trim(tp)//'_gather_nn', &
         trim(nnodes_name), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(shape_name), in2=trim(nnodes_idx))

    ! Concat [num_nodes, feature_dim] to create the scatter target shape.
    write(feat_dim_name, '(A,"_feat_dim")') trim(tp)
    call emit_constant_int64(trim(feat_dim_name), [feature_dim], [1], &
         nodes, num_nodes, inits, num_inits)

    write(aggr_shape, '(A,"_aggr_shape")') trim(tp)
    call emit_node('Concat', trim(tp)//'_cat_shape', &
         trim(aggr_shape), onnx_concat_axis0_attr, nodes, num_nodes, &
         in1=trim(nnodes_name), in2=trim(feat_dim_name))

    ! ConstantOfShape creates the zero-filled aggregation buffer.
    write(zeros_name, '(A,"_zeros")') trim(tp)
    call emit_constant_of_shape_float(trim(tp)//'_zeros', &
         trim(aggr_shape), 0.0_real32, trim(zeros_name), &
         nodes, num_nodes, inits, num_inits)

    write(target_us, '(A,"_tgt_us")') trim(tp)
    write(axes1_name, '(A,"_us_ax1")') trim(tp)
    call emit_constant_int64(trim(axes1_name), [1], [1], &
         nodes, num_nodes, inits, num_inits)
    call emit_node('Unsqueeze', trim(tp)//'_us_tgt', &
         trim(target_us), '', nodes, num_nodes, &
         in1=trim(target_in), in2=trim(axes1_name))

    ! Expand target indices to match the message rank for ScatterElements.
    write(msg_shape, '(A,"_msg_shape")') trim(tp)
    call emit_node('Shape', trim(tp)//'_shape_msg', &
         trim(msg_shape), '', nodes, num_nodes, &
         in1=trim(message_in))

    write(target_exp, '(A,"_tgt_exp")') trim(tp)
    call emit_node('Expand', trim(tp)//'_expand_tgt', &
         trim(target_exp), '', nodes, num_nodes, &
         in1=trim(target_us), in2=trim(msg_shape))

    ! Scatter-add edge messages into the target-vertex slots.
    write(aggr_out, '(A,"_aggr")') trim(tp)
    call emit_node('ScatterElements', trim(tp)//'_scatter_add', &
         trim(aggr_out), onnx_scatter_add_attr, nodes, num_nodes, &
         in1=trim(zeros_name), in2=trim(target_exp), in3=trim(message_in))

  end subroutine emit_scatter_aggregator