Athena interface to Weights & Biases (wandb) experiment tracking.
This module re-exports the full public API of the wf (wandb-fortran)
package so that user code only needs a single use athena_wandb.
When the library is compiled with the preprocessor macro _WANDB
(i.e. -D_WANDB is passed to the compiler), this module also provides
wandb_network_type — a thin network_type extension whose train
method automatically logs epoch loss and accuracy to W&B after each
epoch.
use athena_wandb
call wandb_init(project="my-project", name="run-01")
call wandb_config_set("learning_rate", 0.001_real32)
do epoch = 1, 100
! ... training ...
call wandb_log("loss", loss_val, step=epoch)
end do
call wandb_finish()
Requires -D_WANDB at compile time (for fpm >= 0.13.0, build with
--features wandb).
use athena
use athena_wandb ! provides wandb_network_type
type(wandb_network_type) :: net
! ... add layers, compile ...
call net%wandb_setup(project="my-project", name="run-01")
call net%train(x, y, num_epochs=50, verbose=0) ! logs per epoch
call wandb_finish()
Extension of network_type with automatic W&B metric logging.
| Type | Visibility | Attributes | Name | Initial | |||
|---|---|---|---|---|---|---|---|
| character(len=:), | public, | allocatable | :: | accuracy_method |
Loss and accuracy method names |
||
| real(kind=real32), | public | :: | accuracy_val |
Accuracy and loss of the network |
|||
| type(graph_type), | public | :: | auto_graph |
Graph structure for the network |
|||
| integer, | public | :: | batch_size | = | 0 |
Batch size |
|
| integer, | public | :: | epoch | = | 0 |
Epoch number |
|
| type(array_type), | public, | dimension(:,:), allocatable | :: | expected_array |
Expected output array for the network |
||
| integer, | public, | dimension(:), allocatable | :: | fwd_layer_id |
Layer ID for each vertex in forward order |
||
| integer, | public, | dimension(:), allocatable | :: | fwd_layer_type |
Layer type: 0=input, 1=merge, 2=default |
||
| integer, | public, | dimension(:), allocatable | :: | fwd_num_inputs |
Number of input layers for each vertex in forward order |
||
| integer, | public, | dimension(:), allocatable | :: | fwd_parent_id |
Parent layer ID for single-input vertices |
||
| procedure(compute_accuracy_function), | public, | nopass, pointer | :: | get_accuracy | => | null() |
Pointer to accuracy function |
| type(array_type), | public, | dimension(:,:), allocatable | :: | input_array |
Input array for the network |
||
| type(graph_type), | public, | dimension(:,:), allocatable | :: | input_graph |
Input graph for the network |
||
| integer, | public, | dimension(:), allocatable | :: | leaf_vertices |
Root and output vertices |
||
| logical, | public | :: | log_batch_metrics | = | .false. |
Reserved for future use: log per-batch metrics in addition to epoch metrics |
|
| class(base_loss_type), | public, | allocatable | :: | loss |
Loss method for the network |
||
| character(len=:), | public, | allocatable | :: | loss_method |
Loss and accuracy method names |
||
| real(kind=real32), | public | :: | loss_val |
Accuracy and loss of the network |
|||
| type(metric_dict_type), | public, | dimension(2) | :: | metrics |
Metrics for the network |
||
| type(container_layer_type), | public, | allocatable, dimension(:) | :: | model |
Model layers |
||
| character(len=:), | public, | allocatable | :: | name |
Name of the network |
||
| integer, | public | :: | num_layers | = | 0 |
Number of layers |
|
| integer, | public | :: | num_outputs | = | 0 |
Number of outputs |
|
| integer, | public | :: | num_params | = | 0 |
Number of parameters |
|
| class(base_optimiser_type), | public, | allocatable | :: | optimiser |
Optimiser for the network |
||
| integer, | public | :: | param_num_segments | = | 0 |
Number of parameter segments |
|
| integer, | public, | dimension(:), allocatable | :: | param_seg_end |
End offset in flat parameter array |
||
| integer, | public, | dimension(:), allocatable | :: | param_seg_layer |
Layer index for each parameter segment |
||
| integer, | public, | dimension(:), allocatable | :: | param_seg_pidx |
Param index within that layer for each segment |
||
| integer, | public, | dimension(:), allocatable | :: | param_seg_start |
Start offset in flat parameter array |
||
| integer, | public, | dimension(:), allocatable | :: | root_vertices |
Root and output vertices |
||
| logical, | public | :: | use_graph_input | = | .false. |
Boolean flag for graph input |
|
| logical, | public | :: | use_graph_output | = | .false. |
Boolean flag for graph output |
|
| integer, | public, | dimension(:), allocatable | :: | vertex_order |
Order of vertices |
||
| character(len=:), | public, | allocatable | :: | wandb_project |
W&B project name (set by |
||
| character(len=:), | public, | allocatable | :: | wandb_run_name |
W&B run display name (set by |
| procedure, public, pass(this) :: accuracy_eval | Get the accuracy for the output |
| procedure, public, pass(this) :: add | Add a layer to the network |
| procedure, public, pass(this) :: build_from_onnx | Build network from ONNX nodes and initialisers |
| procedure, public, pass(this) :: compile | Compile the network |
| procedure, public, pass(this) :: copy => network_copy | Copy a network |
| procedure, public, pass(this) :: extract_output => extract_output_real | Extract network output as real array (only works for single output layer models) |
| procedure, public, pass(this) :: forward => forward_generic2d | Forward pass for generic 2D input |
| procedure, public, pass(this) :: forward_eval | Forward pass and return pointer to output (only works for single output layer models) |
| procedure, public, pass(this) :: get_gradients | Get gradients of learnable parameters |
| procedure, public, pass(this) :: get_num_params | Get number of learnable parameters in the network |
| procedure, public, pass(this) :: get_output | Get the output of the network |
| procedure, public, pass(this) :: get_output_shape | Get the output shape of the network |
| procedure, public, pass(this) :: get_params | Get learnable parameters |
| generic, public :: inverse_design => inverse_design_real, inverse_design_array_0d, inverse_design_array_2d | Optimise input to match a target output |
| procedure, public, pass(this) :: layer_from_id | Get the layer of the network from its ID |
| procedure, public, pass(this) :: loss_eval | Get the loss for the output |
| procedure, public, pass(this) :: nullify_graph | Nullify graph data in the network to free memory |
| procedure, public :: post_epoch_hook => wandb_post_epoch_hook | Override the base no-op hook to log epoch metrics to W&B |
| generic, public :: predict => predict_real, predict_graph1d, predict_graph2d, predict_array, predict_array_from_real | Predict function for different input types |
| procedure, public, pass(this) :: predict_array | Predict array type output for a generic input |
| procedure, public, pass(this) :: predict_array_from_real | Return predicted results as array from supplied inputs using the trained network |
| procedure, public, pass(this) :: predict_generic | Predict generic type output for a generic input |
| procedure, public, pass(this) :: predict_graph1d | |
| procedure, public, pass(this) :: predict_graph2d | Return predicted results from supplied inputs using the trained network (graph input) |
| procedure, public, pass(this) :: predict_real | Return predicted results from supplied inputs using the trained network |
| procedure, public, pass(this) :: print | Print the network to file |
| procedure, public, pass(this) :: print_summary | Print a summary of the network architecture |
| procedure, public, pass(this) :: read | Read the network from a file |
| procedure, public, pass(this) :: reduce => network_reduction | Reduce two networks down to one (i.e. add two networks - parallel) |
| procedure, public, pass(this) :: reset | Reset the network |
| procedure, public, pass(this) :: reset_gradients | Reset learnable parameter gradients |
| procedure, public, pass(this) :: reset_state | Reset hidden state of recurrent layers |
| procedure, public, pass(this) :: save_input => save_input_to_network | Convert and save polymorphic input to array or graph |
| procedure, public, pass(this) :: save_output => save_output_to_network | Convert and save polymorphic output to array or graph |
| procedure, public, pass(this) :: set_accuracy | Set network accuracy method |
| procedure, public, pass(this) :: set_batch_size | Set batch size |
| procedure, public, pass(this) :: set_gradients | Set learnable parameter gradients |
| procedure, public, pass(this) :: set_inference_mode | Set inference mode for layers with training/inference-specific behaviour |
| procedure, public, pass(this) :: set_loss | Set network loss method |
| procedure, public, pass(this) :: set_metrics | Set network metrics |
| procedure, public, pass(this) :: set_params | Set learnable parameters |
| procedure, public, pass(this) :: set_training_mode | Set training mode for layers with training/inference-specific behaviour |
| procedure, public, pass(this) :: test | Test the network |
| procedure, public, pass(this) :: train | Train the network |
| procedure, public, pass(this) :: update | Update the learnable parameters of the network based on gradients |
| procedure, public :: wandb_setup | Initialise the W&B run and store project / run-name metadata |
Called automatically at the end of each training epoch. Logs "loss" and "accuracy" to the current W&B run.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(wandb_network_type), | intent(inout) | :: | this |
Network instance |
||
| integer, | intent(in) | :: | epoch |
Current epoch number (1-based) |
||
| real(kind=real32), | intent(in) | :: | loss |
Mean loss over the epoch |
||
| real(kind=real32), | intent(in) | :: | accuracy |
Mean accuracy over the epoch |
Initialise a Weights & Biases run for this network.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(wandb_network_type), | intent(inout) | :: | this |
Network instance |
||
| character(len=*), | intent(in) | :: | project |
W&B project name |
||
| character(len=*), | intent(in), | optional | :: | name |
W&B run display name (optional) |