Get partial derivative wrt input for 3D convolution (subroutine version)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:,:) | :: | output |
pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output) !! Get partial derivative wrt input for 3D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, k, ki, kj, kk, c_in, c_out, s integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d integer :: output_h, output_w, output_d integer :: num_channels, num_filters integer :: channel_size_in, channel_size_out integer, dimension(3) :: stride, dilation real(real32) :: grad_val, kernel_val ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1:3,1) dilation = this%adj_ja(1:3,2) kernel_h = this%adj_ja(1,3) kernel_w = this%adj_ja(2,3) kernel_d = this%adj_ja(3,3) input_h = this%left_operand%shape(1) input_w = this%left_operand%shape(2) input_d = this%left_operand%shape(3) output_h = this%shape(1) output_w = this%shape(2) output_d = this%shape(3) output = 0._real32 channel_size_in = input_h * input_w * input_d channel_size_out = output_h * output_w * output_d do s = 1, size(upstream_grad, dim=2) do c_in = 1, num_channels do k = 1, output_d do j = 1, output_w do i = 1, output_h do c_out = 1, num_filters out_idx = i + ( j - 1 ) * output_h + & ( k - 1 ) * output_h * output_w + & ( c_out - 1 ) * channel_size_out grad_val = upstream_grad(out_idx, s) do kk = 1, kernel_d k_in = ( k - 1 ) * stride(3) + & ( kk - 1 ) * dilation(3) + 1 if( k_in .ge. 1 .and. k_in .le. input_d )then do kj = 1, kernel_w j_in = ( j - 1 ) * stride(2) + & ( kj - 1 ) * dilation(2) + 1 if( j_in .ge. 1 .and. j_in .le. input_w )then do ki = 1, kernel_h i_in = ( i - 1 ) * stride(1) + & ( ki - 1 ) * dilation(1) + 1 if( i_in .ge. 1 .and. & i_in .le. input_h )then in_idx = i_in + & ( j_in - 1 ) * input_h + & ( k_in - 1 ) * input_h * input_w + & ( c_in - 1 ) * channel_size_in k_idx = ki + ( kj - 1 ) * kernel_h + & ( kk - 1 ) * kernel_h * kernel_w + & ( c_in - 1 ) * kernel_h * & kernel_w * kernel_d + & ( c_out - 1 ) * kernel_h * & kernel_w * kernel_d * num_channels kernel_val = this%right_operand%val(k_idx, 1) output(in_idx, s) = & output(in_idx, s) + & grad_val * kernel_val end if end do end if end do end if end do end do end do end do end do end do end do end subroutine get_partial_conv3d_input_val