module athena__loss
  !! Module containing loss function implementations
  !!
  !! This module implements loss functions that quantify the difference between
  !! model predictions and target values, guiding the optimisation process.
  !!
  !! Implemented loss functions:
  !!
  !! Mean Squared Error (MSE):
  !!   L = (1/N) Σ (y_pred - y_true)²
  !!   For regression, sensitive to outliers
  !!
  !! Mean Absolute Error (MAE):
  !!   L = (1/N) Σ |y_pred - y_true|
  !!   For regression, robust to outliers
  !!
  !! Binary Cross-Entropy:
  !!   L = -(1/N) Σ [y*log(ŷ) + (1-y)*log(1-ŷ)]
  !!   For binary classification (outputs in [0,1])
  !!
  !! Categorical Cross-Entropy:
  !!   L = -(1/N) Σ_i Σ_c y_{i,c} * log(ŷ_{i,c})
  !!   For multi-class classification with one-hot encoded targets
  !!
  !! Sparse Categorical Cross-Entropy:
  !!   L = -(1/N) Σ log(ŷ_{i,c_i})
  !!   For multi-class with integer class labels
  !!
  !! Huber Loss:
  !!   L = (1/N) Σ { 0.5*(y-ŷ)²           if |y-ŷ| ≤ δ
  !!               { δ*(|y-ŷ| - 0.5*δ)    otherwise
  !!   Combines MSE and MAE, robust to outliers while smooth near zero
  !!
  !! where N is number of samples, y is true value, ŷ is prediction
  use coreutils, only: real32
  use diffstruc, only: array_type, operator(+), operator(-), &
       operator(*), operator(/), mean, sum, log, abs, merge, squared
  use athena__diffstruc_extd, only: huber
  implicit none


  private

  public :: base_loss_type
  public :: bce_loss_type
  public :: cce_loss_type
  public :: mae_loss_type
  public :: mse_loss_type
  public :: nll_loss_type
  public :: huber_loss_type


  type, abstract :: base_loss_type
     !! Abstract type for loss functions
     character(len=:), allocatable :: name
     !! Name of the loss function
     real(real32) :: epsilon = 1.E-10_real32
     !! Small value to prevent log(0)
     integer :: batch_index = 1
     !! Index of the batch to compute the loss for
     integer :: sample_index = 1
     !! Index of the sample to compute the loss for
   contains
     procedure(compute_base), deferred, pass(this) :: compute
     !! Compute the loss of a model
  end type base_loss_type

  interface
     module function compute_base(this, predicted, expected) result(output)
       !! Compute the loss of a model
       class(base_loss_type), intent(in), target :: this
       !! Instance of the physics-informed neural network loss function
       type(array_type), dimension(:,:), intent(inout), target :: predicted
       !! Predicted values
       type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
            expected
       !! Expected values
       type(array_type), pointer :: output
       !! Physics-informed neural network loss
     end function compute_base
  end interface

!-------------------------------------------------------------------------------

  type, extends(base_loss_type) :: bce_loss_type
     !! Binary cross entropy loss function
   contains
     procedure :: compute => compute_bce
     !! Compute the loss of a model
  end type bce_loss_type

  interface bce_loss_type
     !! Interface for binary cross entropy loss function
     module function setup_loss_bce() result(loss)
       !! Set up binary cross entropy loss function
       type(bce_loss_type) :: loss
       !! Binary cross entropy loss function
     end function setup_loss_bce
  end interface bce_loss_type

!-------------------------------------------------------------------------------

  type, extends(base_loss_type) :: cce_loss_type
     !! Categorical cross entropy loss function
   contains
     procedure :: compute => compute_cce
     !! Compute the loss of a model
  end type cce_loss_type

  interface cce_loss_type
     !! Interface for categorical cross entropy loss function
     module function setup_loss_cce() result(loss)
       !! Set up categorical cross entropy loss function
       type(cce_loss_type) :: loss
       !! Categorical cross entropy loss function
     end function setup_loss_cce
  end interface cce_loss_type

!-------------------------------------------------------------------------------

  type, extends(base_loss_type) :: mae_loss_type
     !! Mean absolute error loss function
   contains
     procedure :: compute => compute_mae
     !! Compute the loss of a model
  end type mae_loss_type

  interface mae_loss_type
     !! Interface for mean absolute error loss function
     module function setup_loss_mae() result(loss)
       !! Set up mean absolute error loss function
       type(mae_loss_type) :: loss
       !! Mean absolute error loss function
     end function setup_loss_mae
  end interface mae_loss_type

!-------------------------------------------------------------------------------

  type, extends(base_loss_type) :: mse_loss_type
     !! Mean squared error loss function
   contains
     procedure :: compute => compute_mse
     !! Compute the loss of a model
  end type mse_loss_type

  interface mse_loss_type
     !! Interface for mean squared error loss function
     module function setup_loss_mse() result(loss)
       !! Set up mean squared error loss function
       type(mse_loss_type) :: loss
       !! Mean squared error loss function
     end function setup_loss_mse
  end interface mse_loss_type

!-------------------------------------------------------------------------------

  type, extends(base_loss_type) :: nll_loss_type
     !! Negative log likelihood loss function
   contains
     procedure :: compute => compute_nll
     !! Compute the loss of a model
  end type nll_loss_type

  interface nll_loss_type
     !! Interface for negative log likelihood loss function
     module function setup_loss_nll() result(loss)
       !! Set up negative log likelihood loss function
       type(nll_loss_type) :: loss
       !! Negative log likelihood loss function
     end function setup_loss_nll
  end interface nll_loss_type

!-------------------------------------------------------------------------------

  type, extends(base_loss_type) :: huber_loss_type
     !! Huber loss function
     real(real32) :: gamma = 1._real32
     !! Gamma value for the huber loss function
   contains
     procedure :: compute => compute_huber
     !! Compute the loss of a model
  end type huber_loss_type

  interface huber_loss_type
     !! Interface for huber loss function
     module function setup_loss_huber() result(loss)
       !! Set up huber loss function
       type(huber_loss_type) :: loss
       !! Huber loss function
     end function setup_loss_huber
  end interface huber_loss_type

!-------------------------------------------------------------------------------



contains
!###############################################################################
  module function setup_loss_bce() result(loss)
    !! Set up binary cross entropy loss function
    implicit none

    ! Local variables
    type(bce_loss_type) :: loss
    !! Binary cross entropy loss function

    loss%name = 'bce'
  end function setup_loss_bce
!-------------------------------------------------------------------------------
  module function setup_loss_cce() result(loss)
    !! Set up categorical cross entropy loss function
    implicit none

    ! Local variables
    type(cce_loss_type) :: loss
    !! Categorical cross entropy loss function

    loss%name = 'cce'
  end function setup_loss_cce
!-------------------------------------------------------------------------------
  module function setup_loss_mae() result(loss)
    !! Set up mean absolute error loss function
    implicit none

    ! Local variables
    type(mae_loss_type) :: loss
    !! Mean absolute error loss function

    loss%name = 'mae'
  end function setup_loss_mae
!-------------------------------------------------------------------------------
  module function setup_loss_mse() result(loss)
    !! Set up mean squared error loss function
    implicit none

    ! Local variables
    type(mse_loss_type) :: loss
    !! Mean squared error loss function

    loss%name = 'mse'
  end function setup_loss_mse
!-------------------------------------------------------------------------------
  module function setup_loss_nll() result(loss)
    !! Set up negative log likelihood loss function
    implicit none

    ! Local variables
    type(nll_loss_type) :: loss
    !! Negative log likelihood loss function

    loss%name = 'nll'
  end function setup_loss_nll
!-------------------------------------------------------------------------------
  module function setup_loss_huber() result(loss)
    !! Set up huber loss function
    implicit none

    ! Local variables
    type(huber_loss_type) :: loss
    !! Huber loss function

    loss%name = 'hub'
  end function setup_loss_huber
!###############################################################################


!###############################################################################
  function compute_bce(this, predicted, expected) result(output)
    !! Compute the binary cross entropy loss of a model
    implicit none

    ! Arguments
    class(bce_loss_type), intent(in), target :: this
    !! Instance of the physics-informed neural network loss function
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Binary cross entropy loss

    ! Local variables
    integer :: s, i
    !! Loop indices
    type(array_type), pointer :: ptr
    !! Temporary pointer for calculations

    output => mean(-expected(1,1) * log(predicted(1,1) + this%epsilon), dim=2)
    if(any(shape(predicted).gt.1))then
       do s = 1, size(predicted,2)
          do i = 1, size(predicted,1)
             if(i.eq.1 .and. s.eq.1) cycle
             if(.not.predicted(i,s)%allocated .or. &
                  .not.expected(i,s)%allocated) cycle
             ptr => mean(-expected(i,s) * log(predicted(i,s) + this%epsilon), dim=2)

             output => output + ptr
          end do
       end do
    end if

  end function compute_bce
!###############################################################################


!###############################################################################
  function compute_cce(this, predicted, expected) result(output)
    !! Compute the categorical cross entropy loss of a model
    implicit none

    ! Arguments
    class(cce_loss_type), intent(in), target :: this
    !! Instance of the physics-informed neural network loss function
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Categorical cross entropy loss

    ! Local variables
    integer :: s, i
    !! Loop indices
    type(array_type), pointer :: ptr
    !! Temporary pointer for calculations

    output => -mean( sum( &
         expected(1,1) * log(predicted(1,1) + this%epsilon), &
         dim=1 ), dim=2)
    if(any(shape(predicted).gt.1))then
       do s = 1, size(predicted,2)
          do i = 1, size(predicted,1)
             if(i.eq.1 .and. s.eq.1) cycle
             if(.not.predicted(i,s)%allocated .or. &
                  .not.expected(i,s)%allocated) cycle
             ptr => mean( sum( &
                  expected(i,s) * log(predicted(i,s) + this%epsilon), &
                  dim=1 ), dim=2)

             output => output - ptr
          end do
       end do
    end if

  end function compute_cce
!###############################################################################


!###############################################################################
  function compute_mae(this, predicted, expected) result(output)
    !! Compute the mean absolute error of a model
    implicit none

    ! Arguments
    class(mae_loss_type), intent(in), target :: this
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Mean absolute error

    ! Local variables
    integer :: s, i
    !! Loop indices
    type(array_type), pointer :: ptr
    !! Temporary pointer for calculations

    output => mean( abs( predicted(1,1) - expected(1,1) ) ) / &
         2._real32
    if(any(shape(predicted).gt.1))then
       do s = 1, size(predicted,2)
          do i = 1, size(predicted,1)
             if(i.eq.1 .and. s.eq.1) cycle
             if(.not.predicted(i,s)%allocated .or. &
                  .not.expected(i,s)%allocated) cycle
             ptr => mean( abs( predicted(i,s) - expected(i,s) ) ) / &
                  2._real32

             output => output + ptr
          end do
       end do
    end if

  end function compute_mae
!###############################################################################


!###############################################################################
  function compute_mse(this, predicted, expected) result(output)
    !! Compute the mean squared error of a model
    implicit none

    ! Arguments
    class(mse_loss_type), intent(in), target :: this
    !! Instance of the mean squared error loss function
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Mean squared error loss

    ! Local variables
    integer :: s, i
    !! Loop indices
    type(array_type), pointer :: ptr
    !! Temporary pointer for calculations

    output => mean( squared( predicted(1,1) - expected(1,1) ) ) / &
         2._real32
    if(any(shape(predicted).gt.1))then
       do s = 1, size(predicted,2)
          do i = 1, size(predicted,1)
             if(i.eq.1 .and. s.eq.1) cycle
             if(.not.predicted(i,s)%allocated .or. &
                  .not.expected(i,s)%allocated) cycle
             ptr => mean( squared( predicted(i,s) - expected(i,s) ) ) / &
                  2._real32

             output => output + ptr
          end do
       end do
    end if

  end function compute_mse
!###############################################################################


!###############################################################################
  function compute_nll(this, predicted, expected) result(output)
    !! Compute the negative log likelihood of a model
    implicit none

    ! Arguments
    class(nll_loss_type), intent(in), target :: this
    !! Instance of the physics-informed neural network loss function
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Negative log likelihood loss

    ! Local variables
    integer :: s, i
    !! Loop indices
    type(array_type), pointer :: ptr
    !! Temporary pointer for calculations

    output => mean(-log(expected(1,1) - predicted(1,1) + this%epsilon) )
    if(any(shape(predicted).gt.1))then
       do s = 1, size(predicted,2)
          do i = 1, size(predicted,1)
             if(i.eq.1 .and. s.eq.1) cycle
             if(.not.predicted(i,s)%allocated .or. &
                  .not.expected(i,s)%allocated) cycle
             ptr => mean(-log(expected(i,s) - predicted(i,s) + this%epsilon) )

             output => output + ptr
          end do
       end do
    end if

  end function compute_nll
!###############################################################################


!###############################################################################
  function compute_huber(this, predicted, expected) result(output)
    !! Compute the huber loss of a model
    implicit none

    ! Arguments
    class(huber_loss_type), intent(in), target :: this
    !! Instance of the huber loss function
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Huber loss

    ! Local variables
    integer :: s, i
    !! Loop indices
    type(array_type), pointer :: ptr
    !! Temporary pointer for calculations

    ptr => predicted(1,1) - expected(1,1)
    output => mean( huber(predicted(1,1) - expected(1,1), this%gamma) )
    if(any(shape(predicted).gt.1))then
       do s = 1, size(predicted,2)
          do i = 1, size(predicted,1)
             if(i.eq.1 .and. s.eq.1) cycle
             if(.not.predicted(i,s)%allocated .or. &
                  .not.expected(i,s)%allocated) cycle
             ptr => predicted(i,s) - expected(i,s)

             output => output + mean( huber(ptr, this%gamma) )
          end do
       end do
    end if

    ! output => merge( &
    !      0.5_real32 * (ptr)**2._real32, &
    !      this%gamma * (abs(ptr) - 0.5_real32 * this%gamma), &
    !      abs(ptr) .le. this%gamma &
    ! )

  end function compute_huber
!###############################################################################


!###############################################################################
  module function compute_base(this, predicted, expected) result(output)
    !! Placeholder for compute function in base_loss_type
    implicit none

    ! Arguments
    class(base_loss_type), intent(in), target :: this
    !! Instance of the base loss function
    type(array_type), dimension(:,:), intent(inout), target :: predicted
    !! Predicted values
    type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
         expected
    !! Expected values
    type(array_type), pointer :: output
    !! Loss value

  end function compute_base
!###############################################################################

end module athena__loss
