athena_diffstruc_extd_sub_merge.f90 Source File


Source Code

submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_merge
  !! Submodule containing implementations for extended diffstruc array operations

contains

!###############################################################################
  module function merge_scalar_over_channels(tsource, fsource, mask) result(output)
    !! 1D average pooling operation
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: tsource
    real(real32), intent(in) :: fsource
    logical, dimension(:,:), intent(in) :: mask
    type(array_type), pointer :: output

    ! Local variables
    integer :: i, m, s
    integer :: num_elements, num_dims


    output => tsource%create_result()
    num_dims = size(tsource%shape)
    num_elements = product(tsource%shape(1:num_dims - 1))
    do concurrent(s = 1:size(tsource%val,2), m = 1: tsource%shape(num_dims))
       do concurrent(i=1:num_elements)
          if(mask(i,1))then
             output%val(i + (m-1) * num_elements,s) = tsource%val(i,s)
          else
             output%val(i + (m-1) * num_elements,s) = fsource
          end if
       end do
    end do
    output%mask = mask

    output%get_partial_left => get_partial_merge_scalar_over_channels
    output%get_partial_left_val => get_partial_merge_scalar_over_channels_val
    if(tsource%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = tsource%is_forward
       output%operation = 'merge_over_channels'
       output%left_operand => tsource
    end if

  end function merge_scalar_over_channels
!-------------------------------------------------------------------------------
  function get_partial_merge_scalar_over_channels(this, upstream_grad) result(output)
    implicit none
    class(array_type), intent(inout) :: this
    type(array_type), intent(in) :: upstream_grad
    type(array_type) :: output

    output = merge_scalar_over_channels(upstream_grad, 0._real32, this%mask)

  end function get_partial_merge_scalar_over_channels
!-------------------------------------------------------------------------------
  pure subroutine get_partial_merge_scalar_over_channels_val( &
       this, upstream_grad, output &
  )
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in) :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    integer :: i, m, s
    integer :: num_elements, num_dims, num_channels

    num_dims = size(this%left_operand%shape)
    num_elements = product(this%left_operand%shape(1:num_dims - 1))
    num_channels = this%left_operand%shape(num_dims)

    do concurrent(s = 1:size(upstream_grad,2), m = 1: num_channels)
       do concurrent(i=1:num_elements)
          if(this%mask(i,1))then
             output(i + (m-1) * num_elements,s) = upstream_grad(i,s)
          else
             output(i + (m-1) * num_elements,s) = 0._real32
          end if
       end do
    end do

  end subroutine get_partial_merge_scalar_over_channels_val
!###############################################################################

end submodule athena__diffstruc_extd_submodule_merge