pure subroutine get_partial_maxpool2d_val(this, upstream_grad, output)
implicit none
! Arguments
class(array_type), intent(in) :: this
real(real32), dimension(:,:), intent(in) :: upstream_grad
real(real32), dimension(:,:), intent(out) :: output
! Local variables
integer :: i, j, m, s
integer :: i_step, j_step
integer :: base_idx, in_idx, out_idx, max_idx, input_h
real(real32) :: pool_max, val_tmp, grad_val
integer :: channel_size_in, channel_size_out
integer, dimension(4) :: input_shape
integer, dimension(2) :: pool_size, stride
! Unpack parameters
input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
pool_size = this%adj_ja(:,1)
stride = this%adj_ja(:,2)
input_h = input_shape(1)
channel_size_in = input_h * input_shape(2)
channel_size_out = this%shape(1) * this%shape(2)
output = 0._real32
! Parallelised over batch and spatial/channel dimensions
do concurrent(s = 1:input_shape(4), m = 1:this%shape(3), &
j = 1:this%shape(2), i = 1:this%shape(1))
! Compute indices once
base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
(m-1) * channel_size_in
out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
grad_val = upstream_grad(out_idx, s)
! Find max value location - initialise with first element
max_idx = base_idx + 1
pool_max = this%left_operand%val(max_idx, s)
! Search remaining elements for max
do j_step = 0, pool_size(2) - 1
do i_step = 0, pool_size(1) - 1
if(i_step .eq. 0 .and. j_step .eq. 0) cycle ! Already processed
in_idx = base_idx + i_step + j_step * input_h + 1
val_tmp = this%left_operand%val(in_idx, s)
if(val_tmp .gt. pool_max)then
pool_max = val_tmp
max_idx = in_idx
end if
end do
end do
! Assign gradient to max location
output(max_idx, s) = output(max_idx, s) + grad_val
end do
end subroutine get_partial_maxpool2d_val