gno_kernel_eval Module Function

module function gno_kernel_eval(coords, kernel_params, adj_ia, adj_ja, coord_dim, kernel_hidden, F_in, F_out) result(c)

Evaluate the GNO kernel MLP on every directed edge in the graph.

For each edge feature column e, compute: dx = edge_features(:,e) [d] hidden = relu( U @ dx + b_u ) [H] kappa_e = V @ hidden + b_v [F_out*F_in]

Kernel params layout (flat column, size Hd + H + FH + F): U : params(1 : Hd) b_u : params(Hd+1 : Hd+H) V : params(Hd+H+1 : Hd+H+FH) b_v : params(Hd+H+FH+1 : end)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: coords

Edge features / relative coordinates [d, num_edges]

class(array_type), intent(in), target :: kernel_params

Packed kernel parameters [Hd + H + FH + F, 1]

integer, intent(in), dimension(:) :: adj_ia

CSR row pointers (size num_vertices + 1)

integer, intent(in), dimension(:,:) :: adj_ja

CSR column indices (adj_ja(1,:) = neighbour index)

integer, intent(in) :: coord_dim

Metadata for unpacking kernel_params

integer, intent(in) :: kernel_hidden

Metadata for unpacking kernel_params

integer, intent(in) :: F_in

Metadata for unpacking kernel_params

integer, intent(in) :: F_out

Metadata for unpacking kernel_params

Return Value type(array_type), pointer

Output per-edge kernel values


Source Code

  module function gno_kernel_eval( &
       coords, kernel_params, adj_ia, adj_ja, &
       coord_dim, kernel_hidden, F_in, F_out &
  ) result(c)
    !! Evaluate the GNO kernel MLP on every directed edge in the graph.
    !!
    !! For each edge feature column e, compute:
    !!   dx      = edge_features(:,e)                 [d]
    !!   hidden  = relu( U @ dx + b_u )                [H]
    !!   kappa_e = V @ hidden + b_v                    [F_out*F_in]
    !!
    !! Kernel params layout (flat column, size H*d + H + F*H + F):
    !!   U   : params(1 : H*d)
    !!   b_u : params(H*d+1 : H*d+H)
    !!   V   : params(H*d+H+1 : H*d+H+F*H)
    !!   b_v : params(H*d+H+F*H+1 : end)
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: coords
    !! Edge features / relative coordinates [d, num_edges]
    class(array_type), intent(in), target :: kernel_params
    !! Packed kernel parameters [H*d + H + F*H + F, 1]
    integer, dimension(:), intent(in)  :: adj_ia
    !! CSR row pointers (size num_vertices + 1)
    integer, dimension(:,:), intent(in) :: adj_ja
    !! CSR column indices (adj_ja(1,:) = neighbour index)
    integer, intent(in) :: coord_dim, kernel_hidden, F_in, F_out
    !! Metadata for unpacking kernel_params
    type(array_type), pointer :: c
    !! Output per-edge kernel values

    ! Local variables
    integer :: num_e, d, H, F, e
    !! Edge count, unpacked dimensions and edge loop index
    integer :: off_U, off_bu, off_V, off_bv
    !! Flat offsets for packed kernel parameter blocks
    real(real32), allocatable :: U(:,:), b_u(:), V(:,:), b_v(:)
    !! Unpacked kernel parameter tensors
    real(real32), allocatable :: dx(:), hidden(:)
    !! Per-edge coordinate and hidden activation buffers

    d = coord_dim
    H = kernel_hidden
    F = F_out * F_in        ! kernel output width
    num_e = size(coords%val, 2)

    ! ---- Unpack kernel params ------------------------------------------------
    off_U  = 0
    off_bu = H * d
    off_V  = off_bu + H
    off_bv = off_V + F * H

    allocate(U(H, d)); U = reshape(kernel_params%val(off_U+1:off_bu, 1), [H, d])
    allocate(b_u(H));  b_u = kernel_params%val(off_bu+1:off_V, 1)
    allocate(V(F, H)); V = reshape(kernel_params%val(off_V+1:off_bv, 1), [F, H])
    allocate(b_v(F));  b_v = kernel_params%val(off_bv+1:, 1)

    ! ---- Forward: evaluate kernel on every edge ------------------------------
    c => coords%create_result(array_shape=[F, num_e])
    allocate(dx(d), hidden(H))

    do e = 1, num_e
       dx = coords%val(:, e)
       hidden = matmul(U, dx) + b_u
       hidden = max(hidden, 0.0_real32)          ! ReLU
       c%val(:, e) = matmul(V, hidden) + b_v
    end do

    deallocate(dx, hidden, U, b_u, V, b_v)

    ! ---- Store metadata for backward -----------------------------------------
    allocate(c%indices(4))
    c%indices = [d, H, F_in, F_out]

    c%get_partial_left     => get_partial_gno_kernel_coords
    c%get_partial_right    => get_partial_gno_kernel_params
    c%get_partial_left_val => get_partial_gno_kernel_coords_val
    c%get_partial_right_val => get_partial_gno_kernel_params_val
    if(coords%requires_grad .or. kernel_params%requires_grad)then
       c%requires_grad    = .true.
       c%is_forward       = coords%is_forward .or. kernel_params%is_forward
       c%operation        = 'gno_kernel_eval'
       c%left_operand     => coords
       c%right_operand    => kernel_params
       c%owns_left_operand  = coords%is_temporary
       c%owns_right_operand = kernel_params%is_temporary
    end if

  end function gno_kernel_eval