Emit the degree-dependent weight selection and update block.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | tp | |||
| character(len=*), | intent(in) | :: | degree_in | |||
| integer, | intent(in) | :: | min_degree | |||
| integer, | intent(in) | :: | max_degree | |||
| integer, | intent(in) | :: | feature_dim | |||
| integer, | intent(in) | :: | nv_out | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| character(len=*), | intent(in) | :: | aggr_in | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| character(len=128), | intent(out) | :: | sq_out |
subroutine emit_duvenaud_degree_update( & tp, degree_in, min_degree, max_degree, feature_dim, nv_out, & weight_data, aggr_in, nodes, num_nodes, inits, num_inits, sq_out) !! Emit the degree-dependent weight selection and update block. use athena__onnx_utils, only: emit_node, emit_squeeze_node, & emit_constant_int64, emit_constant_float use athena__onnx_msgpass_utils, only: emit_weight_initialiser_3d implicit none ! Arguments character(*), intent(in) :: tp, degree_in, aggr_in integer, intent(in) :: min_degree, max_degree, feature_dim, nv_out real(real32), intent(in) :: weight_data(:) type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits character(128), intent(out) :: sq_out ! Local variables character(128) :: min_deg_name, max_deg_name, deg_float character(128) :: deg_clip, deg_idx_float, deg_idx character(128) :: weight_name, weight_sel, deg_us character(128) :: aggr_norm, aggr_us, matmul_out character(128) :: axes1_name, axes2_name character(len=*), parameter :: onnx_axis0_attr = & ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]' character(len=*), parameter :: onnx_cast_float_attr = & ' "attribute": [{"name": "to", "i": "1", "type": "INT"}]' character(len=*), parameter :: onnx_cast_int64_attr = & ' "attribute": [{"name": "to", "i": "7", "type": "INT"}]' ! Clip degree to the supported bucket interval. write(min_deg_name, '(A,"_min_deg")') trim(tp) call emit_constant_float(trim(min_deg_name), & [ real(min_degree, real32) ], [1], & nodes, num_nodes, inits, num_inits) write(max_deg_name, '(A,"_max_deg")') trim(tp) call emit_constant_float(trim(max_deg_name), & [ real(max_degree, real32) ], [1], & nodes, num_nodes, inits, num_inits) write(deg_float, '(A,"_deg_f")') trim(tp) call emit_node('Cast', trim(tp)//'_cast_deg', & trim(deg_float), onnx_cast_float_attr, nodes, num_nodes, & in1=trim(degree_in)) write(deg_clip, '(A,"_deg_clip")') trim(tp) call emit_node('Clip', trim(tp)//'_clip_deg', & trim(deg_clip), '', nodes, num_nodes, & in1=trim(deg_float), in2=trim(min_deg_name), in3=trim(max_deg_name)) ! Shift clipped degrees so they can index the weight bank from zero. write(deg_idx_float, '(A,"_deg_idx_f")') trim(tp) call emit_node('Sub', trim(tp)//'_sub_mindeg', & trim(deg_idx_float), '', nodes, num_nodes, & in1=trim(deg_clip), in2=trim(min_deg_name)) write(deg_idx, '(A,"_deg_idx")') trim(tp) call emit_node('Cast', trim(tp)//'_cast_degidx', & trim(deg_idx), onnx_cast_int64_attr, nodes, num_nodes, & in1=trim(deg_idx_float)) ! Store the degree-specific weight bank as a 3D initialiser. write(weight_name, '(A,"_W")') trim(tp) call emit_weight_initialiser_3d( & trim(weight_name), max_degree - min_degree + 1, & nv_out, feature_dim, weight_data, inits, num_inits) write(weight_sel, '(A,"_W_sel")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_W', & trim(weight_sel), onnx_axis0_attr, nodes, num_nodes, & in1=trim(weight_name), in2=trim(deg_idx)) ! Divide by degree and reshape for batched MatMul. write(axes1_name, '(A,"_us_ax1_deg")') trim(tp) call emit_constant_int64(trim(axes1_name), [1], [1], & nodes, num_nodes, inits, num_inits) write(deg_us, '(A,"_deg_us")') trim(tp) call emit_node('Unsqueeze', trim(tp)//'_us_deg', & trim(deg_us), '', nodes, num_nodes, & in1=trim(deg_clip), in2=trim(axes1_name)) write(aggr_norm, '(A,"_aggr_norm")') trim(tp) call emit_node('Div', trim(tp)//'_div_deg', & trim(aggr_norm), '', nodes, num_nodes, & in1=trim(aggr_in), in2=trim(deg_us)) write(axes2_name, '(A,"_us_ax2")') trim(tp) call emit_constant_int64(trim(axes2_name), [2], [1], & nodes, num_nodes, inits, num_inits) write(aggr_us, '(A,"_aggr_us")') trim(tp) call emit_node('Unsqueeze', trim(tp)//'_us_aggr', & trim(aggr_us), '', nodes, num_nodes, & in1=trim(aggr_norm), in2=trim(axes2_name)) write(matmul_out, '(A,"_matmul_out")') trim(tp) call emit_node('MatMul', trim(tp)//'_matmul', & trim(matmul_out), '', nodes, num_nodes, & in1=trim(weight_sel), in2=trim(aggr_us)) write(sq_out, '(A,"_sq_out")') trim(tp) call emit_squeeze_node(trim(tp)//'_sq_mm', & trim(matmul_out), trim(axes2_name), trim(sq_out), & nodes, num_nodes) end subroutine emit_duvenaud_degree_update