get_partial_merge_scalar_over_channels_val Subroutine

pure subroutine get_partial_merge_scalar_over_channels_val(this, upstream_grad, output)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_merge_scalar_over_channels_val( &
       this, upstream_grad, output &
  )
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    integer :: i, m, s
    integer :: num_elements, num_dims, num_channels

    num_dims = size(this%left_operand%shape)
    num_elements = product(this%left_operand%shape(1:num_dims - 1))
    num_channels = this%left_operand%shape(num_dims)

    do concurrent(s = 1:size(upstream_grad,2), m = 1: num_channels)
       do concurrent(i=1:num_elements)
          if(this%mask(i,1))then
             output(i + (m-1) * num_elements,s) = upstream_grad(i,s)
          else
             output(i + (m-1) * num_elements,s) = 0._real32
          end if
       end do
    end do

  end subroutine get_partial_merge_scalar_over_channels_val