maxpool2d Module Function

module function maxpool2d(input, pool_size, stride) result(output)

2D max pooling operation

Arguments

Type IntentOptional Attributes Name
type(array_type), intent(in), target :: input
integer, intent(in), dimension(2) :: pool_size
integer, intent(in), dimension(2) :: stride

Return Value type(array_type), pointer


Source Code

  module function maxpool2d(input, pool_size, stride) result(output)
    !! 2D max pooling operation
    implicit none

    ! Arguments
    type(array_type), intent(in), target :: input
    integer, dimension(2), intent(in) :: pool_size
    integer, dimension(2), intent(in) :: stride
    type(array_type), pointer :: output

    ! Local variables
    integer :: i, j, m, s, i_step, j_step
    integer :: base_idx, stride_idx, idx, input_h
    real(real32) :: pool_max, val_tmp
    integer :: channel_size_in, channel_size_out
    integer, dimension(4) :: output_shape

    output_shape = [ &
         (input%shape(1) - pool_size(1)) / stride(1) + 1, &
         (input%shape(2) - pool_size(2)) / stride(2) + 1, &
         input%shape(3), &
         size(input%val, dim=2)]
    output => input%create_result(array_shape = output_shape)

    ! Pre-compute as integers to avoid type conversion in loop
    input_h = input%shape(1)
    channel_size_in = input_h * input%shape(2)
    channel_size_out = output_shape(1) * output_shape(2)

    do concurrent(&
         s = 1:output_shape(4), &
         m = 1:output_shape(3), &
         j = 1:output_shape(2), &
         i = 1:output_shape(1))

       ! Compute indices once per output position
       base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
            (m-1) * channel_size_in
       idx = i + (j - 1) * output_shape(1) + (m - 1) * channel_size_out

       ! Find max value - initialise with first element for better performance
       stride_idx = base_idx + 1
       pool_max = input%val(stride_idx, s)

       ! Continue with remaining elements
       do j_step = 0, pool_size(2)-1
          do i_step = 0, pool_size(1)-1
             if(i_step .eq. 0 .and. j_step .eq. 0) cycle  ! Already processed
             stride_idx = base_idx + i_step + j_step * input_h + 1
             if(input%val(stride_idx, s) .gt. pool_max) &
                  pool_max = input%val(stride_idx, s)
          end do
       end do

       output%val(idx, s) = pool_max
    end do

    allocate(output%adj_ja(2,2))
    output%adj_ja(:,1) = pool_size
    output%adj_ja(:,2) = stride

    output%get_partial_left => get_partial_maxpool2d
    output%get_partial_left_val => get_partial_maxpool2d_val
    if(input%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = input%is_forward
       output%operation = 'maxpool'
       output%left_operand => input
    end if

  end function maxpool2d