Batch normalisation operation
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in), | target | :: | input | ||
| class(array_type), | intent(in), | target | :: | params | ||
| real(kind=real32), | intent(in) | :: | momentum | |||
| real(kind=real32), | intent(in), | dimension(:) | :: | mean | ||
| real(kind=real32), | intent(in), | dimension(:) | :: | variance | ||
| real(kind=real32), | intent(in) | :: | epsilon |
module function batchnorm( & input, params, momentum, mean, variance, epsilon & ) result( output ) !! Batch normalisation operation implicit none ! Arguments class(array_type), intent(in), target :: input class(array_type), intent(in), target :: params real(real32), intent(in) :: momentum real(real32), dimension(:), intent(in) :: mean real(real32), dimension(:), intent(in) :: variance real(real32), intent(in) :: epsilon type(batchnorm_array_type), pointer :: output ! Local variables integer :: i, c, s integer :: num_elements, num_dims real(real32) :: mu, var, norm allocate(output) if(output%allocated) call output%deallocate() call output%allocate(array_shape = [ input%shape, size(input%val,2) ]) output%epsilon = epsilon output%mean = mean output%variance = variance num_dims = size(input%shape) num_elements = product(input%shape(1:num_dims - 1)) norm = real(num_elements * size(input%val,2), real32) do concurrent(c = 1:input%shape(num_dims)) mu = 0._real32 var = 0._real32 mu = sum(input%val((c-1) * num_elements+1:c*num_elements,:)) / norm var = sum( (input%val((c-1) * num_elements+1:c*num_elements,:) - mu) ** 2 ) / & norm if(momentum .gt. 1.E-8_real32)then output%mean(c) = momentum * mean(c) + (1._real32 - momentum) * mu output%variance(c) = momentum * variance(c) + (1._real32 - momentum) * var else output%mean(c) = mu output%variance(c) = var end if do concurrent(s = 1:size(input%val,2), i = 1:num_elements) output%val(i + (c-1) * num_elements, s) = & params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - mu ) / & sqrt(var + output%epsilon) + params%val(c+input%shape(num_dims),1) end do end do output%get_partial_left => get_partial_batchnorm_left output%get_partial_left_val => get_partial_batchnorm_left_val output%get_partial_right => get_partial_batchnorm_right output%get_partial_right_val => get_partial_batchnorm_right_val if(input%requires_grad .or. params%requires_grad)then output%requires_grad = .true. output%is_forward = input%is_forward .or. params%is_forward output%operation = 'batchnorm' output%left_operand => input output%right_operand => params end if end function batchnorm