Emit ONNX nodes for one Kipf GCN time step.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| integer, | intent(in) | :: | t | |||
| integer, | intent(in) | :: | nv_in | |||
| integer, | intent(in) | :: | nv_out | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| character(len=*), | intent(in) | :: | activation_name | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| integer, | intent(in) | :: | max_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| integer, | intent(in) | :: | max_inits | |||
| character(len=128), | intent(out) | :: | vertex_out |
subroutine emit_kipf_timestep( & prefix, t, nv_in, nv_out, weight_data, activation_name, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, vertex_out) !! Emit ONNX nodes for one Kipf GCN time step. use athena__onnx_utils, only: emit_node, emit_constant_int64, & emit_constant_float, emit_activation_node use athena__onnx_msgpass_utils, only: get_timestep_output_name, & emit_edge_index_component, emit_scatter_aggregator, & emit_weight_initialiser_2d implicit none ! Arguments character(*), intent(in) :: prefix integer, intent(in) :: t, nv_in, nv_out real(real32), intent(in) :: weight_data(:) character(*), intent(in) :: activation_name type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes integer, intent(in) :: max_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits integer, intent(in) :: max_inits character(128), intent(out) :: vertex_out ! Local variables character(128) :: tp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7 character(128) :: vertex_in, edge_index_in, degree_in character(128) :: src_idx, target_idx, aggr_name character(len=*), parameter :: onnx_axis0_attr = & ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]' character(len=*), parameter :: onnx_transpose_10_attr = & ' "attribute": [{"name": "perm", "ints": ["1", "0"], ' // & '"type": "INTS"}]' character(len=*), parameter :: onnx_cast_float_attr = & ' "attribute": [{"name": "to", "i": "1", "type": "INT"}]' write(tp, '(A,"_t",I0)') trim(prefix), t write(vertex_in, '(A,"_vertex_in")') trim(prefix) write(edge_index_in, '(A,"_edge_index_in")') trim(prefix) write(degree_in, '(A,"_degree_in")') trim(prefix) if(t .gt. 1)then call get_timestep_output_name( & prefix, t-1, activation_name, '_mm_out', '', vertex_in) end if ! --- Step 1: Extract source and target indices from edge_index --- write(tmp1, '(A,"_idx0")') trim(tp) call emit_constant_int64(trim(tmp1), [0], [1], & nodes, num_nodes, inits, num_inits) write(tmp2, '(A,"_idx2")') trim(tp) call emit_constant_int64(trim(tmp2), [2], [1], & nodes, num_nodes, inits, num_inits) call emit_edge_index_component( & tp, edge_index_in, trim(tmp1), 'src', src_idx, nodes, num_nodes) call emit_edge_index_component( & tp, edge_index_in, trim(tmp2), 'tgt', target_idx, nodes, num_nodes) ! --- Step 2: Gather source features and compute normalisation --- write(tmp1, '(A,"_src_feat")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_vfeat', & trim(tmp1), onnx_axis0_attr, nodes, num_nodes, & in1=trim(vertex_in), in2=trim(src_idx)) write(tmp2, '(A,"_deg_f")') trim(tp) call emit_node('Cast', trim(tp)//'_cast_deg', & trim(tmp2), onnx_cast_float_attr, nodes, num_nodes, & in1=trim(degree_in)) write(tmp4, '(A,"_deg_src")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_deg_src', & trim(tmp4), onnx_axis0_attr, nodes, num_nodes, & in1=trim(tmp2), in2=trim(src_idx)) write(tmp6, '(A,"_deg_tgt")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_deg_tgt', & trim(tmp6), onnx_axis0_attr, nodes, num_nodes, & in1=trim(tmp2), in2=trim(target_idx)) write(tmp7, '(A,"_deg_prod")') trim(tp) call emit_node('Mul', trim(tp)//'_mul_deg', & trim(tmp7), '', nodes, num_nodes, & in1=trim(tmp4), in2=trim(tmp6)) write(tmp2, '(A,"_neg_half")') trim(tp) call emit_constant_float(trim(tmp2), [-0.5_real32], [1], & nodes, num_nodes, inits, num_inits) write(tmp3, '(A,"_coeff")') trim(tp) call emit_node('Pow', trim(tp)//'_pow_coeff', & trim(tmp3), '', nodes, num_nodes, & in1=trim(tmp7), in2=trim(tmp2)) ! Unsqueeze coeff for broadcasting and scale the source features. write(tmp4, '(A,"_coeff_us")') trim(tp) write(tmp6, '(A,"_us_ax1")') trim(tp) call emit_constant_int64(trim(tmp6), [1], [1], & nodes, num_nodes, inits, num_inits) call emit_node('Unsqueeze', trim(tp)//'_us_coeff', & trim(tmp4), '', nodes, num_nodes, & in1=trim(tmp3), in2=trim(tmp6)) write(tmp2, '(A,"_scaled_feat")') trim(tp) call emit_node('Mul', trim(tp)//'_mul_coeff', & trim(tmp2), '', nodes, num_nodes, & in1=trim(tmp1), in2=trim(tmp4)) ! --- Step 3: Scatter-add normalised messages to target vertices --- call emit_scatter_aggregator( & tp, vertex_in, target_idx, trim(tmp2), nv_in, & nodes, num_nodes, inits, num_inits, aggr_name) ! --- Step 4: MatMul with weight W --- write(tmp1, '(A,"_W")') trim(tp) call emit_weight_initialiser_2d( & trim(tmp1), nv_out, nv_in, weight_data, inits, num_inits) write(tmp2, '(A,"_Wt")') trim(tp) call emit_node('Transpose', trim(tp)//'_transpose_W', & trim(tmp2), onnx_transpose_10_attr, nodes, num_nodes, & in1=trim(tmp1)) write(tmp3, '(A,"_mm_out")') trim(tp) call emit_node('MatMul', trim(tp)//'_matmul', & trim(tmp3), '', nodes, num_nodes, & in1=trim(aggr_name), in2=trim(tmp2)) ! --- Step 5: Activation --- if(trim(activation_name) .ne. 'none')then call emit_activation_node(activation_name, trim(tp), trim(tmp3), & nodes, num_nodes, max_nodes) vertex_out = trim(nodes(num_nodes)%outputs(1)) else vertex_out = trim(tmp3) end if end subroutine emit_kipf_timestep