get_partial_duvenaud_update_weight_val Subroutine

pure subroutine get_partial_duvenaud_update_weight_val(this, upstream_grad, output)

In-place value gradient for duvenaud_update packed weights.

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this

Forward result node containing saved operands

real(kind=real32), intent(in), dimension(:,:) :: upstream_grad

Upstream gradient values

real(kind=real32), intent(out), dimension(:,:) :: output

Output gradient values for packed weights


Source Code

  pure subroutine get_partial_duvenaud_update_weight_val( &
       this, upstream_grad, output &
  )
    !! In-place value gradient for duvenaud_update packed weights.
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    !! Forward result node containing saved operands
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    !! Upstream gradient values
    real(real32), dimension(:,:), intent(out) :: output
    !! Output gradient values for packed weights

    ! Local variables
    integer :: v, i, j, d_offset, d_val
    !! Loop indices, degree offset and degree bucket index
    integer :: interval, num_output_features, num_input_features
    !! Flattening interval and matrix dimensions
    integer :: min_degree, max_degree
    !! Degree bucket limits

    output = 0._real32
    num_output_features = size(upstream_grad,1)
    num_input_features = this%right_operand%shape(1)
    interval = num_output_features * num_input_features
    min_degree = this%left_operand%indices(1)
    max_degree = this%left_operand%indices(2)
    do concurrent(v=1:size(upstream_grad,2))
       d_val = max( &
            min_degree, &
            min(this%indices(v+1) - this%indices(v), max_degree ) &
       ) - min_degree + 1
       d_offset = (d_val - 1) * interval
       do concurrent(i = 1:num_output_features, j=1:num_input_features)
          output(d_offset+i+num_output_features*(j-1),1) = &
               output(d_offset+i+num_output_features*(j-1),1) + &
               upstream_grad(i,v) * this%right_operand%val(j,v) / &
               real(d_val, real32)
       end do
    end do

  end subroutine get_partial_duvenaud_update_weight_val