emit_duvenaud_degree_update Subroutine

private subroutine emit_duvenaud_degree_update(tp, degree_in, min_degree, max_degree, feature_dim, nv_out, weight_data, aggr_in, nodes, num_nodes, inits, num_inits, sq_out)

Emit the degree-dependent weight selection and update block.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: tp
character(len=*), intent(in) :: degree_in
integer, intent(in) :: min_degree
integer, intent(in) :: max_degree
integer, intent(in) :: feature_dim
integer, intent(in) :: nv_out
real(kind=real32), intent(in) :: weight_data(:)
character(len=*), intent(in) :: aggr_in
type(onnx_node_type), intent(inout), dimension(:) :: nodes
integer, intent(inout) :: num_nodes
type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
integer, intent(inout) :: num_inits
character(len=128), intent(out) :: sq_out

Source Code

  subroutine emit_duvenaud_degree_update( &
       tp, degree_in, min_degree, max_degree, feature_dim, nv_out, &
       weight_data, aggr_in, nodes, num_nodes, inits, num_inits, sq_out)
    !! Emit the degree-dependent weight selection and update block.
    use athena__onnx_utils, only: emit_node, emit_squeeze_node, &
         emit_constant_int64, emit_constant_float
    use athena__onnx_msgpass_utils, only: emit_weight_initialiser_3d
    implicit none

    ! Arguments
    character(*), intent(in) :: tp, degree_in, aggr_in
    integer, intent(in) :: min_degree, max_degree, feature_dim, nv_out
    real(real32), intent(in) :: weight_data(:)
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits
    character(128), intent(out) :: sq_out

    ! Local variables
    character(128) :: min_deg_name, max_deg_name, deg_float
    character(128) :: deg_clip, deg_idx_float, deg_idx
    character(128) :: weight_name, weight_sel, deg_us
    character(128) :: aggr_norm, aggr_us, matmul_out
    character(128) :: axes1_name, axes2_name
    character(len=*), parameter :: onnx_axis0_attr = &
         '        "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
    character(len=*), parameter :: onnx_cast_float_attr = &
         '        "attribute": [{"name": "to", "i": "1", "type": "INT"}]'
    character(len=*), parameter :: onnx_cast_int64_attr = &
         '        "attribute": [{"name": "to", "i": "7", "type": "INT"}]'

    ! Clip degree to the supported bucket interval.
    write(min_deg_name, '(A,"_min_deg")') trim(tp)
    call emit_constant_float(trim(min_deg_name), &
         [ real(min_degree, real32) ], [1], &
         nodes, num_nodes, inits, num_inits)

    write(max_deg_name, '(A,"_max_deg")') trim(tp)
    call emit_constant_float(trim(max_deg_name), &
         [ real(max_degree, real32) ], [1], &
         nodes, num_nodes, inits, num_inits)

    write(deg_float, '(A,"_deg_f")') trim(tp)
    call emit_node('Cast', trim(tp)//'_cast_deg', &
         trim(deg_float), onnx_cast_float_attr, nodes, num_nodes, &
         in1=trim(degree_in))

    write(deg_clip, '(A,"_deg_clip")') trim(tp)
    call emit_node('Clip', trim(tp)//'_clip_deg', &
         trim(deg_clip), '', nodes, num_nodes, &
         in1=trim(deg_float), in2=trim(min_deg_name), in3=trim(max_deg_name))

    ! Shift clipped degrees so they can index the weight bank from zero.
    write(deg_idx_float, '(A,"_deg_idx_f")') trim(tp)
    call emit_node('Sub', trim(tp)//'_sub_mindeg', &
         trim(deg_idx_float), '', nodes, num_nodes, &
         in1=trim(deg_clip), in2=trim(min_deg_name))

    write(deg_idx, '(A,"_deg_idx")') trim(tp)
    call emit_node('Cast', trim(tp)//'_cast_degidx', &
         trim(deg_idx), onnx_cast_int64_attr, nodes, num_nodes, &
         in1=trim(deg_idx_float))

    ! Store the degree-specific weight bank as a 3D initialiser.
    write(weight_name, '(A,"_W")') trim(tp)
    call emit_weight_initialiser_3d( &
         trim(weight_name), max_degree - min_degree + 1, &
         nv_out, feature_dim, weight_data, inits, num_inits)

    write(weight_sel, '(A,"_W_sel")') trim(tp)
    call emit_node('Gather', trim(tp)//'_gather_W', &
         trim(weight_sel), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(weight_name), in2=trim(deg_idx))

    ! Divide by degree and reshape for batched MatMul.
    write(axes1_name, '(A,"_us_ax1_deg")') trim(tp)
    call emit_constant_int64(trim(axes1_name), [1], [1], &
         nodes, num_nodes, inits, num_inits)

    write(deg_us, '(A,"_deg_us")') trim(tp)
    call emit_node('Unsqueeze', trim(tp)//'_us_deg', &
         trim(deg_us), '', nodes, num_nodes, &
         in1=trim(deg_clip), in2=trim(axes1_name))

    write(aggr_norm, '(A,"_aggr_norm")') trim(tp)
    call emit_node('Div', trim(tp)//'_div_deg', &
         trim(aggr_norm), '', nodes, num_nodes, &
         in1=trim(aggr_in), in2=trim(deg_us))

    write(axes2_name, '(A,"_us_ax2")') trim(tp)
    call emit_constant_int64(trim(axes2_name), [2], [1], &
         nodes, num_nodes, inits, num_inits)

    write(aggr_us, '(A,"_aggr_us")') trim(tp)
    call emit_node('Unsqueeze', trim(tp)//'_us_aggr', &
         trim(aggr_us), '', nodes, num_nodes, &
         in1=trim(aggr_norm), in2=trim(axes2_name))

    write(matmul_out, '(A,"_matmul_out")') trim(tp)
    call emit_node('MatMul', trim(tp)//'_matmul', &
         trim(matmul_out), '', nodes, num_nodes, &
         in1=trim(weight_sel), in2=trim(aggr_us))

    write(sq_out, '(A,"_sq_out")') trim(tp)
    call emit_squeeze_node(trim(tp)//'_sq_mm', &
         trim(matmul_out), trim(axes2_name), trim(sq_out), &
         nodes, num_nodes)

  end subroutine emit_duvenaud_degree_update