Update the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in), | target | :: | a |
Aggregated neighbour features |
|
| class(array_type), | intent(in), | target | :: | weight |
Packed degree-conditioned weight tensor |
|
| integer, | intent(in), | dimension(:) | :: | adj_ia |
CSR row pointers |
|
| integer, | intent(in) | :: | min_degree |
Minimum and maximum degree buckets |
||
| integer, | intent(in) | :: | max_degree |
Minimum and maximum degree buckets |
Degree-conditioned updated feature tensor
module function duvenaud_update(a, weight, adj_ia, min_degree, max_degree) result(c) !! Update the message passing layer implicit none ! Arguments class(array_type), intent(in), target :: a !! Aggregated neighbour features class(array_type), intent(in), target :: weight !! Packed degree-conditioned weight tensor ! real(real32), dimension(:,:,:), intent(in) :: weight integer, dimension(:), intent(in) :: adj_ia !! CSR row pointers integer, intent(in) :: min_degree, max_degree !! Minimum and maximum degree buckets type(array_type), pointer :: c !! Degree-conditioned updated feature tensor type(array_type), pointer :: weight_array !! Reserved pointer for weight reshaping operations ! Local variables integer :: v, i, d !! Loop indices and degree bucket index integer :: interval !! Flat parameter interval for one degree bucket real(real32), pointer :: w_ptr(:,:) !! 2D view over selected degree-specific weight matrix c => a%create_result(array_shape=[weight%shape(1), size(a%val,2)]) interval = weight%shape(1) * weight%shape(2) do v = 1, size(a%val,2) d = max( min_degree, min( adj_ia(v+1) - adj_ia(v), max_degree ) ) - & min_degree + 1 w_ptr(1:weight%shape(1), 1:weight%shape(2)) => & weight%val(interval*(d-1)+1:interval*d,1) c%val(:,v) = matmul(w_ptr, a%val(:,v) / real(d, real32)) end do c%indices = adj_ia c%get_partial_left => get_partial_duvenaud_update_weight c%get_partial_right => get_partial_duvenaud_update c%get_partial_left_val => get_partial_duvenaud_update_weight_val c%get_partial_right_val => get_partial_duvenaud_update_val if(a%requires_grad .or. weight%requires_grad)then c%requires_grad = .true. c%is_forward = a%is_forward .or. weight%is_forward c%operation = 'duvenaud_update' c%right_operand => a c%left_operand => weight c%owns_right_operand = a%is_temporary c%owns_left_operand = weight%is_temporary end if end function duvenaud_update