metric_dict_check Subroutine

private subroutine metric_dict_check(this, plateau_threshold, converged)

Check if the metric has converged

Type Bound

metric_dict_type

Arguments

Type IntentOptional Attributes Name
class(metric_dict_type), intent(inout) :: this

Instance of metric data

real(kind=real32), intent(in) :: plateau_threshold

Threshold for plateau

integer, intent(out) :: converged

Boolean whether the metric has converged


Source Code

  subroutine metric_dict_check(this,plateau_threshold,converged)
    !! Check if the metric has converged
    implicit none

    ! Arguments
    class(metric_dict_type), intent(inout) :: this
    !! Instance of metric data
    real(real32), intent(in) :: plateau_threshold
    !! Threshold for plateau
    integer, intent(out) :: converged
    !! Boolean whether the metric has converged

    ! Local variables
    integer :: window_width
    !! Width of the convergence check window
    integer :: window_ubound, window_lbound
    !! Upper and lower bounds of the window

    converged = 0
    window_width = min(this%window_width, this%num_entries)
    if(window_width .le. 0)then
       call stop_program("Window width is zero or negative")
       return
    end if
    window_ubound = this%num_entries
    window_lbound = window_ubound - window_width + 1
    if(this%active)then
       if( &
            ( &
                 trim(this%key).eq."loss".and.&
                 abs( &
                      sum( this%history(window_lbound:window_ubound) ) &
                 ) / window_width.lt.&
                 this%threshold &
            ) .or. &
            ( &
                 trim(this%key).eq."accuracy".and.&
                 abs( &
                      sum( 1._real32 - this%history(window_lbound:window_ubound) ) &
                 ) / window_width.lt.&
                 this%threshold &
            ) &
       )then
          write(6,*) &
               "Convergence achieved, "//trim(this%key)//" threshold reached"
          write(6,*) "Exiting training loop"
          converged = 1
       elseif( &
            all( abs(this%history(window_lbound:window_ubound) - this%val) .lt. &
                 plateau_threshold &
            ) &
       )then
          write(0,'("ERROR: ",A," has remained constant for ",I0," runs")') &
               trim(this%key), size(this%history,dim=1)
          write(0,*) this%history
          write(0,*) "Exiting..."
          converged = -1
       end if
    end if

  end subroutine metric_dict_check