Apply gradients to parameters to minimise loss using Adam optimiser
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(adam_optimiser_type), | intent(inout) | :: | this |
Instance of the Adam optimiser |
||
| real(kind=real32), | intent(inout), | dimension(:) | :: | param |
Parameters |
|
| real(kind=real32), | intent(inout), | dimension(:) | :: | gradient |
Gradients |
pure subroutine minimise_adam(this, param, gradient) !! Apply gradients to parameters to minimise loss using Adam optimiser implicit none ! Arguments class(adam_optimiser_type), intent(inout) :: this !! Instance of the Adam optimiser real(real32), dimension(:), intent(inout) :: param !! Parameters real(real32), dimension(:), intent(inout) :: gradient !! Gradients ! Local variables real(real32) :: learning_rate, bias_correction1, bias_correction2 !! Learning rate ! Decay learning rate and update iteration learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) ! Apply regularisation if(this%regularisation) & call this%regulariser%regularise( param, gradient, learning_rate ) ! Adaptive learning method this%m = this%beta1 * this%m + & (1._real32 - this%beta1) * gradient this%v = this%beta2 * this%v + & (1._real32 - this%beta2) * gradient * gradient ! Bias corrections bias_correction1 = 1._real32 - this%beta1**this%iter bias_correction2 = 1._real32 - this%beta2**this%iter ! Update parameters associate( & m_hat => this%m / bias_correction1, & v_hat => this%v / bias_correction2 ) if(this%regularisation .and. allocated(this%regulariser))then select type(regulariser => this%regulariser) type is (l2_regulariser_type) select case(regulariser%decoupled) case(.true.) ! decoupled weight decay (AdamW) param = param - learning_rate * regulariser%l2 * param param = param - learning_rate * ( m_hat / (sqrt(v_hat) + this%epsilon) ) case(.false.) ! classical L2 regularisation (included in gradient) param = param - learning_rate * ( & ( m_hat + regulariser%l2 * param ) / & ( sqrt(v_hat) + this%epsilon ) & ) end select class default ! unknown regulariser — fall back to standard Adam param = param - learning_rate * ( & m_hat / (sqrt(v_hat) + this%epsilon) ) end select else ! no regularisation — standard Adam param = param - learning_rate * ( & m_hat / (sqrt(v_hat) + this%epsilon) ) end if end associate end subroutine minimise_adam