get_partial_gno_agg_kernels_val Subroutine

pure subroutine get_partial_gno_agg_kernels_val(this, upstream_grad, output)

In-place gradient w.r.t. edge_kernels

The aggregation is: m_i += reshape(kappa_e,[F_out,F_in]) @ h_j So d(m_i)/d(kappa_e) viewed as reshape: grad_kappa_e = vec( upstream(:,i) @ h_j^T )

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this

Forward result node containing saved operands

real(kind=real32), intent(in), dimension(:,:) :: upstream_grad

Upstream gradient values

real(kind=real32), intent(out), dimension(:,:) :: output

Output gradient values for edge kernels


Source Code

  pure subroutine get_partial_gno_agg_kernels_val( &
       this, upstream_grad, output)
    !! In-place gradient w.r.t. edge_kernels
    !!
    !! The aggregation is: m_i += reshape(kappa_e,[F_out,F_in]) @ h_j
    !! So d(m_i)/d(kappa_e) viewed as reshape:
    !!   grad_kappa_e = vec( upstream(:,i) @ h_j^T )
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    !! Forward result node containing saved operands
    real(real32), dimension(:,:), intent(in)  :: upstream_grad
    !! Upstream gradient values
    real(real32), dimension(:,:), intent(out) :: output
    !! Output gradient values for edge kernels

    ! Local variables
    integer :: F_in, F_out, num_v, i, j, jj, edge_idx
    !! Inferred dimensions and traversal indices
    integer :: fo, fi
    !! Feature indices for flattened kernel layout

    ! Infer dimensions from operands
    F_in  = size(this%left_operand%val, 1)
    F_out = size(upstream_grad, 1)
    num_v = size(this%left_operand%val, 2)

    output = 0.0_real32
    do i = 1, num_v
       do jj = this%indices(i), this%indices(i+1) - 1
          j = this%adj_ja(1, jj)
          edge_idx = this%adj_ja(2, jj)
          ! kappa_e is stored as vec(K) where K = reshape(kappa_e, [F_out, F_in])
          ! d(m_i)/d(K(fo,fi)) = upstream(fo, i) * h(fi, j)
          ! vec index: (fi-1)*F_out + fo
          do fi = 1, F_in
             do fo = 1, F_out
                output((fi-1)*F_out + fo, edge_idx) = &
                     output((fi-1)*F_out + fo, edge_idx) + &
                     upstream_grad(fo, i) * this%left_operand%val(fi, j)
             end do
          end do
       end do
    end do

  end subroutine get_partial_gno_agg_kernels_val