Optimised backward pass for 3D max pooling
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:,:) | :: | output |
pure subroutine get_partial_maxpool3d_val(this, upstream_grad, output) !! Optimised backward pass for 3D max pooling implicit none ! Arguments class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output ! Local variables integer :: i, j, k, m, s integer :: i_step, j_step, k_step integer :: base_idx, in_idx, out_idx, max_idx integer :: input_h, input_hw integer :: channel_size_in, channel_size_out real(real32) :: pool_max, val_tmp, grad_val integer, dimension(5) :: input_shape integer, dimension(3) :: pool_size, stride ! Unpack parameters input_shape = [ this%left_operand%shape, size(this%val, dim=2) ] pool_size = this%adj_ja(:,1) stride = this%adj_ja(:,2) input_h = input_shape(1) input_hw = input_h * input_shape(2) channel_size_in = input_hw * input_shape(3) channel_size_out = this%shape(1) * this%shape(2) * this%shape(3) output = 0._real32 ! Parallelised over batch and spatial/channel dimensions do concurrent(s = 1:input_shape(5), m = 1:this%shape(4), & k = 1:this%shape(3), j = 1:this%shape(2), i = 1:this%shape(1)) ! Compute indices once base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + & ((k-1)*stride(3)) * input_hw + (m-1) * channel_size_in out_idx = i + (j-1) * this%shape(1) + & (k-1) * this%shape(1)*this%shape(2) + & (m-1) * channel_size_out grad_val = upstream_grad(out_idx, s) ! Find max value location - initialise with first element max_idx = base_idx + 1 pool_max = this%left_operand%val(max_idx, s) ! Search remaining elements for max do k_step = 0, pool_size(3)-1 do j_step = 0, pool_size(2)-1 do i_step = 0, pool_size(1)-1 if(i_step .eq. 0 .and. j_step .eq. 0 .and. k_step .eq. 0) cycle in_idx = base_idx + i_step + j_step * input_h + & k_step * input_hw + 1 val_tmp = this%left_operand%val(in_idx, s) if(val_tmp .gt. pool_max)then pool_max = val_tmp max_idx = in_idx end if end do end do end do ! Assign gradient to max location output(max_idx, s) = output(max_idx, s) + grad_val end do end subroutine get_partial_maxpool3d_val