Emit the zero-initialise, expand, and scatter-add aggregation block.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | tp | |||
| character(len=*), | intent(in) | :: | vertex_in | |||
| character(len=*), | intent(in) | :: | target_in | |||
| character(len=*), | intent(in) | :: | message_in | |||
| integer, | intent(in) | :: | feature_dim | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| character(len=128), | intent(out) | :: | aggr_out |
subroutine emit_scatter_aggregator( & tp, vertex_in, target_in, message_in, feature_dim, & nodes, num_nodes, inits, num_inits, aggr_out) !! Emit the zero-initialise, expand, and scatter-add aggregation block. implicit none ! Arguments character(*), intent(in) :: tp, vertex_in, target_in, message_in integer, intent(in) :: feature_dim type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits character(128), intent(out) :: aggr_out ! Local variables character(128) :: shape_name, nnodes_idx, nnodes_name character(128) :: feat_dim_name, aggr_shape, zeros_name character(128) :: target_us, axes1_name, msg_shape, target_exp ! Get num_nodes from shape of vertex_in. write(shape_name, '(A,"_vshape")') trim(tp) call emit_node('Shape', trim(tp)//'_shape_v', & trim(shape_name), '', nodes, num_nodes, & in1=trim(vertex_in)) write(nnodes_idx, '(A,"_nnodes_idx")') trim(tp) call emit_constant_int64(trim(nnodes_idx), [0], [1], & nodes, num_nodes, inits, num_inits) write(nnodes_name, '(A,"_nnodes")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_nn', & trim(nnodes_name), onnx_axis0_attr, nodes, num_nodes, & in1=trim(shape_name), in2=trim(nnodes_idx)) ! Concat [num_nodes, feature_dim] to create the scatter target shape. write(feat_dim_name, '(A,"_feat_dim")') trim(tp) call emit_constant_int64(trim(feat_dim_name), [feature_dim], [1], & nodes, num_nodes, inits, num_inits) write(aggr_shape, '(A,"_aggr_shape")') trim(tp) call emit_node('Concat', trim(tp)//'_cat_shape', & trim(aggr_shape), onnx_concat_axis0_attr, nodes, num_nodes, & in1=trim(nnodes_name), in2=trim(feat_dim_name)) ! ConstantOfShape creates the zero-filled aggregation buffer. write(zeros_name, '(A,"_zeros")') trim(tp) call emit_constant_of_shape_float(trim(tp)//'_zeros', & trim(aggr_shape), 0.0_real32, trim(zeros_name), & nodes, num_nodes, inits, num_inits) write(target_us, '(A,"_tgt_us")') trim(tp) write(axes1_name, '(A,"_us_ax1")') trim(tp) call emit_constant_int64(trim(axes1_name), [1], [1], & nodes, num_nodes, inits, num_inits) call emit_node('Unsqueeze', trim(tp)//'_us_tgt', & trim(target_us), '', nodes, num_nodes, & in1=trim(target_in), in2=trim(axes1_name)) ! Expand target indices to match the message rank for ScatterElements. write(msg_shape, '(A,"_msg_shape")') trim(tp) call emit_node('Shape', trim(tp)//'_shape_msg', & trim(msg_shape), '', nodes, num_nodes, & in1=trim(message_in)) write(target_exp, '(A,"_tgt_exp")') trim(tp) call emit_node('Expand', trim(tp)//'_expand_tgt', & trim(target_exp), '', nodes, num_nodes, & in1=trim(target_us), in2=trim(msg_shape)) ! Scatter-add edge messages into the target-vertex slots. write(aggr_out, '(A,"_aggr")') trim(tp) call emit_node('ScatterElements', trim(tp)//'_scatter_add', & trim(aggr_out), onnx_scatter_add_attr, nodes, num_nodes, & in1=trim(zeros_name), in2=trim(target_exp), in3=trim(message_in)) end subroutine emit_scatter_aggregator