athena_onnx_msgpass_utils.f90 Source File


Source Code

module athena__onnx_msgpass_utils
  !! Shared ONNX builder helpers for message-passing layers.
  !!
  !! This module factors out the repeated edge-index extraction,
  !! scatter accumulation, weight export, and output naming logic used by
  !! the Duvenaud and Kipf message-passing layers.
  use coreutils, only: real32
  use athena__misc_types, only: onnx_node_type, onnx_initialiser_type, &
       onnx_tensor_type
  use athena__onnx_utils, only: emit_node, emit_squeeze_node, &
       emit_constant_int64, emit_constant_float, &
       emit_constant_of_shape_float, emit_activation_node, &
       col_to_row_major_2d
  implicit none

  private

  public :: emit_msgpass_graph_inputs
  public :: emit_output_identity
  public :: get_timestep_output_name
  public :: emit_edge_index_component
  public :: emit_scatter_aggregator
  public :: emit_weight_initialiser_2d
  public :: emit_weight_initialiser_3d

  character(len=*), parameter :: onnx_axis0_attr = &
       '        "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
  character(len=*), parameter :: onnx_concat_axis0_attr = &
       '        "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
  character(len=*), parameter :: onnx_concat_axis1_attr = &
       '        "attribute": [{"name": "axis", "i": "1", "type": "INT"}]'
  character(len=*), parameter :: onnx_softmax_axis0_attr = &
       '        "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
  character(len=*), parameter :: onnx_transpose_10_attr = &
       '        "attribute": [{"name": "perm", "ints": ["1", "0"], ' // &
       '"type": "INTS"}]'
  character(len=*), parameter :: onnx_reduce_sum_attr = &
       '        "attribute": [{"name": "keepdims", "i": "0", ' // &
       '"type": "INT"}]'
  character(len=*), parameter :: onnx_cast_float_attr = &
       '        "attribute": [{"name": "to", "i": "1", "type": "INT"}]'
  character(len=*), parameter :: onnx_cast_int64_attr = &
       '        "attribute": [{"name": "to", "i": "7", "type": "INT"}]'
  character(len=*), parameter :: onnx_scatter_add_attr = &
       '        "attribute": [' // &
       '{"name": "axis", "i": "0", "type": "INT"}, ' // &
       '{"name": "reduction", "s": "YWRk", "type": "STRING"}]'

contains


!###############################################################################
  subroutine emit_msgpass_graph_inputs(prefix, input_shape, graph_inputs, &
       num_inputs)
    !! Emit the standard graph input tensors used by message-passing layers.
    !!
    !! Adds vertex features, optional edge features, edge_index, and degree.
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Input name prefix (e.g. "input_1")
    integer, dimension(:), intent(in) :: input_shape
    !! Layer input shape [num_vertex_features, num_edge_features]
    type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
    !! Accumulator for graph inputs
    integer, intent(inout) :: num_inputs
    !! Current number of graph inputs

    ! Vertex features: [num_nodes, nv]
    call add_graph_input_tensor( &
         graph_inputs, num_inputs, trim(prefix)//'_vertex', 1, &
         -1, 'num_nodes', input_shape(1), '')

    ! Edge features: [num_edges, ne]
    if(input_shape(2) .gt. 0)then
       call add_graph_input_tensor( &
            graph_inputs, num_inputs, trim(prefix)//'_edge', 1, &
            -1, 'num_edges', input_shape(2), '')
    end if

    ! Edge index: [3, num_csr_entries]
    call add_graph_input_tensor( &
         graph_inputs, num_inputs, trim(prefix)//'_edge_index', 7, &
         3, '', -1, 'num_csr_entries')

    ! Node degree: [num_nodes]
    call add_graph_input_tensor( &
         graph_inputs, num_inputs, trim(prefix)//'_degree', 7, &
         -1, 'num_nodes')

  end subroutine emit_msgpass_graph_inputs
!###############################################################################


!###############################################################################
  subroutine emit_output_identity(prefix, source_name, activation_name, &
       nodes, num_nodes)
    !! Emit a final Identity node using the standard ATHENA output naming.
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Layer node prefix
    character(*), intent(in) :: source_name
    !! Source tensor to rename
    character(*), intent(in) :: activation_name
    !! Final activation name used in the exported output suffix
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    !! Accumulator for ONNX nodes
    integer, intent(inout) :: num_nodes
    !! Current number of nodes

    ! Local variables
    character(:), allocatable :: suffix

    suffix = '_output'
    if(trim(activation_name) .ne. 'none')then
       suffix = '_' // trim(adjustl(activation_name)) // '_output'
    end if
    call emit_node('Identity', trim(prefix)//'_identity', &
         trim(prefix)//trim(suffix), '', nodes, num_nodes, &
         in1=trim(source_name))

  end subroutine emit_output_identity
