get_partial_duvenaud_update_val Subroutine

pure subroutine get_partial_duvenaud_update_val(this, upstream_grad, output)

In-place value gradient for duvenaud_update input features.

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this

Forward result node containing saved operands

real(kind=real32), intent(in), dimension(:,:) :: upstream_grad

Upstream gradient values

real(kind=real32), intent(out), dimension(:,:) :: output

Output gradient values for input features


Source Code

  pure subroutine get_partial_duvenaud_update_val( &
       this, upstream_grad, output &
  )
    !! In-place value gradient for duvenaud_update input features.
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    !! Forward result node containing saved operands
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    !! Upstream gradient values
    real(real32), dimension(:,:), intent(out) :: output
    !! Output gradient values for input features

    ! Local variables
    integer :: v, d
    !! Loop index and degree bucket index
    integer :: interval, num_output_features, num_input_features
    !! Flattening interval and matrix dimensions
    integer :: min_degree, max_degree
    !! Degree bucket limits
    real(real32), dimension(size(upstream_grad,1), this%right_operand%shape(1)) :: tmp
    !! Temporary reshaped weight matrix for one degree bucket

    output = 0._real32
    num_output_features = size(upstream_grad,1)
    num_input_features = this%right_operand%shape(1)
    interval = num_output_features * num_input_features
    min_degree = this%left_operand%indices(1)
    max_degree = this%left_operand%indices(2)
    do concurrent(v=1:size(upstream_grad,2))
       d = max( &
            min_degree, &
            min(this%indices(v+1) - this%indices(v), max_degree ) &
       ) - min_degree + 1
       tmp = reshape(this%left_operand%val((d-1)*interval+1:d*interval,1), &
            [num_output_features, num_input_features] )
       output(:,v) = matmul(upstream_grad(:,v), tmp) / real(d, real32)
    end do

  end subroutine get_partial_duvenaud_update_val