get_partial_maxpool1d_val Subroutine

pure subroutine get_partial_maxpool1d_val(this, upstream_grad, output)

Optimised backward pass for 1D max pooling

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_maxpool1d_val(this, upstream_grad, output)
    !! Optimised backward pass for 1D max pooling
    implicit none

    ! Arguments
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    ! Local variables
    integer :: i, m, s, p
    integer :: base_idx, max_idx, out_idx, input_h
    real(real32) :: pool_max, grad_val
    integer, dimension(3) :: input_shape
    integer, dimension(1) :: pool_size, stride

    input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
    pool_size(1) = this%adj_ja(1,1)
    stride(1) = this%adj_ja(1,2)
    input_h = input_shape(1)

    output = 0._real32

    do concurrent(s = 1:input_shape(3), m = 1:this%shape(2), &
         i = 1:this%shape(1))

       ! Compute indices once
       base_idx = (i - 1) * stride(1) + (m - 1) * input_h
       out_idx = i + (m - 1) * this%shape(1)
       grad_val = upstream_grad(out_idx, s)

       ! Find max value location - initialise with first element
       max_idx = base_idx + 1
       pool_max = this%left_operand%val(max_idx, s)

       ! Search remaining elements for max
       do p = 1, pool_size(1) - 1
          if(this%left_operand%val(base_idx + p + 1, s) .gt. pool_max)then
             pool_max = this%left_operand%val(base_idx + p + 1, s)
             max_idx = base_idx + p + 1
          end if
       end do

       ! Assign gradient to max location
       output(max_idx, s) = output(max_idx, s) + grad_val
    end do

  end subroutine get_partial_maxpool1d_val