minimise_adam Subroutine

private pure subroutine minimise_adam(this, param, gradient)

Apply gradients to parameters to minimise loss using Adam optimiser

Type Bound

adam_optimiser_type

Arguments

Type IntentOptional Attributes Name
class(adam_optimiser_type), intent(inout) :: this

Instance of the Adam optimiser

real(kind=real32), intent(inout), dimension(:) :: param

Parameters

real(kind=real32), intent(inout), dimension(:) :: gradient

Gradients


Source Code

  pure subroutine minimise_adam(this, param, gradient)
    !! Apply gradients to parameters to minimise loss using Adam optimiser
    implicit none

    ! Arguments
    class(adam_optimiser_type), intent(inout) :: this
    !! Instance of the Adam optimiser
    real(real32), dimension(:), intent(inout) :: param
    !! Parameters
    real(real32), dimension(:), intent(inout) :: gradient
    !! Gradients

    ! Local variables
    real(real32) :: learning_rate, bias_correction1, bias_correction2
    !! Learning rate


    ! Decay learning rate and update iteration
    learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)

    ! Apply regularisation
    if(this%regularisation) &
         call this%regulariser%regularise( param, gradient, learning_rate )

    ! Adaptive learning method
    this%m = this%beta1 * this%m + &
         (1._real32 - this%beta1) * gradient
    this%v = this%beta2 * this%v + &
         (1._real32 - this%beta2) * gradient * gradient

    ! Bias corrections
    bias_correction1 = 1._real32 - this%beta1**this%iter
    bias_correction2 = 1._real32 - this%beta2**this%iter

    ! Update parameters
    associate( &
         m_hat => this%m / bias_correction1, &
         v_hat => this%v / bias_correction2 )
       if(this%regularisation .and. allocated(this%regulariser))then
          select type(regulariser => this%regulariser)
          type is (l2_regulariser_type)
             select case(regulariser%decoupled)
             case(.true.)
                ! decoupled weight decay (AdamW)
                param = param - learning_rate * regulariser%l2 * param
                param = param - learning_rate * ( m_hat / (sqrt(v_hat) + this%epsilon) )
             case(.false.)
                ! classical L2 regularisation (included in gradient)
                param = param - learning_rate * ( &
                     ( m_hat + regulariser%l2 * param ) / &
                     ( sqrt(v_hat) + this%epsilon ) &
                )
             end select
          class default
             ! unknown regulariser — fall back to standard Adam
             param = param - learning_rate * ( &
                  m_hat / (sqrt(v_hat) + this%epsilon) )
          end select
       else
          ! no regularisation — standard Adam
          param = param - learning_rate * ( &
               m_hat / (sqrt(v_hat) + this%epsilon) )
       end if
    end associate
  end subroutine minimise_adam