Aggregate neighbour messages using pre-computed per-edge kernels.
For each node i: m_i = sum_{j in N(i)} reshape(kappa_e, [F_out, F_in]) @ h_j
where e is the edge index corresponding to (i, j).
left_operand → features [F_in, num_vertices] right_operand → edge_kernels [F_out*F_in, num_edges] output → [F_out, num_vertices]
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in), | target | :: | features |
Node features [F_in, num_vertices] |
|
| class(array_type), | intent(in), | target | :: | edge_kernels |
Per-edge kernel values [F_out*F_in, num_edges] |
|
| integer, | intent(in), | dimension(:) | :: | adj_ia |
CSR row pointers |
|
| integer, | intent(in), | dimension(:,:) | :: | adj_ja |
CSR column indices |
|
| integer, | intent(in) | :: | F_in |
Feature dimensions |
||
| integer, | intent(in) | :: | F_out |
Feature dimensions |
Aggregated node output tensor
module function gno_aggregate( & features, edge_kernels, adj_ia, adj_ja, F_in, F_out & ) result(c) !! Aggregate neighbour messages using pre-computed per-edge kernels. !! !! For each node i: !! m_i = sum_{j in N(i)} reshape(kappa_e, [F_out, F_in]) @ h_j !! !! where e is the edge index corresponding to (i, j). !! !! left_operand → features [F_in, num_vertices] !! right_operand → edge_kernels [F_out*F_in, num_edges] !! output → [F_out, num_vertices] implicit none ! Arguments class(array_type), intent(in), target :: features !! Node features [F_in, num_vertices] class(array_type), intent(in), target :: edge_kernels !! Per-edge kernel values [F_out*F_in, num_edges] integer, dimension(:), intent(in) :: adj_ia !! CSR row pointers integer, dimension(:,:), intent(in) :: adj_ja !! CSR column indices integer, intent(in) :: F_in, F_out !! Feature dimensions type(array_type), pointer :: c !! Aggregated node output tensor ! Local variables integer :: num_v, i, j, jj, edge_idx !! Node/edge traversal indices num_v = size(features%val, 2) c => features%create_result(array_shape=[F_out, num_v]) c%val = 0.0_real32 do i = 1, num_v do jj = adj_ia(i), adj_ia(i+1) - 1 j = adj_ja(1, jj) edge_idx = adj_ja(2, jj) ! kappa_e reshaped to [F_out, F_in], multiplied by h_j [F_in] c%val(:, i) = c%val(:, i) + & matmul( & reshape(edge_kernels%val(:, edge_idx), [F_out, F_in]), & features%val(:, j) & ) end do end do c%indices = adj_ia c%adj_ja = adj_ja c%get_partial_left => get_partial_gno_agg_features c%get_partial_right => get_partial_gno_agg_kernels c%get_partial_left_val => get_partial_gno_agg_features_val c%get_partial_right_val => get_partial_gno_agg_kernels_val if(features%requires_grad .or. edge_kernels%requires_grad)then c%requires_grad = .true. c%is_forward = features%is_forward .or. edge_kernels%is_forward c%operation = 'gno_aggregate' c%left_operand => features c%right_operand => edge_kernels c%owns_left_operand = features%is_temporary c%owns_right_operand = edge_kernels%is_temporary end if end function gno_aggregate