Evaluate the GNO kernel MLP on every directed edge in the graph.
For each edge feature column e, compute: dx = edge_features(:,e) [d] hidden = relu( U @ dx + b_u ) [H] kappa_e = V @ hidden + b_v [F_out*F_in]
Kernel params layout (flat column, size Hd + H + FH + F): U : params(1 : Hd) b_u : params(Hd+1 : Hd+H) V : params(Hd+H+1 : Hd+H+FH) b_v : params(Hd+H+FH+1 : end)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in), | target | :: | coords |
Edge features / relative coordinates [d, num_edges] |
|
| class(array_type), | intent(in), | target | :: | kernel_params |
Packed kernel parameters [Hd + H + FH + F, 1] |
|
| integer, | intent(in), | dimension(:) | :: | adj_ia |
CSR row pointers (size num_vertices + 1) |
|
| integer, | intent(in), | dimension(:,:) | :: | adj_ja |
CSR column indices (adj_ja(1,:) = neighbour index) |
|
| integer, | intent(in) | :: | coord_dim |
Metadata for unpacking kernel_params |
||
| integer, | intent(in) | :: | kernel_hidden |
Metadata for unpacking kernel_params |
||
| integer, | intent(in) | :: | F_in |
Metadata for unpacking kernel_params |
||
| integer, | intent(in) | :: | F_out |
Metadata for unpacking kernel_params |
Output per-edge kernel values
module function gno_kernel_eval( & coords, kernel_params, adj_ia, adj_ja, & coord_dim, kernel_hidden, F_in, F_out & ) result(c) !! Evaluate the GNO kernel MLP on every directed edge in the graph. !! !! For each edge feature column e, compute: !! dx = edge_features(:,e) [d] !! hidden = relu( U @ dx + b_u ) [H] !! kappa_e = V @ hidden + b_v [F_out*F_in] !! !! Kernel params layout (flat column, size H*d + H + F*H + F): !! U : params(1 : H*d) !! b_u : params(H*d+1 : H*d+H) !! V : params(H*d+H+1 : H*d+H+F*H) !! b_v : params(H*d+H+F*H+1 : end) implicit none ! Arguments class(array_type), intent(in), target :: coords !! Edge features / relative coordinates [d, num_edges] class(array_type), intent(in), target :: kernel_params !! Packed kernel parameters [H*d + H + F*H + F, 1] integer, dimension(:), intent(in) :: adj_ia !! CSR row pointers (size num_vertices + 1) integer, dimension(:,:), intent(in) :: adj_ja !! CSR column indices (adj_ja(1,:) = neighbour index) integer, intent(in) :: coord_dim, kernel_hidden, F_in, F_out !! Metadata for unpacking kernel_params type(array_type), pointer :: c !! Output per-edge kernel values ! Local variables integer :: num_e, d, H, F, e !! Edge count, unpacked dimensions and edge loop index integer :: off_U, off_bu, off_V, off_bv !! Flat offsets for packed kernel parameter blocks real(real32), allocatable :: U(:,:), b_u(:), V(:,:), b_v(:) !! Unpacked kernel parameter tensors real(real32), allocatable :: dx(:), hidden(:) !! Per-edge coordinate and hidden activation buffers d = coord_dim H = kernel_hidden F = F_out * F_in ! kernel output width num_e = size(coords%val, 2) ! ---- Unpack kernel params ------------------------------------------------ off_U = 0 off_bu = H * d off_V = off_bu + H off_bv = off_V + F * H allocate(U(H, d)); U = reshape(kernel_params%val(off_U+1:off_bu, 1), [H, d]) allocate(b_u(H)); b_u = kernel_params%val(off_bu+1:off_V, 1) allocate(V(F, H)); V = reshape(kernel_params%val(off_V+1:off_bv, 1), [F, H]) allocate(b_v(F)); b_v = kernel_params%val(off_bv+1:, 1) ! ---- Forward: evaluate kernel on every edge ------------------------------ c => coords%create_result(array_shape=[F, num_e]) allocate(dx(d), hidden(H)) do e = 1, num_e dx = coords%val(:, e) hidden = matmul(U, dx) + b_u hidden = max(hidden, 0.0_real32) ! ReLU c%val(:, e) = matmul(V, hidden) + b_v end do deallocate(dx, hidden, U, b_u, V, b_v) ! ---- Store metadata for backward ----------------------------------------- allocate(c%indices(4)) c%indices = [d, H, F_in, F_out] c%get_partial_left => get_partial_gno_kernel_coords c%get_partial_right => get_partial_gno_kernel_params c%get_partial_left_val => get_partial_gno_kernel_coords_val c%get_partial_right_val => get_partial_gno_kernel_params_val if(coords%requires_grad .or. kernel_params%requires_grad)then c%requires_grad = .true. c%is_forward = coords%is_forward .or. kernel_params%is_forward c%operation = 'gno_kernel_eval' c%left_operand => coords c%right_operand => kernel_params c%owns_left_operand = coords%is_temporary c%owns_right_operand = kernel_params%is_temporary end if end function gno_kernel_eval