athena_lr_decay.f90 Source File


Source Code

module athena__learning_rate_decay
  !! Module containing learning decay rate types and procedures
  use coreutils, only: real32
  implicit none


  private

  public :: base_lr_decay_type
  public :: exp_lr_decay_type
  public :: step_lr_decay_type
  public :: inv_lr_decay_type


!-------------------------------------------------------------------------------

  type base_lr_decay_type
     !! Type for learning rate decay
     character(len=20) :: name
     !! Name of the learning rate decay type
     real(real32) :: initial_learning_rate
     !! Initial learning rate
     real(real32) :: decay_rate
     !! Decay rate for learning rate
     logical :: iterate_per_epoch = .false.
     !! Whether to iterate learning rate decay per epoch
   contains
     procedure :: get_lr => lr_decay_none
     !! Procedure to get the learning rate
  end type base_lr_decay_type

  interface base_lr_decay_type
     !! Interface for base learning rate decay type
     module function setup_lr_decay_base() result(lr_decay)
       !! Set up base learning rate decay type
       type(base_lr_decay_type) :: lr_decay
       !! Base learning rate decay type
     end function setup_lr_decay_base
  end interface base_lr_decay_type

!-------------------------------------------------------------------------------

  type, extends(base_lr_decay_type) :: exp_lr_decay_type
     !! Type for exponential learning rate decay
   contains
     procedure :: get_lr => lr_decay_exp
     !! Procedure to get the learning rate
  end type exp_lr_decay_type

  interface exp_lr_decay_type
     !! Interface for exponential learning rate decay type
     module function setup_lr_decay_exp(decay_rate) result(lr_decay)
       !! Set up exponential learning rate decay type
       real(real32), optional, intent(in) :: decay_rate
       !! Decay rate for learning rate
       type(exp_lr_decay_type) :: lr_decay
       !! Exponential learning rate decay type
     end function setup_lr_decay_exp
  end interface exp_lr_decay_type

!-------------------------------------------------------------------------------

  type, extends(base_lr_decay_type) :: step_lr_decay_type
     !! Type for step learning rate decay
     integer :: decay_steps
     !! Number of steps for learning rate decay
   contains
     procedure :: get_lr => lr_decay_step
     !! Procedure to get the learning rate
  end type step_lr_decay_type

  interface step_lr_decay_type
     !! Interface for step learning rate decay type
     module function setup_lr_decay_step(decay_rate, decay_steps) &
          result(lr_decay)
       !! Set up step learning rate decay type
       real(real32), optional, intent(in) :: decay_rate
       !! Decay rate for learning rate
       integer, optional, intent(in) :: decay_steps
       !! Number of steps for learning rate decay
       type(step_lr_decay_type) :: lr_decay
       !! Step learning rate decay type
     end function setup_lr_decay_step
  end interface step_lr_decay_type

!-------------------------------------------------------------------------------

  type, extends(base_lr_decay_type) :: inv_lr_decay_type
     !! Type for inverse learning rate decay
     real(real32) :: decay_power
     !! Power for learning rate decay
   contains
     procedure :: get_lr => lr_decay_inv
     !! Procedure to get the learning rate
  end type inv_lr_decay_type

  interface inv_lr_decay_type
     !! Interface for inverse learning rate decay type
     module function setup_lr_decay_inv(decay_rate, decay_power) &
          result(lr_decay)
       !! Set up inverse learning rate decay type
       real(real32), optional, intent(in) :: decay_rate, decay_power
       !! Decay rate for learning rate
       type(inv_lr_decay_type) :: lr_decay
       !! Inverse learning rate decay type
     end function setup_lr_decay_inv
  end interface inv_lr_decay_type



contains

!###############################################################################
  module function setup_lr_decay_base() result(lr_decay)
    !! Set up base learning rate decay type
    implicit none

    ! Output variable
    type(base_lr_decay_type) :: lr_decay
    !! Instance of the base learning rate decay type

    lr_decay%name = "base"
    lr_decay%decay_rate = 0._real32

  end function setup_lr_decay_base
