get_partial_conv2d_input_val Subroutine

pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output)

Get partial derivative wrt input for 2D convolution (subroutine version)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output)
    !! Get partial derivative wrt input for 2D convolution (subroutine version)
    implicit none

    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, ki, kj, c_in, c_out, s
    integer :: i_in, j_in, k_idx, out_idx, in_idx
    integer :: in_base_idx, k_base_idx, kernel_channel_size
    integer :: input_h, input_w, kernel_h, kernel_w
    integer :: output_h, output_w, num_channels, num_filters
    integer, dimension(2) :: stride, dilation
    integer :: channel_size_in, channel_size_out
    real(real32) :: grad_val, kernel_val


    ! Unpack parameters
    num_channels = this%indices(1)
    num_filters = this%indices(2)
    stride = this%adj_ja(1:2,1)
    dilation = this%adj_ja(1:2,2)
    kernel_h = this%adj_ja(1,3)
    kernel_w = this%adj_ja(2,3)

    input_h = this%left_operand%shape(1)
    input_w = this%left_operand%shape(2)
    output_h = this%shape(1)
    output_w = this%shape(2)
    channel_size_in  = input_h * input_w
    channel_size_out = output_h * output_w
    kernel_channel_size = kernel_h * kernel_w

    output = 0._real32

    ! Parallelised over batch, output channels, output spatial dims
    do concurrent(s = 1:size(upstream_grad, dim=2), c_out = 1:num_filters, &
         j = 1:output_w, i = 1:output_h)
       out_idx = i + (j-1)*output_h + (c_out-1)*channel_size_out
       grad_val = upstream_grad(out_idx, s)

       if(abs(grad_val) .gt. 1.e-30_real32)then
          do c_in = 1, num_channels
             in_base_idx = (c_in - 1) * channel_size_in
             k_base_idx = (c_in - 1) * kernel_channel_size + &
                  (c_out - 1) * kernel_channel_size * num_channels

             do kj = 1, kernel_w
                j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
                if(j_in .ge. 1 .and. j_in .le. input_w)then
                   do ki = 1, kernel_h
                      i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
                      if(i_in .ge. 1 .and. i_in .le. input_h)then
                         in_idx = i_in + (j_in - 1) * input_h + in_base_idx
                         k_idx = (kernel_h - ki + 1) + &
                              (kernel_w - kj) * kernel_h + k_base_idx
                         kernel_val = this%right_operand%val(k_idx, 1)
                         output(in_idx, s) = output(in_idx, s) + &
                              grad_val * kernel_val
                      end if
                   end do
                end if
             end do
          end do
       end if
    end do

  end subroutine get_partial_conv2d_input_val