get_partial_avgpool2d_val Subroutine

pure subroutine get_partial_avgpool2d_val(this, upstream_grad, output)

Optimised backward pass for 2D average pooling

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_avgpool2d_val(this, upstream_grad, output)
    !! Optimised backward pass for 2D average pooling
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, m, s
    integer :: i_step, j_step
    integer :: base_idx, in_idx, out_idx, input_h
    integer :: channel_size_in, channel_size_out
    real(real32) :: pool_norm, grad_val
    integer, dimension(4) :: input_shape
    integer, dimension(2) :: pool_size, stride

    ! Unpack parameters
    input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
    pool_size = this%adj_ja(:,1)
    stride    = this%adj_ja(:,2)
    input_h = input_shape(1)
    channel_size_in = input_h * input_shape(2)
    channel_size_out = this%shape(1) * this%shape(2)

    output = 0._real32

    pool_norm = 1.0_real32 / real(pool_size(1) * pool_size(2), real32)

    do concurrent( &
         s = 1:input_shape(4), &
         m = 1:this%shape(3), &
         j = 1:this%shape(2), &
         i = 1:this%shape(1))

       ! Compute indices once
       base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
            (m-1) * channel_size_in
       out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
       grad_val = upstream_grad(out_idx, s) * pool_norm

       ! Distribute gradient over pooling window
       do j_step = 0, pool_size(2) - 1
          do i_step = 0, pool_size(1) - 1
             in_idx = base_idx + i_step + j_step * input_h + 1
             output(in_idx, s) = output(in_idx, s) + grad_val
          end do
       end do
    end do

  end subroutine get_partial_avgpool2d_val