!###############################################################################


!###############################################################################
  subroutine add_graph_input_tensor( &
       graph_inputs, num_inputs, name, elem_type, &
       dim1, dim_param1, dim2, dim_param2)
    !! Add one graph input tensor declaration to the ONNX input list.
    implicit none

    ! Arguments
    type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
    integer, intent(inout) :: num_inputs
    character(*), intent(in) :: name
    integer, intent(in) :: elem_type, dim1
    character(*), intent(in) :: dim_param1
    integer, optional, intent(in) :: dim2
    character(*), optional, intent(in) :: dim_param2

    num_inputs = num_inputs + 1
    graph_inputs(num_inputs)%name = trim(name)
    graph_inputs(num_inputs)%elem_type = elem_type
    if(present(dim2))then
       allocate(graph_inputs(num_inputs)%dims(2))
       allocate(graph_inputs(num_inputs)%dim_params(2))
       graph_inputs(num_inputs)%dims = [ dim1, dim2 ]
       graph_inputs(num_inputs)%dim_params(1) = dim_param1
       graph_inputs(num_inputs)%dim_params(2) = dim_param2
    else
       allocate(graph_inputs(num_inputs)%dims(1))
       allocate(graph_inputs(num_inputs)%dim_params(1))
       graph_inputs(num_inputs)%dims(1) = dim1
       graph_inputs(num_inputs)%dim_params(1) = dim_param1
    end if

  end subroutine add_graph_input_tensor
!###############################################################################


!###############################################################################
  subroutine get_timestep_output_name( &
       prefix, t, activation_name, inactive_suffix, activation_suffix, output)
    !! Build the canonical ONNX output name for one exported timestep.
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    integer, intent(in) :: t
    character(*), intent(in) :: activation_name
    character(*), intent(in) :: inactive_suffix, activation_suffix
    character(128), intent(out) :: output

    ! Local variables
    character(128) :: step_prefix

    write(step_prefix, '(A,"_t",I0)') trim(prefix), t
    if(trim(activation_name) .ne. 'none')then
       output = trim(step_prefix)
       if(len_trim(activation_suffix) .gt. 0)then
          output = trim(output) // trim(activation_suffix)
       end if
       output = trim(output) // '_' // trim(adjustl(activation_name)) // &
            '_output'
    else
       output = trim(step_prefix) // trim(inactive_suffix)
    end if

  end subroutine get_timestep_output_name
!###############################################################################


!###############################################################################
  subroutine emit_edge_index_component( &
       tp, edge_index_in, index_name, tag, component_out, &
       nodes, num_nodes)
    !! Gather one edge_index row and squeeze it into a vector.
    implicit none

    ! Arguments
    character(*), intent(in) :: tp, edge_index_in, index_name, tag
    character(128), intent(out) :: component_out
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes

    ! Local variables
    character(128) :: raw_name

    write(raw_name, '(A,"_",A,"_raw")') trim(tp), trim(tag)
    write(component_out, '(A,"_",A)') trim(tp), trim(tag)
    call emit_node('Gather', trim(tp)//'_gather_'//trim(tag), &
         trim(raw_name), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(edge_index_in), in2=trim(index_name))
    call emit_squeeze_node(trim(tp)//'_sq_'//trim(tag), &
         trim(raw_name), trim(tp)//'_idx0', trim(component_out), &
         nodes, num_nodes)

  end subroutine emit_edge_index_component
!###############################################################################


