set_loss Module Subroutine

module subroutine set_loss(this, loss_method, verbose)

Uses

Set the loss method for the network

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

class(*), intent(in) :: loss_method

Loss method

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  module subroutine set_loss(this, loss_method, verbose)
    !! Set the loss method for the network
    use coreutils, only: to_lower
    use athena__loss, only: &
         bce_loss_type, &
         cce_loss_type, &
         mae_loss_type, &
         mse_loss_type, &
         nll_loss_type, &
         huber_loss_type
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), intent(in) :: loss_method
    !! Loss method
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: verbose_
    !! Verbosity level
    character(len=:), allocatable :: loss_method_
    !! Loss method
    character(256) :: err_msg
    !! Error message


    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    !---------------------------------------------------------------------------
    ! Handle analogous definitions
    !---------------------------------------------------------------------------

    !---------------------------------------------------------------------------
    ! Set loss method
    !---------------------------------------------------------------------------
    select type(loss_method)
    class is(base_loss_type)
       this%loss = loss_method
       if(verbose_.gt.0) write(*,*) "Loss method: ", trim(loss_method%name)
       loss_method_ = trim(loss_method%name)
    type is(character(*))
       loss_method_ = to_lower(loss_method)
       select case(loss_method)
       case("binary_crossentropy")
          loss_method_ = "bce"
       case("categorical_crossentropy")
          loss_method_ = "cce"
       case("mean_absolute_error")
          loss_method_ = "mae"
       case("mean_squared_error")
          loss_method_ = "mse"
       case("negative_log_likelihood")
          loss_method_ = "nll"
       case("huber")
          loss_method_ = "hub"
       end select
       select case(loss_method_)
       case("bce")
          this%loss = bce_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Binary Cross Entropy"
       case("cce")
          this%loss = cce_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy"
       case("mae")
          this%loss = mae_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Mean Absolute Error"
       case("mse")
          this%loss = mse_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Mean Squared Error"
       case("nll")
          this%loss = nll_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Negative Log Likelihood"
       case("hub")
          this%loss = huber_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Huber"
       case default
          write(err_msg,'(A)') &
               "No loss method provided" // &
               achar(13) // achar(10) // &
               "Failed loss method: "//trim(loss_method_)
          call stop_program(trim(err_msg))
          return
       end select
    end select
    this%loss_method = loss_method_

  end subroutine set_loss