submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_conv !! Submodule containing implementations for extended diffstruc array operations contains !############################################################################### module function conv1d(input, kernel, stride, dilation) result(output) !! 1D convolution operation implicit none ! Arguments type(array_type), intent(in), target :: input type(array_type), intent(in), target :: kernel integer, intent(in) :: stride integer, intent(in) :: dilation type(array_type), pointer :: output ! Local variables integer :: i, k, c_in, c_out, s integer :: i_in, k_idx integer :: input_h, kernel_h, output_h, num_channels, num_filters real(real32) :: conv_sum integer, dimension(3) :: output_shape ! Extract dimensions ! input: [H_in, C_in, B] ! kernel: [K, C_in, C_out] input_h = input%shape(1) num_channels = input%shape(2) kernel_h = kernel%shape(1) num_filters = kernel%shape(3) ! Calculate output dimensions output_h = (input_h - dilation*(kernel_h - 1) - 1) / & stride + 1 output_shape = [output_h, num_filters, size(input%val, dim=2)] output => input%create_result(array_shape = output_shape) output%val = 0._real32 ! Perform convolution do concurrent(s = 1:output_shape(3), c_out = 1:num_filters, & i = 1:output_h) conv_sum = 0._real32 do c_in = 1, num_channels do k = 1, kernel_h i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then k_idx = k + ( c_in - 1 ) * kernel_h + & ( c_out - 1 ) * kernel_h * num_channels conv_sum = conv_sum + & input%val(i_in + ( c_in - 1 ) * input_h, s) * & kernel%val(k_idx, 1) end if end do end do output%val(i + (c_out-1)*output_h, s) = conv_sum end do ! Store parameters for backward pass allocate(output%indices(2)) output%indices(1) = num_channels output%indices(2) = num_filters allocate(output%adj_ja(1,3)) output%adj_ja(1,1) = stride output%adj_ja(1,2) = dilation output%adj_ja(1,3) = kernel_h output%get_partial_left => get_partial_conv1d_input output%get_partial_right => get_partial_conv1d_kernel output%get_partial_left_val => get_partial_conv1d_input_val output%get_partial_right_val => get_partial_conv1d_kernel_val if(input%requires_grad .or. kernel%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'conv1d' output%left_operand => input output%right_operand => kernel end if end function conv1d !------------------------------------------------------------------------------- function get_partial_conv1d_input(this, upstream_grad) result(output) !! Get partial derivative wrt input for 1D convolution implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%left_operand%shape, & size(this%left_operand%val, dim=2) ]) call this%get_partial_left_val(upstream_grad%val, output%val) end function get_partial_conv1d_input !------------------------------------------------------------------------------- function get_partial_conv1d_kernel(this, upstream_grad) result(output) !! Get partial derivative wrt kernel for 1D convolution implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) call this%get_partial_right_val(upstream_grad%val, output%val) end function get_partial_conv1d_kernel !------------------------------------------------------------------------------- pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output) !! Get partial derivative wrt input for 1D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, k, c_in, c_out, s integer :: i_in, k_idx, out_idx integer :: input_h, kernel_h, output_h, num_channels, num_filters integer :: stride, dilation real(real32) :: grad_val, kernel_val ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1,1) dilation = this%adj_ja(1,2) kernel_h = this%adj_ja(1,3) input_h = this%left_operand%shape(1) output_h = this%shape(1) output = 0._real32 ! Parallelised over batch, channels, and output positions do concurrent(s = 1:size(upstream_grad, dim=2), c_in = 1:num_channels, & i = 1:output_h, c_out = 1:num_filters) out_idx = i + (c_out-1)*output_h grad_val = upstream_grad(out_idx, s) if(abs(grad_val) .gt. 1.e-30_real32)then do k = 1, kernel_h i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then k_idx = k + ( c_in - 1 ) * kernel_h + & ( c_out - 1 ) * kernel_h * num_channels kernel_val = this%right_operand%val(k_idx, 1) output(i_in + ( c_in - 1 ) * input_h, s) = & output(i_in + ( c_in - 1 ) * input_h, s) + & grad_val * kernel_val end if end do end if end do end subroutine get_partial_conv1d_input_val !------------------------------------------------------------------------------- pure subroutine get_partial_conv1d_kernel_val(this, upstream_grad, output) !! Get partial derivative wrt kernel for 1D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, k, c_in, c_out, s integer :: i_in, k_idx, out_idx integer :: input_h, kernel_h, output_h, num_channels, num_filters integer :: stride, dilation real(real32) :: grad_sum ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1,1) dilation = this%adj_ja(1,2) kernel_h = this%adj_ja(1,3) input_h = this%left_operand%shape(1) output_h = this%shape(1) output = 0._real32 ! Parallelised over filters, channels, and kernel positions do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, k = 1:kernel_h) k_idx = k + ( c_in - 1 ) * kernel_h + & ( c_out - 1 ) * kernel_h * num_channels grad_sum = 0._real32 do s = 1, size(upstream_grad, dim=2) do i = 1, output_h i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then out_idx = i + ( c_out - 1 ) * output_h grad_sum = grad_sum + upstream_grad(out_idx, s) * & this%left_operand%val(i_in + ( c_in - 1 ) * input_h, s) end if end do end do output(k_idx, 1) = grad_sum end do end subroutine get_partial_conv1d_kernel_val !############################################################################### !############################################################################### module function conv2d(input, kernel, stride, dilation) result(output) !! 2D convolution operation implicit none ! Arguments type(array_type), intent(in), target :: input type(array_type), intent(in), target :: kernel integer, dimension(2), intent(in) :: stride integer, dimension(2), intent(in) :: dilation type(array_type), pointer :: output ! Local variables integer :: i, j, ki, kj, c_in, c_out, s integer :: i_in, j_in, k_idx, out_idx, in_idx, in_base_idx, k_base_idx integer :: input_h, input_w, kernel_h, kernel_w integer :: output_h, output_w, num_channels, num_filters integer :: channel_size_in, channel_size_out, kernel_channel_size integer :: dil_kernel_h_m1, dil_kernel_w_m1 real(real32) :: conv_sum integer, dimension(4) :: output_shape ! Extract dimensions ! input: [H_in, W_in, C_in, B] ! kernel: [K_h, K_w, C_in, C_out] input_h = input%shape(1) input_w = input%shape(2) num_channels = input%shape(3) kernel_h = kernel%shape(1) kernel_w = kernel%shape(2) num_filters = kernel%shape(4) ! Pre-compute common values channel_size_in = input_h * input_w kernel_channel_size = kernel_h * kernel_w dil_kernel_h_m1 = dilation(1) * (kernel_h - 1) dil_kernel_w_m1 = dilation(2) * (kernel_w - 1) ! Calculate output dimensions output_h = (input_h - dil_kernel_h_m1 - 1) / stride(1) + 1 output_w = (input_w - dil_kernel_w_m1 - 1) / stride(2) + 1 output_shape = [output_h, output_w, num_filters, & size(input%val, dim=2)] output => input%create_result(array_shape = output_shape) output%val = 0._real32 channel_size_out = output_h * output_w ! Perform convolution do concurrent(s = 1:output_shape(4), c_out = 1:num_filters, & j = 1:output_w, i = 1:output_h) conv_sum = 0._real32 do c_in = 1, num_channels in_base_idx = (c_in - 1) * channel_size_in k_base_idx = (c_in - 1) * kernel_channel_size + & (c_out - 1) * kernel_channel_size * num_channels do kj = 1, kernel_w j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 if(j_in .ge. 1 .and. j_in .le. input_w)then do ki = 1, kernel_h i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then in_idx = i_in + (j_in - 1) * input_h + in_base_idx k_idx = ki + (kj - 1) * kernel_h + k_base_idx conv_sum = conv_sum + input%val(in_idx, s) * & kernel%val(k_idx, 1) end if end do end if end do end do out_idx = i + (j - 1) * output_h + (c_out - 1) * channel_size_out output%val(out_idx, s) = conv_sum end do ! Store parameters for backward pass allocate(output%indices(2)) output%indices(1) = num_channels output%indices(2) = num_filters allocate(output%adj_ja(2,3)) output%adj_ja(1:2,1) = stride output%adj_ja(1:2,2) = dilation output%adj_ja(1,3) = kernel_h output%adj_ja(2,3) = kernel_w output%get_partial_left => get_partial_conv2d_input output%get_partial_right => get_partial_conv2d_kernel output%get_partial_left_val => get_partial_conv2d_input_val output%get_partial_right_val => get_partial_conv2d_kernel_val if(input%requires_grad .or. kernel%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'conv2d' output%left_operand => input output%right_operand => kernel end if end function conv2d !------------------------------------------------------------------------------- function get_partial_conv2d_input(this, upstream_grad) result(output) !! Get partial derivative wrt input for 2D convolution implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%left_operand%shape, & size(this%left_operand%val, dim=2) ]) call this%get_partial_left_val(upstream_grad%val, output%val) end function get_partial_conv2d_input !------------------------------------------------------------------------------- function get_partial_conv2d_kernel(this, upstream_grad) result(output) !! Get partial derivative wrt kernel for 2D convolution implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) call this%get_partial_right_val(upstream_grad%val, output%val) end function get_partial_conv2d_kernel !------------------------------------------------------------------------------- pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output) !! Get partial derivative wrt input for 2D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, ki, kj, c_in, c_out, s integer :: i_in, j_in, k_idx, out_idx, in_idx integer :: in_base_idx, k_base_idx, kernel_channel_size integer :: input_h, input_w, kernel_h, kernel_w integer :: output_h, output_w, num_channels, num_filters integer, dimension(2) :: stride, dilation integer :: channel_size_in, channel_size_out real(real32) :: grad_val, kernel_val ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1:2,1) dilation = this%adj_ja(1:2,2) kernel_h = this%adj_ja(1,3) kernel_w = this%adj_ja(2,3) input_h = this%left_operand%shape(1) input_w = this%left_operand%shape(2) output_h = this%shape(1) output_w = this%shape(2) channel_size_in = input_h * input_w channel_size_out = output_h * output_w kernel_channel_size = kernel_h * kernel_w output = 0._real32 ! Parallelised over batch, output channels, output spatial dims do concurrent(s = 1:size(upstream_grad, dim=2), c_out = 1:num_filters, & j = 1:output_w, i = 1:output_h) out_idx = i + (j-1)*output_h + (c_out-1)*channel_size_out grad_val = upstream_grad(out_idx, s) if(abs(grad_val) .gt. 1.e-30_real32)then do c_in = 1, num_channels in_base_idx = (c_in - 1) * channel_size_in k_base_idx = (c_in - 1) * kernel_channel_size + & (c_out - 1) * kernel_channel_size * num_channels do kj = 1, kernel_w j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 if(j_in .ge. 1 .and. j_in .le. input_w)then do ki = 1, kernel_h i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then in_idx = i_in + (j_in - 1) * input_h + in_base_idx k_idx = (kernel_h - ki + 1) + & (kernel_w - kj) * kernel_h + k_base_idx kernel_val = this%right_operand%val(k_idx, 1) output(in_idx, s) = output(in_idx, s) + & grad_val * kernel_val end if end do end if end do end do end if end do end subroutine get_partial_conv2d_input_val !------------------------------------------------------------------------------- pure subroutine get_partial_conv2d_kernel_val(this, upstream_grad, output) !! Get partial derivative wrt kernel for 2D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, ki, kj, c_in, c_out, s integer :: i_in, j_in, k_idx, out_idx, in_idx integer :: in_base_idx, out_base_idx, k_base_idx, kernel_channel_size integer :: input_h, input_w, kernel_h, kernel_w integer :: output_h, output_w, num_channels, num_filters integer, dimension(2) :: stride, dilation integer :: channel_size_in, channel_size_out real(real32) :: grad_sum ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1:2,1) dilation = this%adj_ja(1:2,2) kernel_h = this%adj_ja(1,3) kernel_w = this%adj_ja(2,3) input_h = this%left_operand%shape(1) input_w = this%left_operand%shape(2) output_h = this%shape(1) output_w = this%shape(2) channel_size_in = input_h * input_w channel_size_out = output_h * output_w kernel_channel_size = kernel_h * kernel_w output = 0._real32 ! Parallelised over filters, channels, and kernel dimensions do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, & kj = 1:kernel_w, ki = 1:kernel_h) out_base_idx = (c_out - 1) * channel_size_out in_base_idx = (c_in - 1) * channel_size_in k_base_idx = (c_in - 1) * kernel_channel_size + & (c_out - 1) * kernel_channel_size * num_channels k_idx = ki + (kj - 1) * kernel_h + k_base_idx grad_sum = 0._real32 do s = 1, size(upstream_grad, dim=2) do j = 1, output_w j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 if(j_in .ge. 1 .and. j_in .le. input_w)then do i = 1, output_h i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then in_idx = i_in + (j_in - 1) * input_h + in_base_idx out_idx = i + (j - 1) * output_h + out_base_idx grad_sum = grad_sum + & upstream_grad(out_idx, s) * this%left_operand%val(in_idx, s) end if end do end if end do end do output(k_idx, 1) = grad_sum end do end subroutine get_partial_conv2d_kernel_val !############################################################################### !############################################################################### module function conv3d(input, kernel, stride, dilation) result(output) !! 3D convolution operation implicit none ! Arguments type(array_type), intent(in), target :: input type(array_type), intent(in), target :: kernel integer, dimension(3), intent(in) :: stride integer, dimension(3), intent(in) :: dilation type(array_type), pointer :: output ! Local variables integer :: i, j, k, ki, kj, kk, c_in, c_out, s integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d integer :: output_h, output_w, output_d integer :: num_channels, num_filters integer :: channel_size_in, channel_size_out real(real32) :: conv_sum integer, dimension(5) :: output_shape ! Extract dimensions ! input: [H_in, W_in, D_in, C_in, B] ! kernel: [K_h, K_w, K_d, C_in, C_out] input_h = input%shape(1) input_w = input%shape(2) input_d = input%shape(3) num_channels = input%shape(4) kernel_h = kernel%shape(1) kernel_w = kernel%shape(2) kernel_d = kernel%shape(3) num_filters = kernel%shape(5) ! Calculate output dimensions output_h = (input_h - dilation(1)*(kernel_h - 1) - 1) / & stride(1) + 1 output_w = (input_w - dilation(2)*(kernel_w - 1) - 1) / & stride(2) + 1 output_d = (input_d - dilation(3)*(kernel_d - 1) - 1) / & stride(3) + 1 output_shape = [output_h, output_w, output_d, num_filters, & size(input%val, dim=2)] output => input%create_result(array_shape = output_shape) output%val = 0._real32 channel_size_in = input_h * input_w * input_d channel_size_out = output_h * output_w * output_d ! Perform convolution - optimised with do concurrent do concurrent(s = 1:output_shape(5), c_out = 1:num_filters, & k = 1:output_d, j = 1:output_w, i = 1:output_h) conv_sum = 0._real32 do c_in = 1, num_channels do kk = 1, kernel_d k_in = ( k - 1 ) * stride(3) + ( kk - 1 ) * dilation(3) + 1 if(k_in .ge. 1 .and. k_in .le. input_d)then do kj = 1, kernel_w j_in = ( j - 1 ) * stride(2) + (kj - 1) * dilation(2) + 1 if(j_in .ge. 1 .and. j_in .le. input_w)then do ki = 1, kernel_h i_in = ( i - 1 ) * stride(1) + & ( ki - 1 ) * dilation(1) + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then in_idx = i_in + ( j_in - 1 ) * input_h + & ( k_in - 1 ) * input_h * input_w + & ( c_in - 1 ) * channel_size_in k_idx = ki + ( kj - 1 ) * kernel_h + & ( kk - 1 ) * kernel_h * kernel_w + & ( c_in - 1 ) * kernel_h * kernel_w * & kernel_d + & ( c_out - 1 ) * kernel_h * kernel_w * & kernel_d * num_channels conv_sum = conv_sum + input%val(in_idx, s) * & kernel%val(k_idx, 1) end if end do end if end do end if end do end do out_idx = i + ( j - 1 ) * output_h + & ( k - 1 ) * output_h * output_w + & ( c_out - 1 ) * channel_size_out output%val(out_idx, s) = conv_sum end do ! Store parameters for backward pass allocate(output%indices(2)) output%indices(1) = num_channels output%indices(2) = num_filters allocate(output%adj_ja(3,3)) output%adj_ja(1:3,1) = stride output%adj_ja(1:3,2) = dilation output%adj_ja(1,3) = kernel_h output%adj_ja(2,3) = kernel_w output%adj_ja(3,3) = kernel_d output%get_partial_left => get_partial_conv3d_input output%get_partial_right => get_partial_conv3d_kernel output%get_partial_left_val => get_partial_conv3d_input_val output%get_partial_right_val => get_partial_conv3d_kernel_val if(input%requires_grad .or. kernel%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'conv3d' output%left_operand => input output%right_operand => kernel end if end function conv3d !------------------------------------------------------------------------------- function get_partial_conv3d_input(this, upstream_grad) result(output) !! Get partial derivative wrt input for 3D convolution implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%left_operand%shape, & size(this%left_operand%val, dim=2) ]) call this%get_partial_left_val(upstream_grad%val, output%val) end function get_partial_conv3d_input !------------------------------------------------------------------------------- function get_partial_conv3d_kernel(this, upstream_grad) result(output) !! Get partial derivative wrt kernel for 3D convolution implicit none class(array_type), intent(inout) :: this type(array_type), intent(in) :: upstream_grad type(array_type) :: output call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) call this%get_partial_right_val(upstream_grad%val, output%val) end function get_partial_conv3d_kernel !------------------------------------------------------------------------------- pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output) !! Get partial derivative wrt input for 3D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, k, ki, kj, kk, c_in, c_out, s integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d integer :: output_h, output_w, output_d integer :: num_channels, num_filters integer :: channel_size_in, channel_size_out integer, dimension(3) :: stride, dilation real(real32) :: grad_val, kernel_val ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1:3,1) dilation = this%adj_ja(1:3,2) kernel_h = this%adj_ja(1,3) kernel_w = this%adj_ja(2,3) kernel_d = this%adj_ja(3,3) input_h = this%left_operand%shape(1) input_w = this%left_operand%shape(2) input_d = this%left_operand%shape(3) output_h = this%shape(1) output_w = this%shape(2) output_d = this%shape(3) output = 0._real32 channel_size_in = input_h * input_w * input_d channel_size_out = output_h * output_w * output_d do s = 1, size(upstream_grad, dim=2) do c_in = 1, num_channels do k = 1, output_d do j = 1, output_w do i = 1, output_h do c_out = 1, num_filters out_idx = i + ( j - 1 ) * output_h + & ( k - 1 ) * output_h * output_w + & ( c_out - 1 ) * channel_size_out grad_val = upstream_grad(out_idx, s) do kk = 1, kernel_d k_in = ( k - 1 ) * stride(3) + & ( kk - 1 ) * dilation(3) + 1 if( k_in .ge. 1 .and. k_in .le. input_d )then do kj = 1, kernel_w j_in = ( j - 1 ) * stride(2) + & ( kj - 1 ) * dilation(2) + 1 if( j_in .ge. 1 .and. j_in .le. input_w )then do ki = 1, kernel_h i_in = ( i - 1 ) * stride(1) + & ( ki - 1 ) * dilation(1) + 1 if( i_in .ge. 1 .and. & i_in .le. input_h )then in_idx = i_in + & ( j_in - 1 ) * input_h + & ( k_in - 1 ) * input_h * input_w + & ( c_in - 1 ) * channel_size_in k_idx = ki + ( kj - 1 ) * kernel_h + & ( kk - 1 ) * kernel_h * kernel_w + & ( c_in - 1 ) * kernel_h * & kernel_w * kernel_d + & ( c_out - 1 ) * kernel_h * & kernel_w * kernel_d * num_channels kernel_val = this%right_operand%val(k_idx, 1) output(in_idx, s) = & output(in_idx, s) + & grad_val * kernel_val end if end do end if end do end if end do end do end do end do end do end do end do end subroutine get_partial_conv3d_input_val !------------------------------------------------------------------------------- pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output) !! Get partial derivative wrt kernel for 3D convolution (subroutine version) implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, k, ki, kj, kk, c_in, c_out, s integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d integer :: output_h, output_w, output_d integer :: num_channels, num_filters integer :: channel_size_in, channel_size_out integer, dimension(3) :: stride, dilation real(real32) :: grad_sum ! Unpack parameters num_channels = this%indices(1) num_filters = this%indices(2) stride = this%adj_ja(1:3,1) dilation = this%adj_ja(1:3,2) kernel_h = this%adj_ja(1,3) kernel_w = this%adj_ja(2,3) kernel_d = this%adj_ja(3,3) input_h = this%left_operand%shape(1) input_w = this%left_operand%shape(2) input_d = this%left_operand%shape(3) output_h = this%shape(1) output_w = this%shape(2) output_d = this%shape(3) output = 0._real32 channel_size_in = input_h * input_w * input_d channel_size_out = output_h * output_w * output_d do c_out = 1, num_filters do c_in = 1, num_channels do kk = 1, kernel_d do kj = 1, kernel_w do ki = 1, kernel_h k_idx = ki + (kj-1)*kernel_h + & (kk-1)*kernel_h*kernel_w + & (c_in-1)*kernel_h*kernel_w*kernel_d + & (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels grad_sum = 0._real32 do s = 1, size(upstream_grad, dim=2) do k = 1, output_d k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1 if(k_in .ge. 1 .and. k_in .le. input_d)then do j = 1, output_w j_in = (j-1)*stride(2) + & (kj-1)*dilation(2) + 1 if(j_in .ge. 1 .and. j_in .le. input_w)then do i = 1, output_h i_in = (i-1)*stride(1) + & (ki-1)*dilation(1) + 1 if(i_in .ge. 1 .and. i_in .le. input_h)then in_idx = i_in + (j_in-1)*input_h + & (k_in-1)*input_h*input_w + & (c_in-1)*channel_size_in out_idx = i + (j-1)*output_h + & (k-1)*output_h*output_w + & (c_out-1)*channel_size_out grad_sum = grad_sum + & upstream_grad(out_idx, s) * & this%left_operand%val(in_idx, s) end if end do end if end do end if end do end do output(k_idx, 1) = grad_sum end do end do end do end do end do end subroutine get_partial_conv3d_kernel_val !############################################################################### end submodule athena__diffstruc_extd_submodule_conv