module athena__initialiser_glorot !! Module containing the implementation of the Glorot initialiser !! !! This module implements Glorot (Xavier) initialisation, designed to !! maintain variance of gradients through layers with sigmoid/tanh. !! !! Mathematical operation: !! !! Uniform variant: !! \[ W \sim \mathcal{U}(-\text{limit}, \text{limit}), \quad \text{limit} = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} \] !! !! Normal variant: !! \[ W \sim \mathcal{N}(0, \sigma^2), \quad \sigma = \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}} \] !! !! where \(n_{\text{in}}\) is fan-in, \(n_{\text{out}}\) is fan-out. !! !! Rationale: Maintains variance across layers, prevents vanishing/exploding !! gradients in deep networks !! !! Best for: Tanh, Sigmoid, Softmax activations !! Reference: Glorot & Bengio (2010), AISTATS use coreutils, only: real32, pi, stop_program use athena__misc_types, only: base_init_type implicit none private public :: glorot_uniform_init_type public :: glorot_normal_init_type type, extends(base_init_type) :: glorot_uniform_init_type !! Type for the Glorot initialiser (uniform) contains procedure, pass(this) :: initialise => glorot_uniform_initialise !! Initialise the weights and biases using the Glorot uniform distribution end type glorot_uniform_init_type type, extends(base_init_type) :: glorot_normal_init_type !! Type for the Glorot initialiser (normal) contains procedure, pass(this) :: initialise => glorot_normal_initialise !! Initialise the weights and biases using the Glorot normal distribution end type glorot_normal_init_type interface glorot_uniform_init_type module function initialiser_uniform_setup() result(initialiser) !! Interface for the Glorot uniform initialiser type(glorot_uniform_init_type) :: initialiser !! Glorot uniform initialiser object end function initialiser_uniform_setup end interface glorot_uniform_init_type interface glorot_normal_init_type module function initialiser_normal_setup() result(initialiser) !! Interface for the Glorot normal initialiser type(glorot_normal_init_type) :: initialiser !! Glorot normal initialiser object end function initialiser_normal_setup end interface glorot_normal_init_type contains !############################################################################### module function initialiser_uniform_setup() result(initialiser) implicit none ! Arguments type(glorot_uniform_init_type) :: initialiser !! Glorot uniform initialiser object initialiser%name = "glorot_uniform" end function initialiser_uniform_setup !------------------------------------------------------------------------------- module function initialiser_normal_setup() result(initialiser) implicit none ! Arguments type(glorot_normal_init_type) :: initialiser !! Glorot normal initialiser object initialiser%name = "glorot_normal" end function initialiser_normal_setup !############################################################################### !############################################################################### subroutine glorot_uniform_initialise(this, input, fan_in, fan_out, spacing) !! Initialise the weights and biases using the Glorot uniform distribution implicit none ! Arguments class(glorot_uniform_init_type), intent(inout) :: this !! Instance of the Glorot initialiser real(real32), dimension(..), intent(out) :: input !! Weights and biases to initialise integer, optional, intent(in) :: fan_in, fan_out !! Number of input and output units integer, dimension(:), optional, intent(in) :: spacing !! Spacing of the input and output units (not used) ! Local variables integer :: n !! Number of elements in the input array real(real32) :: limit !! Scaling factor real(real32), dimension(:), allocatable :: r !! Temporary uniform random numbers ! Validate inputs if(.not.present(fan_in)) & call stop_program("glorot_uniform_initialise: fan_in not present") if(.not.present(fan_out)) & call stop_program("glorot_uniform_initialise: fan_out not present") limit = sqrt(6._real32 / real(fan_in + fan_out, real32)) n = size(input) allocate(r(n)) call random_number(r) r = (2._real32 * r - 1._real32) * limit ! Assign according to rank select rank(input) rank(0) input = r(1) rank(1) input = r rank(2) input = reshape(r, shape(input)) rank(3) input = reshape(r, shape(input)) rank(4) input = reshape(r, shape(input)) rank(5) input = reshape(r, shape(input)) rank(6) input = reshape(r, shape(input)) end select deallocate(r) end subroutine glorot_uniform_initialise !############################################################################### !############################################################################### subroutine glorot_normal_initialise(this, input, fan_in, fan_out, spacing) !! Initialise the weights and biases using the Glorot normal distribution implicit none ! Arguments class(glorot_normal_init_type), intent(inout) :: this !! Instance of the Glorot initialiser real(real32), dimension(..), intent(out) :: input !! Weights to initialise integer, optional, intent(in) :: fan_in, fan_out !! Number of input and output units integer, dimension(:), optional, intent(in) :: spacing !! Spacing of the input and output units (not used here, included for compatibility) ! Local variables integer :: n !! Number of elements in the input array real(real32) :: sigma !! Scaling factor real(real32), dimension(:), allocatable :: u1, u2, z !! Temporary arrays for the random numbers ! Default fallback values (to avoid division by zero) if(.not.present(fan_in)) & call stop_program("glorot_normal_initialise: fan_in not present") if(.not.present(fan_out)) & call stop_program("glorot_normal_initialise: fan_out not present") sigma = sqrt(2._real32 / real(fan_in + fan_out, real32)) n = size(input) allocate(u1(n), u2(n), z(n)) call random_number(u1) call random_number(u2) where (u1 .lt. 1.E-7_real32) u1 = 1.E-7_real32 end where ! Box-Muller transform for normal distribution z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2) z = sigma * z ! Assign according to rank select rank(input) rank(0) input = z(1) rank(1) input = z rank(2) input = reshape(z, shape(input)) rank(3) input = reshape(z, shape(input)) rank(4) input = reshape(z, shape(input)) rank(5) input = reshape(z, shape(input)) rank(6) input = reshape(z, shape(input)) end select deallocate(u1, u2, z) end subroutine glorot_normal_initialise !############################################################################### end module athena__initialiser_glorot