Emit ONNX nodes for one Duvenaud message passing time step.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| integer, | intent(in) | :: | t | |||
| integer, | intent(in) | :: | nv_in | |||
| integer, | intent(in) | :: | ne_in | |||
| integer, | intent(in) | :: | nv_out | |||
| integer, | intent(in) | :: | min_degree | |||
| integer, | intent(in) | :: | max_degree | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| character(len=*), | intent(in) | :: | activation_name | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| integer, | intent(in) | :: | max_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| integer, | intent(in) | :: | max_inits | |||
| character(len=128), | intent(out) | :: | vertex_out |
subroutine emit_duvenaud_timestep( & prefix, t, nv_in, ne_in, nv_out, & min_degree, max_degree, weight_data, activation_name, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, vertex_out) !! Emit ONNX nodes for one Duvenaud message passing time step. use athena__onnx_utils, only: emit_node, emit_constant_int64, & emit_activation_node use athena__onnx_msgpass_utils, only: get_timestep_output_name, & emit_edge_index_component, emit_scatter_aggregator implicit none ! Arguments character(*), intent(in) :: prefix integer, intent(in) :: t integer, intent(in) :: nv_in, ne_in, nv_out integer, intent(in) :: min_degree, max_degree real(real32), intent(in) :: weight_data(:) character(*), intent(in) :: activation_name type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes integer, intent(in) :: max_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits integer, intent(in) :: max_inits character(128), intent(out) :: vertex_out ! Local variables character(128) :: tp, tmp1, tmp2, tmp3 character(128) :: vertex_in, edge_in, edge_index_in, degree_in character(128) :: src_idx, edge_idx, target_idx character(128) :: msg_name, aggr_name, sq_out character(len=*), parameter :: onnx_axis0_attr = & ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]' character(len=*), parameter :: onnx_concat_axis1_attr = & ' "attribute": [{"name": "axis", "i": "1", "type": "INT"}]' write(tp, '(A,"_t",I0)') trim(prefix), t ! Input tensor names follow the convention set during write_onnx. ! For t=1 the vertex input comes from the previous layer, while edge, ! edge_index, and degree are always rooted at the original graph input. write(vertex_in, '(A,"_vertex_in")') trim(prefix) write(edge_in, '(A,"_edge_in")') trim(prefix) write(edge_index_in, '(A,"_edge_index_in")') trim(prefix) write(degree_in, '(A,"_degree_in")') trim(prefix) if(t .gt. 1)then call get_timestep_output_name( & prefix, t-1, activation_name, '_sq_out', '_sq', vertex_in) end if ! --- Step 1: Extract source and edge-feature indices from edge_index --- write(tmp1, '(A,"_idx0")') trim(tp) call emit_constant_int64(trim(tmp1), [0], [1], & nodes, num_nodes, inits, num_inits) write(tmp2, '(A,"_idx1")') trim(tp) call emit_constant_int64(trim(tmp2), [1], [1], & nodes, num_nodes, inits, num_inits) write(tmp3, '(A,"_idx2")') trim(tp) call emit_constant_int64(trim(tmp3), [2], [1], & nodes, num_nodes, inits, num_inits) call emit_edge_index_component( & tp, edge_index_in, trim(tmp1), 'src', src_idx, nodes, num_nodes) call emit_edge_index_component( & tp, edge_index_in, trim(tmp2), 'eidx', edge_idx, nodes, num_nodes) call emit_edge_index_component( & tp, edge_index_in, trim(tmp3), 'tgt', target_idx, nodes, num_nodes) ! --- Step 2: Gather source vertex features and edge features --- write(tmp1, '(A,"_src_feat")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_vfeat', & trim(tmp1), onnx_axis0_attr, nodes, num_nodes, & in1=trim(vertex_in), in2=trim(src_idx)) write(tmp2, '(A,"_edge_feat")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_efeat', & trim(tmp2), onnx_axis0_attr, nodes, num_nodes, & in1=trim(edge_in), in2=trim(edge_idx)) ! --- Step 3: Concat source vertex + edge features --- write(msg_name, '(A,"_msg")') trim(tp) call emit_node('Concat', trim(tp)//'_concat_msg', & trim(msg_name), onnx_concat_axis1_attr, nodes, num_nodes, & in1=trim(tmp1), in2=trim(tmp2)) ! --- Step 4: Scatter-add to aggregate messages per target vertex --- call emit_scatter_aggregator( & tp, vertex_in, target_idx, msg_name, nv_in + ne_in, & nodes, num_nodes, inits, num_inits, aggr_name) ! --- Step 5: Degree-specific weight application --- call emit_duvenaud_degree_update( & tp, degree_in, min_degree, max_degree, nv_in + ne_in, nv_out, & weight_data, aggr_name, nodes, num_nodes, inits, num_inits, sq_out) ! --- Step 6: Activation --- if(trim(activation_name) .ne. 'none')then call emit_activation_node(activation_name, trim(tp)//'_sq', & trim(sq_out), nodes, num_nodes, max_nodes) vertex_out = trim(nodes(num_nodes)%outputs(1)) else vertex_out = trim(sq_out) end if end subroutine emit_duvenaud_timestep