glorot_uniform_initialise Subroutine

private subroutine glorot_uniform_initialise(this, input, fan_in, fan_out, spacing)

Initialise the weights and biases using the Glorot uniform distribution

Type Bound

glorot_uniform_init_type

Arguments

Type IntentOptional Attributes Name
class(glorot_uniform_init_type), intent(inout) :: this

Instance of the Glorot initialiser

real(kind=real32), intent(out), dimension(..) :: input

Weights and biases to initialise

integer, intent(in), optional :: fan_in

Number of input and output units

integer, intent(in), optional :: fan_out

Number of input and output units

integer, intent(in), optional, dimension(:) :: spacing

Spacing of the input and output units (not used)


Source Code

  subroutine glorot_uniform_initialise(this, input, fan_in, fan_out, spacing)
    !! Initialise the weights and biases using the Glorot uniform distribution
    implicit none

    ! Arguments
    class(glorot_uniform_init_type), intent(inout) :: this
    !! Instance of the Glorot initialiser
    real(real32), dimension(..), intent(out) :: input
    !! Weights and biases to initialise
    integer, optional, intent(in) :: fan_in, fan_out
    !! Number of input and output units
    integer, dimension(:), optional, intent(in) :: spacing
    !! Spacing of the input and output units (not used)

    ! Local variables
    integer :: n
    !! Number of elements in the input array
    real(real32) :: limit
    !! Scaling factor
    real(real32), dimension(:), allocatable :: r
    !! Temporary uniform random numbers

    ! Validate inputs
    if(.not.present(fan_in)) &
         call stop_program("glorot_uniform_initialise: fan_in not present")
    if(.not.present(fan_out)) &
         call stop_program("glorot_uniform_initialise: fan_out not present")

    limit = sqrt(6._real32 / real(fan_in + fan_out, real32))
    n = size(input)
    allocate(r(n))
    call random_number(r)
    r = (2._real32 * r - 1._real32) * limit

    ! Assign according to rank
    select rank(input)
    rank(0)
       input = r(1)
    rank(1)
       input = r
    rank(2)
       input = reshape(r, shape(input))
    rank(3)
       input = reshape(r, shape(input))
    rank(4)
       input = reshape(r, shape(input))
    rank(5)
       input = reshape(r, shape(input))
    rank(6)
       input = reshape(r, shape(input))
    end select

    deallocate(r)
  end subroutine glorot_uniform_initialise