Get partial derivative wrt kernel for 3D convolution (subroutine version)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:,:) | :: | output |
pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output) !! Get partial derivative wrt kernel for 3D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, k, ki, kj, kk, c_in, c_out, s integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d integer :: output_h, output_w, output_d integer :: num_channels, num_filters integer :: channel_size_in, channel_size_out integer, dimension(3) :: stride, dilation real(real32) :: grad_sum ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1:3,1) dilation = this%adj_ja(1:3,2) kernel_h = this%adj_ja(1,3) kernel_w = this%adj_ja(2,3) kernel_d = this%adj_ja(3,3) input_h = this%left_operand%shape(1) input_w = this%left_operand%shape(2) input_d = this%left_operand%shape(3) output_h = this%shape(1) output_w = this%shape(2) output_d = this%shape(3) output = 0._real32 channel_size_in = input_h * input_w * input_d channel_size_out = output_h * output_w * output_d do c_out = 1, num_filters do c_in = 1, num_channels do kk = 1, kernel_d do kj = 1, kernel_w do ki = 1, kernel_h k_idx = ki + (kj-1)*kernel_h + & (kk-1)*kernel_h*kernel_w + & (c_in-1)*kernel_h*kernel_w*kernel_d + & (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels grad_sum = 0._real32 do s = 1, size(upstream_grad, dim=2) do k = 1, output_d k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1 if(k_in .ge. 1 .and. k_in .le. input_d)then do j = 1, output_w j_in = (j-1)*stride(2) + & (kj-1)*dilation(2) + 1 if(j_in .ge. 1 .and. j_in .le. input_w)then do i = 1, output_h i_in = (i-1)*stride(1) + & (ki-1)*dilation(1) + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then in_idx = i_in + (j_in-1)*input_h + & (k_in-1)*input_h*input_w + & (c_in-1)*channel_size_in out_idx = i + (j-1)*output_h + & (k-1)*output_h*output_w + & (c_out-1)*channel_size_out grad_sum = grad_sum + & upstream_grad(out_idx, s) * & this%left_operand%val(in_idx, s) end if end do end if end do end if end do end do output(k_idx, 1) = grad_sum end do end do end do end do end do end subroutine get_partial_conv3d_kernel_val