get_partial_maxpool2d_val Subroutine

pure subroutine get_partial_maxpool2d_val(this, upstream_grad, output)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_maxpool2d_val(this, upstream_grad, output)
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, m, s
    integer :: i_step, j_step
    integer :: base_idx, in_idx, out_idx, max_idx, input_h
    real(real32) :: pool_max, val_tmp, grad_val
    integer :: channel_size_in, channel_size_out
    integer, dimension(4) :: input_shape
    integer, dimension(2) :: pool_size, stride

    ! Unpack parameters
    input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
    pool_size = this%adj_ja(:,1)
    stride    = this%adj_ja(:,2)
    input_h = input_shape(1)
    channel_size_in = input_h * input_shape(2)
    channel_size_out = this%shape(1) * this%shape(2)

    output = 0._real32

    ! Parallelised over batch and spatial/channel dimensions
    do concurrent(s = 1:input_shape(4), m = 1:this%shape(3), &
         j = 1:this%shape(2), i = 1:this%shape(1))

       ! Compute indices once
       base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
            (m-1) * channel_size_in
       out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
       grad_val = upstream_grad(out_idx, s)

       ! Find max value location - initialise with first element
       max_idx = base_idx + 1
       pool_max = this%left_operand%val(max_idx, s)

       ! Search remaining elements for max
       do j_step = 0, pool_size(2) - 1
          do i_step = 0, pool_size(1) - 1
             if(i_step .eq. 0 .and. j_step .eq. 0) cycle  ! Already processed
             in_idx = base_idx + i_step + j_step * input_h + 1
             val_tmp = this%left_operand%val(in_idx, s)

             if(val_tmp .gt. pool_max)then
                pool_max = val_tmp
                max_idx = in_idx
             end if
          end do
       end do

       ! Assign gradient to max location
       output(max_idx, s) = output(max_idx, s) + grad_val
    end do

  end subroutine get_partial_maxpool2d_val