maxpool3d Module Function

module function maxpool3d(input, pool_size, stride) result(output)

3D max pooling operation

Arguments

Type IntentOptional Attributes Name
type(array_type), intent(in), target :: input
integer, intent(in), dimension(3) :: pool_size
integer, intent(in), dimension(3) :: stride

Return Value type(array_type), pointer


Source Code

  module function maxpool3d(input, pool_size, stride) result(output)
    !! 3D max pooling operation
    implicit none

    ! Arguments
    type(array_type), intent(in), target :: input
    integer, dimension(3), intent(in) :: pool_size
    integer, dimension(3), intent(in) :: stride
    type(array_type), pointer :: output

    ! Local variables
    integer :: i, j, k, m, s
    integer :: i_step, j_step, k_step
    integer :: stride_idx, idx
    integer :: channel_size_in, channel_size_out
    real(real32) :: pool_max
    integer, dimension(5) :: output_shape

    ! output_shape = [H_out, W_out, D_out, C, B]
    output_shape = [ &
         (input%shape(1) - pool_size(1)) / stride(1) + 1, &
         (input%shape(2) - pool_size(2)) / stride(2) + 1, &
         (input%shape(3) - pool_size(3)) / stride(3) + 1, &
         input%shape(4), &
         size(input%val, dim=2) ]

    output => input%create_result(array_shape = output_shape)

    ! Pre-compute as integers
    channel_size_in = input%shape(1) * input%shape(2) * input%shape(3)
    channel_size_out = output_shape(1) * output_shape(2) * output_shape(3)

    do concurrent( &
         s = 1:output_shape(5), &
         m = 1:output_shape(4), &
         k = 1:output_shape(3), &
         j = 1:output_shape(2), &
         i = 1:output_shape(1))

       ! Compute indices once per output position
       stride_idx = ((i-1)*stride(1)) + &
            ((j-1)*stride(2)) * input%shape(1) + &
            ((k-1)*stride(3)) * input%shape(1) * input%shape(2) + &
            (m-1) * channel_size_in + 1
       idx = i + (j-1) * output_shape(1) + &
            (k-1) * output_shape(1)*output_shape(2) + &
            (m-1) * channel_size_out

       ! Find max value - initialise with first element
       pool_max = input%val(stride_idx, s)

       do k_step = 0, pool_size(3)-1
          do j_step = 0, pool_size(2)-1
             do i_step = 0, pool_size(1)-1
                if(i_step .eq. 0 .and. j_step .eq. 0 .and. k_step .eq. 0) cycle
                if( &
                     input%val( &
                          stride_idx + i_step + &
                          j_step * input%shape(1) + &
                          k_step * input%shape(1) * input%shape(2), s &
                     ) .gt. pool_max &
                )then
                   pool_max = input%val(stride_idx + i_step + &
                        j_step * input%shape(1) + &
                        k_step * input%shape(1) * input%shape(2), s)
                end if
             end do
          end do
       end do

       output%val(idx, s) = pool_max
    end do

    allocate(output%adj_ja(3,2))
    output%adj_ja(:,1) = pool_size
    output%adj_ja(:,2) = stride

    output%get_partial_left => get_partial_maxpool3d
    output%get_partial_left_val => get_partial_maxpool3d_val
    if(input%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = input%is_forward
       output%operation = 'maxpool3d'
       output%left_operand => input
    end if

  end function maxpool3d