module athena__initialiser
  !! Module containing functions to set up initialisers
  !!
  !! This module contains functions to set up initialisers for the weights and
  !! biases of a neural network model
  !! Examples of initialsers in keras: https://keras.io/api/layers/initializers/
  use coreutils, only: stop_program, to_lower
  use athena__misc_types, only: base_init_type
  use athena__initialiser_glorot, only: &
       glorot_uniform_init_type, glorot_normal_init_type
  use athena__initialiser_he, only: he_uniform_init_type, he_normal_init_type
  use athena__initialiser_lecun, only: &
       lecun_uniform_init_type, lecun_normal_init_type
  use athena__initialiser_ones, only: ones_init_type
  use athena__initialiser_zeros, only: zeros_init_type
  use athena__initialiser_ident, only: ident_init_type
  use athena__initialiser_gaussian, only: gaussian_init_type
  implicit none


  private

  public :: initialiser_setup, get_default_initialiser


contains

!###############################################################################
  function get_default_initialiser(activation, is_bias) result(name)
    !! Get the default initialiser based on the activation function
    implicit none

    ! Arguments
    character(*), intent(in) :: activation
    !! Activation function
    logical, optional, intent(in) :: is_bias
    !! Boolean whether initialiser is for bias

    character(:), allocatable :: name


    !---------------------------------------------------------------------------
    ! If bias, use default initialiser of zero
    !---------------------------------------------------------------------------
    if(present(is_bias))then
       if(is_bias) name = "zeros"
       return
    end if


    !---------------------------------------------------------------------------
    ! Set default initialiser based on activation
    !---------------------------------------------------------------------------
    if(trim(activation).eq."selu")then
       name = "lecun_normal"
    elseif(index(activation,"elu").ne.0)then
       name = "he_uniform"
    elseif(trim(activation).eq."batch")then
       name = "gaussian"
    else
       name = "glorot_uniform"
    end if

  end function get_default_initialiser
!###############################################################################


!###############################################################################
  function initialiser_setup(input, error) result(initialiser)
    !! Set up the initialiser function
    implicit none

    ! Arguments
    class(base_init_type), allocatable :: initialiser
    !! Initialiser function
    class(*) :: input
    !! Name of initialiser or initialiser object
    integer, optional, intent(out) :: error
    !! Error code

    ! Local variables
    character(256) :: err_msg
    !! Error message


    !---------------------------------------------------------------------------
    ! Set initialiser function
    !---------------------------------------------------------------------------
    select type(input)
    class is(base_init_type)
       initialiser = input
    type is(character(*))
       select case(trim(to_lower(input)))
       case("glorot_uniform")
          initialiser = glorot_uniform_init_type()
       case("glorot_normal")
          initialiser = glorot_normal_init_type()
       case("he_uniform")
          initialiser = he_uniform_init_type()
       case("he_normal")
          initialiser = he_normal_init_type()
       case("lecun_uniform")
          initialiser = lecun_uniform_init_type()
       case("lecun_normal")
          initialiser = lecun_normal_init_type()
       case("ones")
          initialiser = ones_init_type()
       case("zeros")
          initialiser = zeros_init_type()
       case("ident")
          initialiser = ident_init_type()
       case("gaussian")
          initialiser = gaussian_init_type()
       case("normal")
          initialiser = gaussian_init_type(name="normal")
       case default
          if(present(error))then
             error = -1
             return
          else
             write(err_msg,'("Incorrect initialiser name given ''",A,"''")') &
                  trim(to_lower(input))
             call stop_program(trim(err_msg))
             return
          end if
       end select
    class default
       if(present(error))then
          error = -1
          return
       else
          write(err_msg,'("Unknown input type given for initialiser setup")')
          call stop_program(trim(err_msg))
          return
       end if
    end select

  end function initialiser_setup
!###############################################################################

end module athena__initialiser