!-------------------------------------------------------------------------------
  module function setup_lr_decay_exp(decay_rate) result(lr_decay)
    !! Set up exponential learning rate decay type
    implicit none

    ! Arguments
    real(real32), optional, intent(in) :: decay_rate
    !! Decay rate for learning rate
    type(exp_lr_decay_type) :: lr_decay
    !! Exponential learning rate decay type

    lr_decay%name = "exp"
    if(present(decay_rate))then
       lr_decay%decay_rate = decay_rate
    else
       lr_decay%decay_rate = 0.9_real32
    end if

  end function setup_lr_decay_exp
!-------------------------------------------------------------------------------
  module function setup_lr_decay_step(decay_rate, decay_steps) result(lr_decay)
    !! Set up step learning rate decay type
    implicit none

    ! Arguments
    real(real32), optional, intent(in) :: decay_rate
    !! Decay rate for learning rate
    integer, optional, intent(in) :: decay_steps
    !! Number of steps for learning rate decay
    type(step_lr_decay_type) :: lr_decay
    !! Step learning rate decay type

    lr_decay%name = "step"
    if(present(decay_rate))then
       lr_decay%decay_rate = decay_rate
    else
       lr_decay%decay_rate = 0.1_real32
    end if
    if(present(decay_steps))then
       lr_decay%decay_steps = decay_steps
    else
       lr_decay%decay_steps = 100
    end if
    lr_decay%iterate_per_epoch = .true.

  end function setup_lr_decay_step
!-------------------------------------------------------------------------------
  module function setup_lr_decay_inv(decay_rate, decay_power) result(lr_decay)
    !! Set up inverse learning rate decay type
    implicit none

    ! Arguments
    real(real32), optional, intent(in) :: decay_rate, decay_power
    !! Decay rate for learning rate
    type(inv_lr_decay_type) :: lr_decay
    !! Inverse learning rate decay type

    lr_decay%name = "inv"
    if(present(decay_rate))then
       lr_decay%decay_rate = decay_rate
    else
       lr_decay%decay_rate = 0.001_real32
    end if
    if(present(decay_power))then
       lr_decay%decay_power = decay_power
    else
       lr_decay%decay_power = 1._real32
    end if

  end function setup_lr_decay_inv
!###############################################################################


!###############################################################################
  pure function lr_decay_none(this, learning_rate, iteration) result(output)
    !! Get the learning rate for the base decay type
    implicit none

    ! Arguments
    class(base_lr_decay_type), intent(in) :: this
    !! Instance of the base learning rate decay type
    real(real32), intent(in) :: learning_rate
    !! Initial learning rate
    integer, intent(in) :: iteration
    !! Iteration number
    real(real32) :: output
    !! Learning rate

    output = learning_rate

  end function lr_decay_none
!-------------------------------------------------------------------------------
  pure function lr_decay_exp(this, learning_rate, iteration) result(output)
    !! Get the learning rate for the exponential decay type
    implicit none

    ! Arguments
    class(exp_lr_decay_type), intent(in) :: this
    !! Instance of the exponential learning rate decay type
    real(real32), intent(in) :: learning_rate
    !! Initial learning rate
    integer, intent(in) :: iteration
    !! Iteration number
    real(real32) :: output
    !! Learning rate

    output = learning_rate * exp(- iteration * this%decay_rate)

  end function lr_decay_exp
!-------------------------------------------------------------------------------
  pure function lr_decay_step(this, learning_rate, iteration) result(output)
    !! Get the learning rate for the step decay type
    implicit none

    ! Arguments
    class(step_lr_decay_type), intent(in) :: this
    !! Instance of the step learning rate decay type
    real(real32), intent(in) :: learning_rate
    !! Initial learning rate
    integer, intent(in) :: iteration
    !! Iteration number
    real(real32) :: output
    !! Learning rate

    output = learning_rate * this%decay_rate ** (iteration / this%decay_steps)

  end function lr_decay_step
!-------------------------------------------------------------------------------
  pure function lr_decay_inv(this, learning_rate, iteration) result(output)
    !! Get the learning rate for the inverse decay type
    implicit none

    ! Arguments
    class(inv_lr_decay_type), intent(in) :: this
    !! Instance of the inverse learning rate decay type
    real(real32), intent(in) :: learning_rate
    !! Initial learning rate
    integer, intent(in) :: iteration
    !! Iteration number
    real(real32) :: output
    !! Learning rate

    output = learning_rate * &
         (1._real32 + this%decay_rate * iteration) ** (- this%decay_power)

  end function lr_decay_inv
!###############################################################################

end module athena__learning_rate_decay