conv3d Module Function

module function conv3d(input, kernel, stride, dilation) result(output)

3D convolution operation

Arguments

Type IntentOptional Attributes Name
type(array_type), intent(in), target :: input
type(array_type), intent(in), target :: kernel
integer, intent(in), dimension(3) :: stride
integer, intent(in), dimension(3) :: dilation

Return Value type(array_type), pointer


Source Code

  module function conv3d(input, kernel, stride, dilation) result(output)
    !! 3D convolution operation
    implicit none

    ! Arguments
    type(array_type), intent(in), target :: input
    type(array_type), intent(in), target :: kernel
    integer, dimension(3), intent(in) :: stride
    integer, dimension(3), intent(in) :: dilation
    type(array_type), pointer :: output

    ! Local variables
    integer :: i, j, k, ki, kj, kk, c_in, c_out, s
    integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
    integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
    integer :: output_h, output_w, output_d
    integer :: num_channels, num_filters
    integer :: channel_size_in, channel_size_out
    real(real32) :: conv_sum
    integer, dimension(5) :: output_shape

    ! Extract dimensions
    ! input: [H_in, W_in, D_in, C_in, B]
    ! kernel: [K_h, K_w, K_d, C_in, C_out]
    input_h = input%shape(1)
    input_w = input%shape(2)
    input_d = input%shape(3)
    num_channels = input%shape(4)
    kernel_h = kernel%shape(1)
    kernel_w = kernel%shape(2)
    kernel_d = kernel%shape(3)
    num_filters = kernel%shape(5)

    ! Calculate output dimensions
    output_h = (input_h - dilation(1)*(kernel_h - 1) - 1) / &
         stride(1) + 1
    output_w = (input_w - dilation(2)*(kernel_w - 1) - 1) / &
         stride(2) + 1
    output_d = (input_d - dilation(3)*(kernel_d - 1) - 1) / &
         stride(3) + 1
    output_shape = [output_h, output_w, output_d, num_filters, &
         size(input%val, dim=2)]

    output => input%create_result(array_shape = output_shape)
    output%val = 0._real32

    channel_size_in = input_h * input_w * input_d
    channel_size_out = output_h * output_w * output_d

    ! Perform convolution - optimised with do concurrent
    do concurrent(s = 1:output_shape(5), c_out = 1:num_filters, &
         k = 1:output_d, j = 1:output_w, i = 1:output_h)

       conv_sum = 0._real32
       do c_in = 1, num_channels
          do kk = 1, kernel_d
             k_in = ( k - 1 ) * stride(3) + ( kk - 1 ) * dilation(3) + 1
             if(k_in .ge. 1 .and. k_in .le. input_d)then
                do kj = 1, kernel_w
                   j_in = ( j - 1 ) * stride(2) + (kj - 1) * dilation(2) + 1
                   if(j_in .ge. 1 .and. j_in .le. input_w)then
                      do ki = 1, kernel_h
                         i_in = ( i - 1 ) * stride(1) + &
                              ( ki - 1 ) * dilation(1) + 1
                         if(i_in .ge. 1 .and. i_in .le. input_h)then
                            in_idx = i_in + ( j_in - 1 ) * input_h + &
                                 ( k_in - 1 ) * input_h * input_w + &
                                 ( c_in - 1 ) * channel_size_in
                            k_idx = ki + ( kj - 1 ) * kernel_h + &
                                 ( kk - 1 ) * kernel_h * kernel_w + &
                                 ( c_in - 1 ) * kernel_h * kernel_w * &
                                 kernel_d + &
                                 ( c_out - 1 ) * kernel_h * kernel_w * &
                                 kernel_d * num_channels
                            conv_sum = conv_sum + input%val(in_idx, s) * &
                                 kernel%val(k_idx, 1)
                         end if
                      end do
                   end if
                end do
             end if
          end do
       end do
       out_idx = i + ( j - 1 ) * output_h + &
            ( k - 1 ) * output_h * output_w + &
            ( c_out - 1 ) * channel_size_out
       output%val(out_idx, s) = conv_sum
    end do

    ! Store parameters for backward pass
    allocate(output%indices(2))
    output%indices(1) = num_channels
    output%indices(2) = num_filters
    allocate(output%adj_ja(3,3))
    output%adj_ja(1:3,1) = stride
    output%adj_ja(1:3,2) = dilation
    output%adj_ja(1,3) = kernel_h
    output%adj_ja(2,3) = kernel_w
    output%adj_ja(3,3) = kernel_d

    output%get_partial_left => get_partial_conv3d_input
    output%get_partial_right => get_partial_conv3d_kernel
    output%get_partial_left_val => get_partial_conv3d_input_val
    output%get_partial_right_val => get_partial_conv3d_kernel_val
    if(input%requires_grad .or. kernel%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = input%is_forward
       output%operation = 'conv3d'
       output%left_operand => input
       output%right_operand => kernel
    end if

  end function conv3d