batchnorm Module Function

module function batchnorm(input, params, momentum, mean, variance, epsilon) result(output)

Batch normalisation operation

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: input
class(array_type), intent(in), target :: params
real(kind=real32), intent(in) :: momentum
real(kind=real32), intent(in), dimension(:) :: mean
real(kind=real32), intent(in), dimension(:) :: variance
real(kind=real32), intent(in) :: epsilon

Return Value type(batchnorm_array_type), pointer


Source Code

  module function batchnorm( &
       input, params, momentum, mean, variance, epsilon &
  ) result( output )
    !! Batch normalisation operation
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: input
    class(array_type), intent(in), target :: params
    real(real32), intent(in) :: momentum
    real(real32), dimension(:), intent(in) :: mean
    real(real32), dimension(:), intent(in) :: variance
    real(real32), intent(in) :: epsilon
    type(batchnorm_array_type), pointer :: output

    ! Local variables
    integer :: i, c, s
    integer :: num_elements, num_dims
    real(real32) :: mu, var, norm

    allocate(output)
    if(output%allocated) call output%deallocate()
    call output%allocate(array_shape = [ input%shape, size(input%val,2) ])
    output%epsilon = epsilon
    output%mean = mean
    output%variance = variance
    num_dims = size(input%shape)
    num_elements = product(input%shape(1:num_dims - 1))
    norm = real(num_elements * size(input%val,2), real32)
    do concurrent(c = 1:input%shape(num_dims))
       mu = 0._real32
       var = 0._real32
       mu = sum(input%val((c-1) * num_elements+1:c*num_elements,:)) / norm
       var = sum( (input%val((c-1) * num_elements+1:c*num_elements,:) - mu) ** 2 ) / &
            norm

       if(momentum .gt. 1.E-8_real32)then
          output%mean(c) = momentum * mean(c) + (1._real32 - momentum) * mu
          output%variance(c) = momentum * variance(c) + (1._real32 - momentum) * var
       else
          output%mean(c) = mu
          output%variance(c) = var
       end if

       do concurrent(s = 1:size(input%val,2), i = 1:num_elements)
          output%val(i + (c-1) * num_elements, s) = &
               params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - mu ) / &
               sqrt(var + output%epsilon) + params%val(c+input%shape(num_dims),1)
       end do
    end do

    output%get_partial_left => get_partial_batchnorm_left
    output%get_partial_left_val => get_partial_batchnorm_left_val
    output%get_partial_right => get_partial_batchnorm_right
    output%get_partial_right_val => get_partial_batchnorm_right_val
    if(input%requires_grad .or. params%requires_grad)then
       output%requires_grad = .true.
       output%is_forward = input%is_forward .or. params%is_forward
       output%operation = 'batchnorm'
       output%left_operand => input
       output%right_operand => params
    end if

  end function batchnorm