batchnorm_inference Module Function

module function batchnorm_inference(input, params, mean, variance, epsilon) result(output)

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: input
class(array_type), intent(in), target :: params
real(kind=real32), intent(in), dimension(:) :: mean
real(kind=real32), intent(in), dimension(:) :: variance
real(kind=real32), intent(in) :: epsilon

Return Value type(batchnorm_array_type), pointer


Source Code

  module function batchnorm_inference( &
       input, params, mean, variance, epsilon &
  ) result( output )
    implicit none
    class(array_type), intent(in), target :: input
    class(array_type), intent(in), target :: params
    real(real32), dimension(:), intent(in) :: mean
    real(real32), dimension(:), intent(in) :: variance
    real(real32), intent(in) :: epsilon
    type(batchnorm_array_type), pointer :: output

    integer :: i, c, s
    integer :: num_elements, num_dims

    allocate(output)
    if(output%allocated) call output%deallocate()
    call output%allocate(array_shape = [ input%shape, size(input%val,2) ])
    output%epsilon = epsilon
    output%mean = mean
    output%variance = variance
    num_dims = size(input%shape)
    num_elements = product(input%shape(1:num_dims - 1))
    do concurrent(c = 1:input%shape(num_dims))
       do concurrent(s = 1:size(input%val,2), i = 1:num_elements)
          output%val(i + (c-1) * num_elements, s) = &
               params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - &
                    mean(c) ) / sqrt(variance(c) + output%epsilon) + &
               params%val(c+input%shape(num_dims),1)
       end do
    end do

  end function batchnorm_inference