athena_wandb.F90 Source File


Source Code

module athena_wandb
  !! Athena interface to Weights & Biases (wandb) experiment tracking.
  !!
  !! This module re-exports the full public API of the `wf` (wandb-fortran)
  !! package so that user code only needs a single `use athena_wandb`.
  !!
  !! When the library is compiled with the preprocessor macro `_WANDB`
  !! (i.e. `-D_WANDB` is passed to the compiler), this module also provides
  !! `wandb_network_type` — a thin `network_type` extension whose `train`
  !! method automatically logs epoch loss and accuracy to W&B after each
  !! epoch.
  !!
  !! ## Usage (basic wandb API)
  !!
  !! ```fortran
  !! use athena_wandb
  !! call wandb_init(project="my-project", name="run-01")
  !! call wandb_config_set("learning_rate", 0.001_real32)
  !! do epoch = 1, 100
  !!    ! ... training ...
  !!    call wandb_log("loss", loss_val, step=epoch)
  !! end do
  !! call wandb_finish()
  !! ```
  !!
  !! ## Usage (automatic logging via wandb_network_type)
  !!
  !! Requires `-D_WANDB` at compile time (for fpm >= 0.13.0, build with
  !! `--features wandb`).
  !!
  !! ```fortran
  !! use athena
  !! use athena_wandb          ! provides wandb_network_type
  !! type(wandb_network_type) :: net
  !! ! ... add layers, compile ...
  !! call net%wandb_setup(project="my-project", name="run-01")
  !! call net%train(x, y, num_epochs=50, verbose=0)   ! logs per epoch
  !! call wandb_finish()
  !! ```
  !!
#ifdef _WANDB
  use wf, only: &
       wandb_init,              &
       wandb_log,               &
       wandb_finish,            &
       wandb_shutdown,          &
       wandb_config_set,        &
       wandb_config_get,        &
       wandb_sweep_config_type, &
       wandb_sweep,             &
       wandb_agent,             &
       wandb_sweep_start_agent, &
       wandb_sweep_next_params, &
       wandb_sweep_run_done
  use athena__network, only: network_type
  use coreutils, only: real32
#endif
  implicit none

  private

#ifdef _WANDB
  !---------------------------------------------------------------------------
  ! Re-export wandb-fortran public API
  !---------------------------------------------------------------------------
  public :: wandb_init
  public :: wandb_log
  public :: wandb_finish
  public :: wandb_shutdown
  public :: wandb_config_set
  public :: wandb_config_get
  public :: wandb_sweep_config_type
  public :: wandb_sweep
  public :: wandb_agent
  public :: wandb_sweep_start_agent
  public :: wandb_sweep_next_params
  public :: wandb_sweep_run_done


  !---------------------------------------------------------------------------
  ! wandb_network_type
  !---------------------------------------------------------------------------
  public :: wandb_network_type

  type, extends(network_type) :: wandb_network_type
     !! Extension of `network_type` with automatic W&B metric logging.
     !!
     !! Call `wandb_setup` once before training to initialise the W&B run,
     !! then call `train` as normal.  After every epoch the hook logs:
     !!   - "loss"     — mean batch loss over the epoch
     !!   - "accuracy" — mean batch accuracy over the epoch
     !!
     character(len=:), allocatable :: wandb_project
     !! W&B project name (set by `wandb_setup`)
     character(len=:), allocatable :: wandb_run_name
     !! W&B run display name (set by `wandb_setup`)
     logical :: log_batch_metrics = .false.
     !! Reserved for future use: log per-batch metrics in addition to epoch metrics
   contains
     procedure :: post_epoch_hook => wandb_post_epoch_hook
     !! Override the base no-op hook to log epoch metrics to W&B
     procedure :: wandb_setup
     !! Initialise the W&B run and store project / run-name metadata
  end type wandb_network_type

contains

!###############################################################################
  subroutine wandb_setup(this, project, name)
    !! Initialise a Weights & Biases run for this network.
    !!
    !! Call this once before `train`.  It is a thin wrapper around
    !! `wandb_init` that also stores the project and run name on the type so
    !! they are available for diagnostics.
    implicit none

    ! Arguments
    class(wandb_network_type), intent(inout) :: this
    !! Network instance
    character(*), intent(in) :: project
    !! W&B project name
    character(*), intent(in), optional :: name
    !! W&B run display name (optional)

    this%wandb_project = trim(project)
    if(present(name))then
       this%wandb_run_name = trim(name)
       call wandb_init(project=project, name=name)
    else
       this%wandb_run_name = ""
       call wandb_init(project=project)
    end if

  end subroutine wandb_setup
!###############################################################################


!###############################################################################
  subroutine wandb_post_epoch_hook(this, epoch, loss, accuracy)
    !! Called automatically at the end of each training epoch.
    !! Logs "loss" and "accuracy" to the current W&B run.
    implicit none

    ! Arguments
    class(wandb_network_type), intent(inout) :: this
    !! Network instance
    integer, intent(in) :: epoch
    !! Current epoch number (1-based)
    real(real32), intent(in) :: loss
    !! Mean loss over the epoch
    real(real32), intent(in) :: accuracy
    !! Mean accuracy over the epoch

    call wandb_log("loss",     loss,     step=epoch)
    call wandb_log("accuracy", accuracy, step=epoch)

  end subroutine wandb_post_epoch_hook
!###############################################################################

#endif

end module athena_wandb