conv2d Module Function

module function conv2d(input, kernel, stride, dilation) result(output)

2D convolution operation

Arguments

Type IntentOptional Attributes Name
type(array_type), intent(in), target :: input
type(array_type), intent(in), target :: kernel
integer, intent(in), dimension(2) :: stride
integer, intent(in), dimension(2) :: dilation

Return Value type(array_type), pointer


Source Code

  module function conv2d(input, kernel, stride, dilation) result(output)
    !! 2D convolution operation
    implicit none

    ! Arguments
    type(array_type), intent(in), target :: input
    type(array_type), intent(in), target :: kernel
    integer, dimension(2), intent(in) :: stride
    integer, dimension(2), intent(in) :: dilation
    type(array_type), pointer :: output

    ! Local variables
    integer :: i, j, ki, kj, c_in, c_out, s
    integer :: i_in, j_in, k_idx, out_idx, in_idx, in_base_idx, k_base_idx
    integer :: input_h, input_w, kernel_h, kernel_w
    integer :: output_h, output_w, num_channels, num_filters
    integer :: channel_size_in, channel_size_out, kernel_channel_size
    integer :: dil_kernel_h_m1, dil_kernel_w_m1
    real(real32) :: conv_sum
    integer, dimension(4) :: output_shape

    ! Extract dimensions
    ! input: [H_in, W_in, C_in, B]
    ! kernel: [K_h, K_w, C_in, C_out]
    input_h = input%shape(1)
    input_w = input%shape(2)
    num_channels = input%shape(3)
    kernel_h = kernel%shape(1)
    kernel_w = kernel%shape(2)
    num_filters = kernel%shape(4)

    ! Pre-compute common values
    channel_size_in = input_h * input_w
    kernel_channel_size = kernel_h * kernel_w
    dil_kernel_h_m1 = dilation(1) * (kernel_h - 1)
    dil_kernel_w_m1 = dilation(2) * (kernel_w - 1)

    ! Calculate output dimensions
    output_h = (input_h - dil_kernel_h_m1 - 1) / stride(1) + 1
    output_w = (input_w - dil_kernel_w_m1 - 1) / stride(2) + 1
    output_shape = [output_h, output_w, num_filters, &
         size(input%val, dim=2)]

    output => input%create_result(array_shape = output_shape)
    output%val = 0._real32

    channel_size_out = output_h * output_w

    ! Perform convolution
    do concurrent(s = 1:output_shape(4), c_out = 1:num_filters, &
         j = 1:output_w, i = 1:output_h)
       conv_sum = 0._real32
       do c_in = 1, num_channels
          in_base_idx = (c_in - 1) * channel_size_in
          k_base_idx = (c_in - 1) * kernel_channel_size + &
               (c_out - 1) * kernel_channel_size * num_channels
          do kj = 1, kernel_w
             j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
             if(j_in .ge. 1 .and. j_in .le. input_w)then
                do ki = 1, kernel_h
                   i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
                   if(i_in .ge. 1 .and. i_in .le. input_h)then
                      in_idx = i_in + (j_in - 1) * input_h + in_base_idx
                      k_idx = ki + (kj - 1) * kernel_h + k_base_idx
                      conv_sum = conv_sum + input%val(in_idx, s) * &
                           kernel%val(k_idx, 1)
                   end if
                end do
             end if
          end do
       end do
       out_idx = i + (j - 1) * output_h + (c_out - 1) * channel_size_out
       output%val(out_idx, s) = conv_sum
    end do

    ! Store parameters for backward pass
    allocate(output%indices(2))
    output%indices(1) = num_channels
    output%indices(2) = num_filters
    allocate(output%adj_ja(2,3))
    output%adj_ja(1:2,1) = stride
    output%adj_ja(1:2,2) = dilation
    output%adj_ja(1,3) = kernel_h
    output%adj_ja(2,3) = kernel_w


    output%get_partial_left => get_partial_conv2d_input
    output%get_partial_right => get_partial_conv2d_kernel
    output%get_partial_left_val => get_partial_conv2d_input_val
    output%get_partial_right_val => get_partial_conv2d_kernel_val
    if(input%requires_grad .or. kernel%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = input%is_forward
       output%operation = 'conv2d'
       output%left_operand => input
       output%right_operand => kernel
    end if

  end function conv2d