get_partial_pad2d_val Subroutine

pure subroutine get_partial_pad2d_val(this, upstream_grad, output)

Get the partial derivative for the pad2d operation - raw array version

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_pad2d_val(this, upstream_grad, output)
    !! Get the partial derivative for the pad2d operation - raw array version
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, m, s
    integer :: idx_in, idx_out
    integer :: input_size_h, input_size_w, num_channels
    integer :: output_size_h, output_size_w
    integer :: num_samples
    integer, dimension(4) :: input_shape

    input_shape = [ this%left_operand%shape, size(upstream_grad, dim=2) ]
    num_samples = input_shape(4)
    input_size_h = input_shape(1)
    input_size_w = input_shape(2)
    num_channels = input_shape(3)
    output_size_h = input_size_h + 2 * this%indices(2)
    output_size_w = input_size_w + 2 * this%indices(3)

    output = 0._real32

    ! Main gradient extraction
    do concurrent( &
         s = 1:num_samples, &
         m = 1:num_channels, &
         j = 1:input_size_w, &
         i = 1:input_size_h)
       idx_in = i + (j-1) * input_size_h + (m-1) * input_size_h * input_size_w
       idx_out = (i + this%indices(2)) + &
            (j + this%indices(3) - 1) * output_size_h + &
            (m-1) * output_size_h * output_size_w
       output(idx_in, s) = upstream_grad(idx_out, s)
    end do

    ! Handle corner and edge gradients for special padding modes
    if(this%indices(1) .ge. 3 .and. this%indices(1) .le. 5)then
       call accumulate_corner_gradients_2d_val( &
            upstream_grad, output, input_shape, this%indices, this%adj_ja &
       )
       call accumulate_edge_gradients_2d_val( &
            upstream_grad, output, input_shape, this%indices, this%adj_ja &
       )
    end if

  end subroutine get_partial_pad2d_val