module function batchnorm_inference( &
input, params, mean, variance, epsilon &
) result( output )
implicit none
class(array_type), intent(in), target :: input
class(array_type), intent(in), target :: params
real(real32), dimension(:), intent(in) :: mean
real(real32), dimension(:), intent(in) :: variance
real(real32), intent(in) :: epsilon
type(batchnorm_array_type), pointer :: output
integer :: i, c, s
integer :: num_elements, num_dims
allocate(output)
if(output%allocated) call output%deallocate()
call output%allocate(array_shape = [ input%shape, size(input%val,2) ])
output%epsilon = epsilon
output%mean = mean
output%variance = variance
num_dims = size(input%shape)
num_elements = product(input%shape(1:num_dims - 1))
do concurrent(c = 1:input%shape(num_dims))
do concurrent(s = 1:size(input%val,2), i = 1:num_elements)
output%val(i + (c-1) * num_elements, s) = &
params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - &
mean(c) ) / sqrt(variance(c) + output%epsilon) + &
params%val(c+input%shape(num_dims),1)
end do
end do
end function batchnorm_inference