get_partial_conv3d_input_val Subroutine

pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output)

Get partial derivative wrt input for 3D convolution (subroutine version)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output)
    !! Get partial derivative wrt input for 3D convolution (subroutine version)
    implicit none

    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, k, ki, kj, kk, c_in, c_out, s
    integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
    integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
    integer :: output_h, output_w, output_d
    integer :: num_channels, num_filters
    integer :: channel_size_in, channel_size_out
    integer, dimension(3) :: stride, dilation
    real(real32) :: grad_val, kernel_val


    ! Unpack parameters
    num_channels = this%indices(1)
    num_filters = this%indices(2)
    stride = this%adj_ja(1:3,1)
    dilation = this%adj_ja(1:3,2)
    kernel_h = this%adj_ja(1,3)
    kernel_w = this%adj_ja(2,3)
    kernel_d = this%adj_ja(3,3)

    input_h = this%left_operand%shape(1)
    input_w = this%left_operand%shape(2)
    input_d = this%left_operand%shape(3)
    output_h = this%shape(1)
    output_w = this%shape(2)
    output_d = this%shape(3)

    output = 0._real32

    channel_size_in = input_h * input_w * input_d
    channel_size_out = output_h * output_w * output_d

    do s = 1, size(upstream_grad, dim=2)
       do c_in = 1, num_channels
          do k = 1, output_d
             do j = 1, output_w
                do i = 1, output_h
                   do c_out = 1, num_filters
                      out_idx = i + ( j - 1 ) * output_h + &
                           ( k - 1 ) * output_h * output_w + &
                           ( c_out - 1 ) * channel_size_out
                      grad_val = upstream_grad(out_idx, s)

                      do kk = 1, kernel_d
                         k_in = ( k - 1 ) * stride(3) + &
                              ( kk - 1 ) * dilation(3) + 1
                         if( k_in .ge. 1 .and. k_in .le. input_d )then
                            do kj = 1, kernel_w
                               j_in = ( j - 1 ) * stride(2) + &
                                    ( kj - 1 ) * dilation(2) + 1
                               if( j_in .ge. 1 .and. j_in .le. input_w )then
                                  do ki = 1, kernel_h
                                     i_in = ( i - 1 ) * stride(1) + &
                                          ( ki - 1 ) * dilation(1) + 1
                                     if( i_in .ge. 1 .and. &
                                          i_in .le. input_h )then
                                        in_idx = i_in + &
                                             ( j_in - 1 ) * input_h + &
                                             ( k_in - 1 ) * input_h * input_w + &
                                             ( c_in - 1 ) * channel_size_in
                                        k_idx = ki + ( kj - 1 ) * kernel_h + &
                                             ( kk - 1 ) * kernel_h * kernel_w + &
                                             ( c_in - 1 ) * kernel_h * &
                                             kernel_w * kernel_d + &
                                             ( c_out - 1 ) * kernel_h * &
                                             kernel_w * kernel_d * num_channels
                                        kernel_val = this%right_operand%val(k_idx, 1)
                                        output(in_idx, s) = &
                                             output(in_idx, s) + &
                                             grad_val * kernel_val
                                     end if
                                  end do
                               end if
                            end do
                         end if
                      end do
                   end do
                end do
             end do
          end do
       end do
    end do

  end subroutine get_partial_conv3d_input_val