gno_aggregate Module Function

module function gno_aggregate(features, edge_kernels, adj_ia, adj_ja, F_in, F_out) result(c)

Aggregate neighbour messages using pre-computed per-edge kernels.

For each node i: m_i = sum_{j in N(i)} reshape(kappa_e, [F_out, F_in]) @ h_j

where e is the edge index corresponding to (i, j).

left_operand → features [F_in, num_vertices] right_operand → edge_kernels [F_out*F_in, num_edges] output → [F_out, num_vertices]

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: features

Node features [F_in, num_vertices]

class(array_type), intent(in), target :: edge_kernels

Per-edge kernel values [F_out*F_in, num_edges]

integer, intent(in), dimension(:) :: adj_ia

CSR row pointers

integer, intent(in), dimension(:,:) :: adj_ja

CSR column indices

integer, intent(in) :: F_in

Feature dimensions

integer, intent(in) :: F_out

Feature dimensions

Return Value type(array_type), pointer

Aggregated node output tensor


Source Code

  module function gno_aggregate( &
       features, edge_kernels, adj_ia, adj_ja, F_in, F_out &
  ) result(c)
    !! Aggregate neighbour messages using pre-computed per-edge kernels.
    !!
    !! For each node i:
    !!   m_i = sum_{j in N(i)} reshape(kappa_e, [F_out, F_in]) @ h_j
    !!
    !! where e is the edge index corresponding to (i, j).
    !!
    !! left_operand  → features      [F_in, num_vertices]
    !! right_operand → edge_kernels  [F_out*F_in, num_edges]
    !! output        → [F_out, num_vertices]
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: features
    !! Node features [F_in, num_vertices]
    class(array_type), intent(in), target :: edge_kernels
    !! Per-edge kernel values [F_out*F_in, num_edges]
    integer, dimension(:), intent(in)  :: adj_ia
    !! CSR row pointers
    integer, dimension(:,:), intent(in) :: adj_ja
    !! CSR column indices
    integer, intent(in) :: F_in, F_out
    !! Feature dimensions
    type(array_type), pointer :: c
    !! Aggregated node output tensor

    ! Local variables
    integer :: num_v, i, j, jj, edge_idx
    !! Node/edge traversal indices

    num_v = size(features%val, 2)
    c => features%create_result(array_shape=[F_out, num_v])
    c%val = 0.0_real32

    do i = 1, num_v
       do jj = adj_ia(i), adj_ia(i+1) - 1
          j = adj_ja(1, jj)
          edge_idx = adj_ja(2, jj)
          ! kappa_e reshaped to [F_out, F_in], multiplied by h_j [F_in]
          c%val(:, i) = c%val(:, i) + &
               matmul( &
                    reshape(edge_kernels%val(:, edge_idx), [F_out, F_in]), &
                    features%val(:, j) &
               )
       end do
    end do

    c%indices = adj_ia
    c%adj_ja  = adj_ja

    c%get_partial_left     => get_partial_gno_agg_features
    c%get_partial_right    => get_partial_gno_agg_kernels
    c%get_partial_left_val => get_partial_gno_agg_features_val
    c%get_partial_right_val => get_partial_gno_agg_kernels_val
    if(features%requires_grad .or. edge_kernels%requires_grad)then
       c%requires_grad    = .true.
       c%is_forward       = features%is_forward .or. edge_kernels%is_forward
       c%operation        = 'gno_aggregate'
       c%left_operand     => features
       c%right_operand    => edge_kernels
       c%owns_left_operand  = features%is_temporary
       c%owns_right_operand = edge_kernels%is_temporary
    end if

  end function gno_aggregate