Initialise the weights and biases using the Glorot normal distribution
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(glorot_normal_init_type), | intent(inout) | :: | this |
Instance of the Glorot initialiser |
||
| real(kind=real32), | intent(out), | dimension(..) | :: | input |
Weights to initialise |
|
| integer, | intent(in), | optional | :: | fan_in |
Number of input and output units |
|
| integer, | intent(in), | optional | :: | fan_out |
Number of input and output units |
|
| integer, | intent(in), | optional, | dimension(:) | :: | spacing |
Spacing of the input and output units (not used here, included for compatibility) |
subroutine glorot_normal_initialise(this, input, fan_in, fan_out, spacing) !! Initialise the weights and biases using the Glorot normal distribution implicit none ! Arguments class(glorot_normal_init_type), intent(inout) :: this !! Instance of the Glorot initialiser real(real32), dimension(..), intent(out) :: input !! Weights to initialise integer, optional, intent(in) :: fan_in, fan_out !! Number of input and output units integer, dimension(:), optional, intent(in) :: spacing !! Spacing of the input and output units (not used here, included for compatibility) ! Local variables integer :: n !! Number of elements in the input array real(real32) :: sigma !! Scaling factor real(real32), dimension(:), allocatable :: u1, u2, z !! Temporary arrays for the random numbers ! Default fallback values (to avoid division by zero) if(.not.present(fan_in)) & call stop_program("glorot_normal_initialise: fan_in not present") if(.not.present(fan_out)) & call stop_program("glorot_normal_initialise: fan_out not present") sigma = sqrt(2._real32 / real(fan_in + fan_out, real32)) n = size(input) allocate(u1(n), u2(n), z(n)) call random_number(u1) call random_number(u2) where (u1 .lt. 1.E-7_real32) u1 = 1.E-7_real32 end where ! Box-Muller transform for normal distribution z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2) z = sigma * z ! Assign according to rank select rank(input) rank(0) input = z(1) rank(1) input = z rank(2) input = reshape(z, shape(input)) rank(3) input = reshape(z, shape(input)) rank(4) input = reshape(z, shape(input)) rank(5) input = reshape(z, shape(input)) rank(6) input = reshape(z, shape(input)) end select deallocate(u1, u2, z) end subroutine glorot_normal_initialise