module athena__metrics
  !! Module containing functions to compute the accuracy of a model
  !!
  !! This module contains a derived type for storing and handling metric data
  use coreutils, only: real32, stop_program
  implicit none


  private

  public :: metric_dict_type
  public :: metric_dict_alloc


  type :: metric_dict_type
     !! Type for storing and handling metric data
     character(10) :: key
     !! Key for the metric
     real(real32) :: val
     !! Value of the metric
     logical :: active
     !! Flag to indicate if the metric is active
     real(real32) :: threshold
     !! Threshold for the metric
     integer :: window_width
     !! Window width for checking convergence
     integer :: num_entries
     !! Number of entries in the history
     real(real32), allocatable, dimension(:) :: history
     !! History of the metric
   contains
     procedure :: check => metric_dict_check
     !! Check if the metric has converged
     procedure :: add_t_t => metric_dict_add
     !! Add two metric_dict_type together
     procedure :: append => append_value
     !! Append a value to the history of the metric
     generic :: operator(+) => add_t_t
     !! Overload the addition operator
  end type metric_dict_type



contains

!###############################################################################
  elemental function metric_dict_add(a, b) result(output)
    !! Operation to add two metric_dict_type together
    implicit none

    ! Arguments
    class(metric_dict_type), intent(in) :: a,b
    !! Instances of metric data
    type(metric_dict_type) :: output
    !! Sum of the metric data

    output%key = a%key
    output%val = a%val + b%val
    output%threshold = a%threshold
    output%active = a%active
    if(allocated(a%history)) output%history = a%history
    output%num_entries = a%num_entries

  end function metric_dict_add
!###############################################################################


!###############################################################################
  subroutine metric_dict_alloc(input, source, length)
    !! Allocate memory for a metric_dict_type
    implicit none

    ! Arguments
    type(metric_dict_type), dimension(:), intent(out) :: input
    !! Instance of metric data
    type(metric_dict_type), dimension(:), optional, intent(in) :: source
    !! Source of the metric data to copy
    integer, optional, intent(in) :: length
    !! Length of the metric data

    ! Local variables
    integer :: i
    !! Loop index


    if(present(length))then
       do i=1,size(input,dim=1)
          allocate(input(i)%history(length))
       end do
    else
       if(present(source))then
          do i=1, size(input,dim=1)
             input(i)%key = source(i)%key
             allocate(input(i)%history(size(source(i)%history,dim=1)))
             input(i)%threshold = source(i)%threshold
          end do
       else
          call stop_program( &
               "metric_dict_alloc requires either a source or length" &
          )
       end if
    end if
    input%num_entries = 0

  end subroutine metric_dict_alloc
!###############################################################################


!###############################################################################
  subroutine append_value(this, value)
    !! Append a value to the history of the metric
    implicit none

    ! Arguments
    class(metric_dict_type), intent(inout) :: this
    !! Instance of metric data
    real(real32), intent(in) :: value
    !! Value to append

    ! Local variables
    integer :: new_size

    this%val = value
    if(.not.allocated(this%history))then
       allocate(this%history(this%window_width), source = -huge(1._real32))
       this%history(this%window_width) = value
       this%num_entries = 0
    elseif(this%num_entries .lt. this%window_width)then
       this%history(this%num_entries) = value
    else
       this%history = [ this%history, value ]
    end if
    this%num_entries = this%num_entries + 1

  end subroutine append_value
!###############################################################################


!###############################################################################
  subroutine metric_dict_check(this,plateau_threshold,converged)
    !! Check if the metric has converged
    implicit none

    ! Arguments
    class(metric_dict_type), intent(inout) :: this
    !! Instance of metric data
    real(real32), intent(in) :: plateau_threshold
    !! Threshold for plateau
    integer, intent(out) :: converged
    !! Boolean whether the metric has converged

    ! Local variables
    integer :: window_width
    !! Width of the convergence check window
    integer :: window_ubound, window_lbound
    !! Upper and lower bounds of the window

    converged = 0
    window_width = min(this%window_width, this%num_entries)
    if(window_width .le. 0)then
       call stop_program("Window width is zero or negative")
       return
    end if
    window_ubound = this%num_entries
    window_lbound = window_ubound - window_width + 1
    if(this%active)then
       if( &
            ( &
                 trim(this%key).eq."loss".and.&
                 abs( &
                      sum( this%history(window_lbound:window_ubound) ) &
                 ) / window_width.lt.&
                 this%threshold &
            ) .or. &
            ( &
                 trim(this%key).eq."accuracy".and.&
                 abs( &
                      sum( 1._real32 - this%history(window_lbound:window_ubound) ) &
                 ) / window_width.lt.&
                 this%threshold &
            ) &
       )then
          write(6,*) &
               "Convergence achieved, "//trim(this%key)//" threshold reached"
          write(6,*) "Exiting training loop"
          converged = 1
       elseif( &
            all( abs(this%history(window_lbound:window_ubound) - this%val) .lt. &
                 plateau_threshold &
            ) &
       )then
          write(0,'("ERROR: ",A," has remained constant for ",I0," runs")') &
               trim(this%key), size(this%history,dim=1)
          write(0,*) this%history
          write(0,*) "Exiting..."
          converged = -1
       end if
    end if

  end subroutine metric_dict_check
!###############################################################################

end module athena__metrics
