get_partial_conv3d_kernel_val Subroutine

pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output)

Get partial derivative wrt kernel for 3D convolution (subroutine version)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output)
    !! Get partial derivative wrt kernel for 3D convolution (subroutine version)
    implicit none

    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, k, ki, kj, kk, c_in, c_out, s
    integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
    integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
    integer :: output_h, output_w, output_d
    integer :: num_channels, num_filters
    integer :: channel_size_in, channel_size_out
    integer, dimension(3) :: stride, dilation
    real(real32) :: grad_sum


    ! Unpack parameters
    num_channels = this%indices(1)
    num_filters = this%indices(2)
    stride = this%adj_ja(1:3,1)
    dilation = this%adj_ja(1:3,2)
    kernel_h = this%adj_ja(1,3)
    kernel_w = this%adj_ja(2,3)
    kernel_d = this%adj_ja(3,3)

    input_h = this%left_operand%shape(1)
    input_w = this%left_operand%shape(2)
    input_d = this%left_operand%shape(3)
    output_h = this%shape(1)
    output_w = this%shape(2)
    output_d = this%shape(3)

    output = 0._real32

    channel_size_in = input_h * input_w * input_d
    channel_size_out = output_h * output_w * output_d

    do c_out = 1, num_filters
       do c_in = 1, num_channels
          do kk = 1, kernel_d
             do kj = 1, kernel_w
                do ki = 1, kernel_h
                   k_idx = ki + (kj-1)*kernel_h + &
                        (kk-1)*kernel_h*kernel_w + &
                        (c_in-1)*kernel_h*kernel_w*kernel_d + &
                        (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels

                   grad_sum = 0._real32
                   do s = 1, size(upstream_grad, dim=2)
                      do k = 1, output_d
                         k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1
                         if(k_in .ge. 1 .and. k_in .le. input_d)then
                            do j = 1, output_w
                               j_in = (j-1)*stride(2) + &
                                    (kj-1)*dilation(2) + 1
                               if(j_in .ge. 1 .and. j_in .le. input_w)then
                                  do i = 1, output_h
                                     i_in = (i-1)*stride(1) + &
                                          (ki-1)*dilation(1) + 1
                                     if(i_in .ge. 1 .and. i_in .le. input_h)then
                                        in_idx = i_in + (j_in-1)*input_h + &
                                             (k_in-1)*input_h*input_w + &
                                             (c_in-1)*channel_size_in
                                        out_idx = i + (j-1)*output_h + &
                                             (k-1)*output_h*output_w + &
                                             (c_out-1)*channel_size_out
                                        grad_sum = grad_sum + &
                                             upstream_grad(out_idx, s) * &
                                             this%left_operand%val(in_idx, s)
                                     end if
                                  end do
                               end if
                            end do
                         end if
                      end do
                   end do
                   output(k_idx, 1) = grad_sum
                end do
             end do
          end do
       end do
    end do

  end subroutine get_partial_conv3d_kernel_val