Creating Custom Loss Functions¶
You can implement custom loss functions by extending the base_loss_type.
Base Loss Type¶
All loss functions inherit from base_loss_type and must implement the compute method.
Required Procedures¶
compute: Calculate the loss given predicted and expected values
Essential Structure¶
type, extends(base_loss_type) :: custom_loss_type
! Add custom parameters here
real(real32) :: custom_param = 1.0_real32
contains
procedure, pass(this) :: compute => compute_custom
end type custom_loss_type
Example: Focal Loss¶
Here’s a complete example of a custom loss function (Focal Loss for handling class imbalance):
module my_focal_loss
use coreutils, only: real32
use athena__loss, only: base_loss_type
use diffstruc, only: array_type, operator(-), operator(*), operator(**), &
log, sum, mean
implicit none
type, extends(base_loss_type) :: focal_loss_type
real(real32) :: alpha = 0.25_real32 ! Balance factor
real(real32) :: gamma = 2.0_real32 ! Focusing parameter
contains
procedure, pass(this) :: compute => compute_focal
end type focal_loss_type
interface focal_loss_type
module function setup(alpha, gamma) result(loss)
real(real32), optional, intent(in) :: alpha, gamma
type(focal_loss_type) :: loss
end function setup
end interface
contains
function setup(alpha, gamma) result(loss)
real(real32), optional, intent(in) :: alpha, gamma
type(focal_loss_type) :: loss
loss%name = "focal"
if(present(alpha)) loss%alpha = alpha
if(present(gamma)) loss%gamma = gamma
loss%epsilon = 1.E-10_real32
end function setup
function compute_focal(this, predicted, expected) result(output)
class(focal_loss_type), intent(in), target :: this
type(array_type), dimension(:,:), intent(inout), target :: predicted
type(array_type), dimension(size(predicted,1),size(predicted,2)), &
intent(in) :: expected
type(array_type), pointer :: output
type(array_type) :: pt, focal_weight, loss_val
! Clip predictions to prevent log(0)
pt = predicted
! Note: In practice, add clipping: max(epsilon, min(1-epsilon, predicted))
! Compute focal weight: (1 - pt)^gamma
focal_weight = (1.0_real32 - pt)**this%gamma
! Compute focal loss: -alpha * (1-pt)^gamma * log(pt)
loss_val = -this%alpha * focal_weight * log(pt + this%epsilon) * expected
! Average over samples
allocate(output)
output = mean(sum(loss_val))
end function compute_focal
end module my_focal_loss
Example: Custom Regression Loss¶
A loss function that combines MSE with a penalty term:
module my_regularised_mse
use coreutils, only: real32
use athena__loss, only: base_loss_type
use diffstruc, only: array_type, operator(-), operator(*), operator(**), &
mean, sum, abs
type, extends(base_loss_type) :: regularised_mse_type
real(real32) :: lambda = 0.01_real32 ! Regularisation strength
contains
procedure, pass(this) :: compute => compute_regularised_mse
end type regularised_mse_type
contains
function compute_regularised_mse(this, predicted, expected) result(output)
class(regularised_mse_type), intent(in), target :: this
type(array_type), dimension(:,:), intent(inout), target :: predicted
type(array_type), dimension(size(predicted,1),size(predicted,2)), &
intent(in) :: expected
type(array_type), pointer :: output
type(array_type) :: mse_term, reg_term, diff
! Compute MSE
diff = predicted - expected
mse_term = mean((diff)**2)
! Add regularisation term (e.g., sum of absolute predictions)
reg_term = this%lambda * mean(abs(predicted))
allocate(output)
output = mse_term + reg_term
end function compute_regularised_mse
end module my_regularised_mse
Working with array_type¶
The array_type supports automatic differentiation. Use these operations:
Basic Operations¶
! Arithmetic
result = a + b ! Addition
result = a - b ! Subtraction
result = a * b ! Multiplication
result = a / b ! Division
result = a ** 2 ! Power
! Functions
result = log(a) ! Natural logarithm
result = abs(a) ! Absolute value
result = exp(a) ! Exponential
result = sqrt(a) ! Square root
! Reductions
result = sum(a) ! Sum all elements
result = mean(a) ! Mean of elements
! Conditionals (preserve gradients)
result = merge(a, b, condition)
Best Practices¶
Numerical Stability
Add epsilon to prevent log(0) or division by zero
Clip extreme values when necessary
Use log-space computations when dealing with very small probabilities
Gradient Flow
Ensure all operations support backpropagation
Use
mergeinstead ofif-elsefor conditionalsAvoid operations that break gradient flow
Batch Processing
Loss should average over the batch dimension
Use
meanfor proper averagingConsider sample weights if needed
Testing
Verify loss decreases during training
Check gradients are computed correctly
Test with known inputs and expected outputs
Compare against reference implementations
Documentation
Cite original papers for novel losses
Document the mathematical formulation
Explain when to use the loss function
Provide guidance on hyperparameter tuning
Common Patterns¶
Classification Loss Template¶
function compute_classification(this, predicted, expected) result(output)
! ...
! Clip predictions to valid probability range
p = max(this%epsilon, min(1.0 - this%epsilon, predicted))
! Compute cross-entropy or similar
loss = -expected * log(p)
! Average over samples
output = mean(sum(loss))
end function
Regression Loss Template¶
function compute_regression(this, predicted, expected) result(output)
! ...
! Compute error
error = predicted - expected
! Apply loss function to error
loss = error**2 ! Or abs(error), or custom function
! Average over samples
output = mean(loss)
end function
See Also¶
diffstrucmodule documentation for array operations