get_partial_avgpool3d_val Subroutine

pure subroutine get_partial_avgpool3d_val(this, upstream_grad, output)

Optimised backward pass for 3D average pooling

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_avgpool3d_val(this, upstream_grad, output)
    !! Optimised backward pass for 3D average pooling
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, j, k, m, s
    integer :: i_step, j_step, k_step
    integer :: base_idx, in_idx, out_idx, input_h, input_hw
    integer :: channel_size_in, channel_size_out
    real(real32) :: pool_norm, grad_val
    integer, dimension(5) :: input_shape
    integer, dimension(3) :: pool_size, stride

    ! Unpack parameters
    input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
    pool_size = this%adj_ja(:,1)
    stride    = this%adj_ja(:,2)
    input_h = input_shape(1)
    input_hw = input_h * input_shape(2)
    channel_size_in = input_hw * input_shape(3)
    channel_size_out = this%shape(1) * this%shape(2) * this%shape(3)

    output = 0._real32

    pool_norm = 1.0_real32 / real(product(pool_size), real32)

    do concurrent( &
         s = 1:input_shape(5), &
         m = 1:this%shape(4), &
         k = 1:this%shape(3), &
         j = 1:this%shape(2), &
         i = 1:this%shape(1))

       ! Compute indices once
       base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
            ((k-1)*stride(3)) * input_hw + (m-1) * channel_size_in
       out_idx = i + (j-1) * this%shape(1) + &
            (k-1) * this%shape(1)*this%shape(2) + &
            (m-1) * channel_size_out
       grad_val = upstream_grad(out_idx, s) * pool_norm

       ! Distribute gradient over pooling window
       do k_step = 0, pool_size(3)-1
          do j_step = 0, pool_size(2)-1
             do i_step = 0, pool_size(1)-1
                in_idx = base_idx + i_step + j_step * input_h + &
                     k_step * input_hw + 1
                output(in_idx, s) = output(in_idx, s) + grad_val
             end do
          end do
       end do
    end do

  end subroutine get_partial_avgpool3d_val