he_normal_initialise Subroutine

private subroutine he_normal_initialise(this, input, fan_in, fan_out, spacing)

Initialise the weights and biases using the He normal distribution

Type Bound

he_normal_init_type

Arguments

Type IntentOptional Attributes Name
class(he_normal_init_type), intent(inout) :: this

Instance of the He initialiser

real(kind=real32), intent(out), dimension(..) :: input

Weights and biases to initialise

integer, intent(in), optional :: fan_in

Number of input and output parameters

integer, intent(in), optional :: fan_out

Number of input and output parameters

integer, intent(in), optional, dimension(:) :: spacing

Spacing of the input and output units (not used)


Source Code

  subroutine he_normal_initialise(this, input, fan_in, fan_out, spacing)
    !! Initialise the weights and biases using the He normal distribution
    implicit none

    ! Arguments
    class(he_normal_init_type), intent(inout) :: this
    !! Instance of the He initialiser
    real(real32), dimension(..), intent(out) :: input
    !! Weights and biases to initialise
    integer, optional, intent(in) :: fan_in, fan_out
    !! Number of input and output parameters
    integer, dimension(:), optional, intent(in) :: spacing
    !! Spacing of the input and output units (not used)

    ! Local variables
    integer :: n
    !! Number of elements in the input array
    real(real32) :: sigma
    !! Scaling factor
    real(real32), dimension(:), allocatable :: u1, u2, z
    !! Temporary arrays for the random numbers

    if(.not.present(fan_in)) &
         call stop_program("he_normal_initialise: fan_in not present")

    select case(this%mode)
    case(1)
       sigma = this%scale * sqrt(2._real32/real(fan_in,real32))
    case(2)
       sigma = this%scale * sqrt(2._real32/real(fan_out,real32))
    case default
       call stop_program("he_uniform_initialise: invalid mode")
    end select
    n = size(input)
    allocate(u1(n), u2(n), z(n))

    call random_number(u1)
    call random_number(u2)
    where (u1 .lt. 1.E-7_real32)
       u1 = 1.E-7_real32
    end where

    ! Box-Muller transform
    z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2)
    z = sigma * z

    select rank(input)
    rank(0)
       input = z(1)
    rank(1)
       input = z
    rank(2)
       input = reshape(z, shape(input))
    rank(3)
       input = reshape(z, shape(input))
    rank(4)
       input = reshape(z, shape(input))
    rank(5)
       input = reshape(z, shape(input))
    rank(6)
       input = reshape(z, shape(input))
    end select

    deallocate(u1, u2, z)

  end subroutine he_normal_initialise