module athena__initialiser_he
  !! Module containing the implementation of the He initialiser
  !!
  !! This module implements He (Kaiming/MSRA) initialisation, designed for
  !! layers with ReLU activation to prevent vanishing/exploding gradients.
  !!
  !! Mathematical operation:
  !!
  !! Uniform variant:
  !! \[ W \sim \mathcal{U}(-\text{limit}, \text{limit}), \quad \text{limit} = \sqrt{\frac{6}{n_{\text{in}}}} \]
  !!
  !! Normal variant:
  !! \[ W \sim \mathcal{N}(0, \sigma^2), \quad \sigma = \sqrt{\frac{2}{n_{\text{in}}}} \]
  !!
  !! where \(n_{\text{in}}\) is the number of input units (fan-in).
  !!
  !! Rationale: Maintains variance through ReLU layers
  !! \(\text{Var}(\text{output}) \approx \text{Var}(\text{input})\)
  !!
  !! Best for: ReLU, Leaky ReLU, PReLU activations
  !! Reference: He et al. (2015), ICCV, arXiv:1502.01852
  use coreutils, only: real32, pi, to_lower, stop_program
  use athena__misc_types, only: base_init_type
  implicit none


  private

  public :: he_uniform_init_type, he_normal_init_type


  type, extends(base_init_type) :: he_uniform_init_type
     !! Type for the He initialiser (uniform)
     integer, private :: mode = 1
   contains
     procedure, pass(this) :: initialise => he_uniform_initialise
     !! Initialise the weights and biases using the He uniform distribution
  end type he_uniform_init_type

  type, extends(base_init_type) :: he_normal_init_type
     !! Type for the He initialiser (normal)
     integer, private :: mode = 1
   contains
     procedure, pass(this) :: initialise => he_normal_initialise
     !! Initialise the weights and biases using the He normal distribution
  end type he_normal_init_type


  interface he_uniform_init_type
     module function initialiser_uniform_setup(scale, mode) result(initialiser)
       !! Interface for the He uniform initialiser
       real(real32), intent(in), optional :: scale
       !! Scaling factor (default: 1.0)
       character(len=*), intent(in), optional :: mode
       !! Mode for calculating the scaling factor (default: "fan_in")
       type(he_uniform_init_type) :: initialiser
       !! He uniform initialiser object
     end function initialiser_uniform_setup
  end interface he_uniform_init_type

  interface he_normal_init_type
     module function initialiser_normal_setup(scale, mode) result(initialiser)
       !! Interface for the He normal initialiser
       real(real32), intent(in), optional :: scale
       !! Scaling factor (default: 1.0)
       character(len=*), intent(in), optional :: mode
       !! Mode for calculating the scaling factor (default: "fan_in")
       type(he_normal_init_type) :: initialiser
       !! He normal initialiser object
     end function initialiser_normal_setup
  end interface he_normal_init_type



contains

!###############################################################################
  module function initialiser_uniform_setup(scale, mode) result(initialiser)
    implicit none
    ! Arguments
    real(real32), intent(in), optional :: scale
    !! Scaling factor (default: 1.0)
    character(len=*), intent(in), optional :: mode
    !! Mode for calculating the scaling factor (default: "fan_in")
    type(he_uniform_init_type) :: initialiser
    !! He uniform initialiser object

    ! Local variables
    character(len=20) :: mode_
    !! Mode for calculating the scaling factor

    initialiser%name = "he_uniform"
    if(present(scale)) initialiser%scale = scale
    if(present(mode))then
       mode_ = to_lower(trim(mode))
       select case(mode_)
       case("fan_in")
          initialiser%mode = 1
       case("fan_out")
          initialiser%mode = 2
       case default
          call stop_program("initialiser_setup: invalid mode")
       end select
    end if

  end function initialiser_uniform_setup
!-------------------------------------------------------------------------------
  module function initialiser_normal_setup(scale, mode) result(initialiser)
    implicit none
    ! Arguments
    real(real32), intent(in), optional :: scale
    !! Scaling factor (default: 1.0)
    character(len=*), intent(in), optional :: mode
    !! Mode for calculating the scaling factor (default: "fan_in")
    type(he_normal_init_type) :: initialiser
    !! He normal initialiser object

    ! Local variables
    character(len=20) :: mode_
    !! Mode for calculating the scaling factor

    initialiser%name = "he_normal"
    if(present(scale)) initialiser%scale = scale
    if(present(mode))then
       mode_ = to_lower(trim(mode))
       select case(mode_)
       case("fan_in")
          initialiser%mode = 1
       case("fan_out")
          initialiser%mode = 2
       case default
          call stop_program("initialiser_setup: invalid mode")
       end select
    end if
  end function initialiser_normal_setup
