2D max pooling operation
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| type(array_type), | intent(in), | target | :: | input | ||
| integer, | intent(in), | dimension(2) | :: | pool_size | ||
| integer, | intent(in), | dimension(2) | :: | stride |
module function maxpool2d(input, pool_size, stride) result(output) !! 2D max pooling operation implicit none ! Arguments type(array_type), intent(in), target :: input integer, dimension(2), intent(in) :: pool_size integer, dimension(2), intent(in) :: stride type(array_type), pointer :: output ! Local variables integer :: i, j, m, s, i_step, j_step integer :: base_idx, stride_idx, idx, input_h real(real32) :: pool_max, val_tmp integer :: channel_size_in, channel_size_out integer, dimension(4) :: output_shape output_shape = [ & (input%shape(1) - pool_size(1)) / stride(1) + 1, & (input%shape(2) - pool_size(2)) / stride(2) + 1, & input%shape(3), & size(input%val, dim=2)] output => input%create_result(array_shape = output_shape) ! Pre-compute as integers to avoid type conversion in loop input_h = input%shape(1) channel_size_in = input_h * input%shape(2) channel_size_out = output_shape(1) * output_shape(2) do concurrent(& s = 1:output_shape(4), & m = 1:output_shape(3), & j = 1:output_shape(2), & i = 1:output_shape(1)) ! Compute indices once per output position base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + & (m-1) * channel_size_in idx = i + (j - 1) * output_shape(1) + (m - 1) * channel_size_out ! Find max value - initialise with first element for better performance stride_idx = base_idx + 1 pool_max = input%val(stride_idx, s) ! Continue with remaining elements do j_step = 0, pool_size(2)-1 do i_step = 0, pool_size(1)-1 if(i_step .eq. 0 .and. j_step .eq. 0) cycle ! Already processed stride_idx = base_idx + i_step + j_step * input_h + 1 if(input%val(stride_idx, s) .gt. pool_max) & pool_max = input%val(stride_idx, s) end do end do output%val(idx, s) = pool_max end do allocate(output%adj_ja(2,2)) output%adj_ja(:,1) = pool_size output%adj_ja(:,2) = stride output%get_partial_left => get_partial_maxpool2d output%get_partial_left_val => get_partial_maxpool2d_val if(input%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward output%operation = 'maxpool' output%left_operand => input end if end function maxpool2d