set_accuracy Module Subroutine

module subroutine set_accuracy(this, accuracy_method, verbose)

Uses

Set the accuracy method for the network

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

character(len=*), intent(in) :: accuracy_method

Accuracy method

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  module subroutine set_accuracy(this, accuracy_method, verbose)
    !! Set the accuracy method for the network
    use coreutils, only: to_lower
    use athena__accuracy, only: &
         categorical_score, &
         mae_score, &
         mse_score, &
         rmse_score, &
         r2_score
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    character(*), intent(in) :: accuracy_method
    !! Accuracy method
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: verbose_
    !! Verbosity level
    character(len=:), allocatable :: accuracy_method_
    !! Accuracy method
    character(256) :: err_msg
    !! Error message


    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    !---------------------------------------------------------------------------
    ! Handle analogous definitions
    !---------------------------------------------------------------------------
    accuracy_method_ = to_lower(accuracy_method)
    select case(accuracy_method)
    case("categorical")
       accuracy_method_ = "cat"
    case("mean_absolute_error")
       accuracy_method_ = "mae"
    case("mean_squared_error")
       accuracy_method_ = "mse"
    case("root_mean_squared_error")
       accuracy_method_ = "rmse"
    case("r2", "r^2", "r squared")
       accuracy_method_ = "r2"
    end select

    !---------------------------------------------------------------------------
    ! Set accuracy method
    !---------------------------------------------------------------------------
    select case(accuracy_method_)
    case("cat")
       this%get_accuracy => categorical_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Categorical "
    case("mae")
       this%get_accuracy => mae_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Absolute Error"
    case("mse")
       this%get_accuracy => mse_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Squared Error"
    case("rmse")
       this%get_accuracy => rmse_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Root Mean Squared Error"
    case("r2")
       this%get_accuracy => r2_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: R^2"
    case default
       write(err_msg,'(A)') &
            "No accuracy method provided" // &
            achar(13) // achar(10) // &
            "Failed accuracy method: "//trim(accuracy_method_)
       call stop_program(trim(err_msg))
       return
    end select
    this%accuracy_method = accuracy_method_

  end subroutine set_accuracy