!###############################################################################


!###############################################################################
  subroutine he_uniform_initialise(this, input, fan_in, fan_out, spacing)
    !! Initialise the weights and biases using the He uniform distribution
    implicit none

    ! Arguments
    class(he_uniform_init_type), intent(inout) :: this
    !! Instance of the Glorot initialiser
    real(real32), dimension(..), intent(out) :: input
    !! Weights and biases to initialise
    integer, optional, intent(in) :: fan_in, fan_out
    !! Number of input and output units
    integer, dimension(:), optional, intent(in) :: spacing
    !! Spacing of the input and output units (not used)

    ! Local variables
    integer :: n
    !! Number of elements in the input array
    real(real32) :: limit
    !! Scaling factor
    real(real32), dimension(:), allocatable :: r
    !! Temporary uniform random numbers

    if(.not.present(fan_in)) &
         call stop_program("he_uniform_initialise: fan_in not present")

    select case(this%mode)
    case(1)
       limit = this%scale * sqrt(6._real32 / real(fan_in, real32))
    case(2)
       limit = this%scale * sqrt(6._real32 / real(fan_out, real32))
    case default
       call stop_program("he_uniform_initialise: invalid mode")
    end select
    n = size(input)
    allocate(r(n))
    call random_number(r)
    r = (2._real32 * r - 1._real32) * limit

    ! Assign according to rank
    select rank(input)
    rank(0)
       input = r(1)
    rank(1)
       input = r
    rank(2)
       input = reshape(r, shape(input))
    rank(3)
       input = reshape(r, shape(input))
    rank(4)
       input = reshape(r, shape(input))
    rank(5)
       input = reshape(r, shape(input))
    rank(6)
       input = reshape(r, shape(input))
    end select

    deallocate(r)
  end subroutine he_uniform_initialise
!###############################################################################


!###############################################################################
  subroutine he_normal_initialise(this, input, fan_in, fan_out, spacing)
    !! Initialise the weights and biases using the He normal distribution
    implicit none

    ! Arguments
    class(he_normal_init_type), intent(inout) :: this
    !! Instance of the He initialiser
    real(real32), dimension(..), intent(out) :: input
    !! Weights and biases to initialise
    integer, optional, intent(in) :: fan_in, fan_out
    !! Number of input and output parameters
    integer, dimension(:), optional, intent(in) :: spacing
    !! Spacing of the input and output units (not used)

    ! Local variables
    integer :: n
    !! Number of elements in the input array
    real(real32) :: sigma
    !! Scaling factor
    real(real32), dimension(:), allocatable :: u1, u2, z
    !! Temporary arrays for the random numbers

    if(.not.present(fan_in)) &
         call stop_program("he_normal_initialise: fan_in not present")

    select case(this%mode)
    case(1)
       sigma = this%scale * sqrt(2._real32/real(fan_in,real32))
    case(2)
       sigma = this%scale * sqrt(2._real32/real(fan_out,real32))
    case default
       call stop_program("he_uniform_initialise: invalid mode")
    end select
    n = size(input)
    allocate(u1(n), u2(n), z(n))

    call random_number(u1)
    call random_number(u2)
    where (u1 .lt. 1.E-7_real32)
       u1 = 1.E-7_real32
    end where

    ! Box-Muller transform
    z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2)
    z = sigma * z

    select rank(input)
    rank(0)
       input = z(1)
    rank(1)
       input = z
    rank(2)
       input = reshape(z, shape(input))
    rank(3)
       input = reshape(z, shape(input))
    rank(4)
       input = reshape(z, shape(input))
    rank(5)
       input = reshape(z, shape(input))
    rank(6)
       input = reshape(z, shape(input))
    end select

    deallocate(u1, u2, z)

  end subroutine he_normal_initialise
!###############################################################################

end module athena__initialiser_he
