duvenaud_update Module Function

module function duvenaud_update(a, weight, adj_ia, min_degree, max_degree) result(c)

Update the message passing layer

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: a

Aggregated neighbour features

class(array_type), intent(in), target :: weight

Packed degree-conditioned weight tensor

integer, intent(in), dimension(:) :: adj_ia

CSR row pointers

integer, intent(in) :: min_degree

Minimum and maximum degree buckets

integer, intent(in) :: max_degree

Minimum and maximum degree buckets

Return Value type(array_type), pointer

Degree-conditioned updated feature tensor


Source Code

  module function duvenaud_update(a, weight, adj_ia, min_degree, max_degree) result(c)
    !! Update the message passing layer
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: a
    !! Aggregated neighbour features
    class(array_type), intent(in), target :: weight
    !! Packed degree-conditioned weight tensor
    ! real(real32), dimension(:,:,:), intent(in) :: weight
    integer, dimension(:), intent(in) :: adj_ia
    !! CSR row pointers
    integer, intent(in) :: min_degree, max_degree
    !! Minimum and maximum degree buckets
    type(array_type), pointer :: c
    !! Degree-conditioned updated feature tensor
    type(array_type), pointer :: weight_array
    !! Reserved pointer for weight reshaping operations

    ! Local variables
    integer :: v, i, d
    !! Loop indices and degree bucket index
    integer :: interval
    !! Flat parameter interval for one degree bucket
    real(real32), pointer :: w_ptr(:,:)
    !! 2D view over selected degree-specific weight matrix

    c => a%create_result(array_shape=[weight%shape(1), size(a%val,2)])
    interval = weight%shape(1) * weight%shape(2)
    do v = 1, size(a%val,2)
       d = max( min_degree, min( adj_ia(v+1) - adj_ia(v), max_degree ) ) - &
            min_degree + 1
       w_ptr(1:weight%shape(1), 1:weight%shape(2)) => &
            weight%val(interval*(d-1)+1:interval*d,1)
       c%val(:,v) = matmul(w_ptr, a%val(:,v) / real(d, real32))
    end do
    c%indices = adj_ia

    c%get_partial_left => get_partial_duvenaud_update_weight
    c%get_partial_right => get_partial_duvenaud_update
    c%get_partial_left_val => get_partial_duvenaud_update_weight_val
    c%get_partial_right_val => get_partial_duvenaud_update_val
    if(a%requires_grad .or. weight%requires_grad)then
       c%requires_grad = .true.
       c%is_forward = a%is_forward .or. weight%is_forward
       c%operation = 'duvenaud_update'
       c%right_operand => a
       c%left_operand => weight
       c%owns_right_operand = a%is_temporary
       c%owns_left_operand = weight%is_temporary
    end if

  end function duvenaud_update