!###############################################################################
  subroutine emit_scatter_aggregator( &
       tp, vertex_in, target_in, message_in, feature_dim, &
       nodes, num_nodes, inits, num_inits, aggr_out)
    !! Emit the zero-initialise, expand, and scatter-add aggregation block.
    implicit none

    ! Arguments
    character(*), intent(in) :: tp, vertex_in, target_in, message_in
    integer, intent(in) :: feature_dim
    type(onnx_node_type), intent(inout), dimension(:) :: nodes
    integer, intent(inout) :: num_nodes
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits
    character(128), intent(out) :: aggr_out

    ! Local variables
    character(128) :: shape_name, nnodes_idx, nnodes_name
    character(128) :: feat_dim_name, aggr_shape, zeros_name
    character(128) :: target_us, axes1_name, msg_shape, target_exp

    ! Get num_nodes from shape of vertex_in.
    write(shape_name, '(A,"_vshape")') trim(tp)
    call emit_node('Shape', trim(tp)//'_shape_v', &
         trim(shape_name), '', nodes, num_nodes, &
         in1=trim(vertex_in))

    write(nnodes_idx, '(A,"_nnodes_idx")') trim(tp)
    call emit_constant_int64(trim(nnodes_idx), [0], [1], &
         nodes, num_nodes, inits, num_inits)

    write(nnodes_name, '(A,"_nnodes")') trim(tp)
    call emit_node('Gather', trim(tp)//'_gather_nn', &
         trim(nnodes_name), onnx_axis0_attr, nodes, num_nodes, &
         in1=trim(shape_name), in2=trim(nnodes_idx))

    ! Concat [num_nodes, feature_dim] to create the scatter target shape.
    write(feat_dim_name, '(A,"_feat_dim")') trim(tp)
    call emit_constant_int64(trim(feat_dim_name), [feature_dim], [1], &
         nodes, num_nodes, inits, num_inits)

    write(aggr_shape, '(A,"_aggr_shape")') trim(tp)
    call emit_node('Concat', trim(tp)//'_cat_shape', &
         trim(aggr_shape), onnx_concat_axis0_attr, nodes, num_nodes, &
         in1=trim(nnodes_name), in2=trim(feat_dim_name))

    ! ConstantOfShape creates the zero-filled aggregation buffer.
    write(zeros_name, '(A,"_zeros")') trim(tp)
    call emit_constant_of_shape_float(trim(tp)//'_zeros', &
         trim(aggr_shape), 0.0_real32, trim(zeros_name), &
         nodes, num_nodes, inits, num_inits)

    write(target_us, '(A,"_tgt_us")') trim(tp)
    write(axes1_name, '(A,"_us_ax1")') trim(tp)
    call emit_constant_int64(trim(axes1_name), [1], [1], &
         nodes, num_nodes, inits, num_inits)
    call emit_node('Unsqueeze', trim(tp)//'_us_tgt', &
         trim(target_us), '', nodes, num_nodes, &
         in1=trim(target_in), in2=trim(axes1_name))

    ! Expand target indices to match the message rank for ScatterElements.
    write(msg_shape, '(A,"_msg_shape")') trim(tp)
    call emit_node('Shape', trim(tp)//'_shape_msg', &
         trim(msg_shape), '', nodes, num_nodes, &
         in1=trim(message_in))

    write(target_exp, '(A,"_tgt_exp")') trim(tp)
    call emit_node('Expand', trim(tp)//'_expand_tgt', &
         trim(target_exp), '', nodes, num_nodes, &
         in1=trim(target_us), in2=trim(msg_shape))

    ! Scatter-add edge messages into the target-vertex slots.
    write(aggr_out, '(A,"_aggr")') trim(tp)
    call emit_node('ScatterElements', trim(tp)//'_scatter_add', &
         trim(aggr_out), onnx_scatter_add_attr, nodes, num_nodes, &
         in1=trim(zeros_name), in2=trim(target_exp), in3=trim(message_in))

  end subroutine emit_scatter_aggregator
!###############################################################################


!###############################################################################
  subroutine emit_weight_initialiser_2d( &
       name, nrows, ncols, weight_data, inits, num_inits)
    !! Store a 2D weight matrix as an ONNX initialiser in row-major order.
    implicit none

    ! Arguments
    character(*), intent(in) :: name
    integer, intent(in) :: nrows, ncols
    real(real32), intent(in) :: weight_data(:)
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits

    num_inits = num_inits + 1
    inits(num_inits)%name = trim(name)
    inits(num_inits)%data_type = 1
    allocate(inits(num_inits)%dims(2))
    inits(num_inits)%dims = [ nrows, ncols ]
    allocate(inits(num_inits)%data(size(weight_data)))
    call col_to_row_major_2d(weight_data, inits(num_inits)%data, nrows, ncols)

  end subroutine emit_weight_initialiser_2d
!###############################################################################


!###############################################################################
  subroutine emit_weight_initialiser_3d( &
       name, nslices, nrows, ncols, weight_data, inits, num_inits)
    !! Store a stacked bank of 2D weight matrices as one ONNX tensor.
    implicit none

    ! Arguments
    character(*), intent(in) :: name
    integer, intent(in) :: nslices, nrows, ncols
    real(real32), intent(in) :: weight_data(:)
    type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
    integer, intent(inout) :: num_inits

    num_inits = num_inits + 1
    inits(num_inits)%name = trim(name)
    inits(num_inits)%data_type = 1
    allocate(inits(num_inits)%dims(3))
    inits(num_inits)%dims = [ nslices, nrows, ncols ]
    allocate(inits(num_inits)%data(size(weight_data)))

    ! Transpose each 2D slice from column-major to row-major before export.
    block
      integer :: d, slice_size
      slice_size = nrows * ncols
      do d = 1, nslices
         call col_to_row_major_2d( &
              weight_data((d-1)*slice_size+1 : d*slice_size), &
              inits(num_inits)%data((d-1)*slice_size+1 : d*slice_size), &
              nrows, ncols)
      end do
    end block

  end subroutine emit_weight_initialiser_3d
!###############################################################################

end module athena__onnx_msgpass_utils