Set the loss method for the network
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(inout) | :: | this |
Instance of network |
||
| class(*), | intent(in) | :: | loss_method |
Loss method |
||
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
module subroutine set_loss(this, loss_method, verbose) !! Set the loss method for the network use coreutils, only: to_lower use athena__loss, only: & bce_loss_type, & cce_loss_type, & mae_loss_type, & mse_loss_type, & nll_loss_type, & huber_loss_type implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), intent(in) :: loss_method !! Loss method integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: verbose_ !! Verbosity level character(len=:), allocatable :: loss_method_ !! Loss method character(256) :: err_msg !! Error message if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if !--------------------------------------------------------------------------- ! Handle analogous definitions !--------------------------------------------------------------------------- !--------------------------------------------------------------------------- ! Set loss method !--------------------------------------------------------------------------- select type(loss_method) class is(base_loss_type) this%loss = loss_method if(verbose_.gt.0) write(*,*) "Loss method: ", trim(loss_method%name) loss_method_ = trim(loss_method%name) type is(character(*)) loss_method_ = to_lower(loss_method) select case(loss_method) case("binary_crossentropy") loss_method_ = "bce" case("categorical_crossentropy") loss_method_ = "cce" case("mean_absolute_error") loss_method_ = "mae" case("mean_squared_error") loss_method_ = "mse" case("negative_log_likelihood") loss_method_ = "nll" case("huber") loss_method_ = "hub" end select select case(loss_method_) case("bce") this%loss = bce_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Binary Cross Entropy" case("cce") this%loss = cce_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy" case("mae") this%loss = mae_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Mean Absolute Error" case("mse") this%loss = mse_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Mean Squared Error" case("nll") this%loss = nll_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Negative Log Likelihood" case("hub") this%loss = huber_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Huber" case default write(err_msg,'(A)') & "No loss method provided" // & achar(13) // achar(10) // & "Failed loss method: "//trim(loss_method_) call stop_program(trim(err_msg)) return end select end select this%loss_method = loss_method_ end subroutine set_loss