train Module Subroutine

module subroutine train(this, input, output, num_epochs, batch_size, plateau_threshold, shuffle_batches, batch_print_step, verbose, print_precision, scientific_print, early_stopping, val_input, val_output)

Train the network

This function trains the network on the input data for a number of epochs. The input data is split into batches of size batch_size and the network is trained on each batch. The network is trained using the optimiser specified in the network object.

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

class(*), intent(in), dimension(..) :: input

Input data

class(*), intent(in), dimension(:,:) :: output

Output data

integer, intent(in) :: num_epochs

Number of epochs

integer, intent(in), optional :: batch_size

Batch size

real(kind=real32), intent(in), optional :: plateau_threshold

Plateau threshold

logical, intent(in), optional :: shuffle_batches

Shuffle batches

integer, intent(in), optional :: batch_print_step

Batch print step

integer, intent(in), optional :: verbose

Verbosity level

integer, intent(in), optional :: print_precision

Number of decimal places to print for training metrics

logical, intent(in), optional :: scientific_print

Whether to print training metrics in scientific notation

logical, intent(in), optional :: early_stopping

Whether to stop training early if convergence is detected

class(*), intent(in), optional, dimension(..) :: val_input

Validation input data

class(*), intent(in), optional, dimension(:,:) :: val_output

Validation expected output data


Source Code

  module subroutine train( &
       this, input, output, num_epochs, batch_size, &
       plateau_threshold, shuffle_batches, batch_print_step, verbose, &
       print_precision, scientific_print, early_stopping, &
       val_input, val_output &
  )
    !! Train the network
    !!
    !! This function trains the network on the input data for a number of
    !! epochs. The input data is split into batches of size batch_size and
    !! the network is trained on each batch. The network is trained using
    !! the optimiser specified in the network object.
    use athena__tools_infile, only: stop_check
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: input
    !! Input data
    class(*), dimension(:,:), intent(in) :: output
    !! Output data
    integer, intent(in) :: num_epochs
    !! Number of epochs
    integer, optional, intent(in) :: batch_size
    !! Batch size
    real(real32), optional, intent(in) :: plateau_threshold
    !! Plateau threshold
    logical, optional, intent(in) :: shuffle_batches
    !! Shuffle batches
    integer, optional, intent(in) :: batch_print_step
    !! Batch print step
    integer, optional, intent(in) :: verbose
    !! Verbosity level
    integer, optional, intent(in) :: print_precision
    !! Number of decimal places to print for training metrics
    logical, optional, intent(in) :: scientific_print
    !! Whether to print training metrics in scientific notation
    logical, optional, intent(in) :: early_stopping
    !! Whether to stop training early if convergence is detected
    class(*), dimension(..), optional, intent(in) :: val_input
    !! Validation input data
    class(*), dimension(:,:), optional, intent(in) :: val_output
    !! Validation expected output data

    ! Training parameters
    real(real32) :: batch_loss, batch_accuracy, avg_loss, avg_accuracy
    !! Loss and accuracy

    ! learning parameters
    integer :: l, num_samples
    !! Loop index
    integer :: num_batches
    !! Number of batches
    integer :: converged
    !! Convergence flag
    integer :: window_width
    !! Length of convergence check window
    integer :: verbose_
    !! Verbosity level
    real(real32) :: plateau_threshold_
    !! Plateau threshold
    logical :: shuffle_batches_
    !! Shuffle batches
    logical :: use_accuracy
    !! Whether accuracy evaluation is available
    logical :: early_stopping_
    !! Whether to stop training early if convergence is detected

    ! Printing parameters
    integer :: batch_print_step_
    !! Batch print step
    integer :: print_precision_
    !! Number of decimal places for metric output
    logical :: scientific_print_
    !! Whether to print metrics in scientific notation
    character(len=64) :: loss_str, accuracy_str
    !! Formatted metrics for printing
    character(len=128) :: val_str
    !! Formatted validation metrics for printing

    ! Training loop variables
    integer :: epoch, batch, start_index, end_index
    !! Loop index
    integer, allocatable, dimension(:) :: batch_order
    !! Batch order

    integer :: i, j, s
    !! Loop index
    integer :: tmp_batch_size
    !! Temporary integer to store batch size during validation
    integer :: current_batch_size, target_batch_size
    !! Actual batch size for the current batch and the target batch size
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans

    class(*), allocatable :: data_poly(:,:)
    type(array_type), pointer :: loss => null()

    ! Validation variables
    logical :: use_validation
    !! Whether validation data is provided
    integer :: val_num_samples
    !! Number of validation samples
    integer :: val_sample
    !! Loop index for validation
    real(real32) :: val_loss, val_accuracy, val_loss_sum, val_accuracy_sum
    !! Validation loss and accuracy
#ifdef __INTEL_COMPILER
    type(array_type), pointer :: saved_expected_array(:,:) => null()
    !! Saved training expected output (for restoring after validation)
#else
    type(array_type), dimension(:,:), allocatable :: saved_expected_array
    !! Saved training expected output (for restoring after validation)
#endif

#ifdef _OPENMP
    type(network_type) :: this_copy
    !! Copy of network
#endif


    !---------------------------------------------------------------------------
    ! Check loss and accuracy methods are set
    !---------------------------------------------------------------------------
    if(.not.allocated(this%loss))then
       call stop_program("loss method not set")
       return
    end if
    use_accuracy = associated(this%get_accuracy)
    accuracy_str = ""
    val_str = ""

    !---------------------------------------------------------------------------
    ! Check validation data
    !---------------------------------------------------------------------------
    use_validation = present(val_input) .and. present(val_output)
    if(present(val_input) .neqv. present(val_output))then
       call stop_program( &
            "both val_input and val_output must be provided for validation" &
       )
       return
    end if


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    verbose_ = 0
    batch_print_step_ = 20
    plateau_threshold_ = 0._real32
    shuffle_batches_ = .true.
    scientific_print_ = .false.
    print_precision_ = 3
    early_stopping_ = .true.
    if(present(plateau_threshold)) plateau_threshold_ = plateau_threshold
    if(present(shuffle_batches)) shuffle_batches_ = shuffle_batches
    if(present(batch_print_step)) batch_print_step_ = batch_print_step
    if(present(verbose)) verbose_ = verbose
    if(present(print_precision)) print_precision_ = max(print_precision, 0)
    if(present(scientific_print)) scientific_print_ = scientific_print
    if(present(batch_size)) this%batch_size = batch_size
    if(present(early_stopping)) early_stopping_ = early_stopping


    !---------------------------------------------------------------------------
    ! Initialise monitoring variables
    !---------------------------------------------------------------------------
    window_width = max(ceiling(500._real32/this%batch_size),1)
    do i = 1, size(this%metrics,dim=1)
       this%metrics(i)%window_width = window_width
    end do


    !---------------------------------------------------------------------------
    ! Save input and output to network
    !---------------------------------------------------------------------------
    num_samples = this%save_input( input )
    call this%save_output( output )
    if(size(output,2).ne.num_samples.and.this%use_graph_output)then
       call stop_program("number of samples in input and output do not match")
       return
    end if


    !---------------------------------------------------------------------------
    ! If parallel, initialise slices
    !---------------------------------------------------------------------------
    num_batches = (num_samples + this%batch_size - 1) / this%batch_size
    allocate(batch_order(num_batches))
    do batch = 1, num_batches
       batch_order(batch) = batch
    end do


    !---------------------------------------------------------------------------
    ! Set/reset batch size for training
    !---------------------------------------------------------------------------
    call this%set_batch_size(this%batch_size)
    target_batch_size = this%batch_size


    !---------------------------------------------------------------------------
    ! Enable training mode
    !---------------------------------------------------------------------------
    call this%set_training_mode(mode_store)


    epoch_loop: do epoch = 1, num_epochs
       this%epoch = epoch
       !------------------------------------------------------------------------
       ! Shuffle batch order at the start of each epoch
       !------------------------------------------------------------------------
       if(shuffle_batches_)then
          call shuffle(batch_order)
       end if

       avg_loss     = 0._real32
       avg_accuracy = 0._real32

       !------------------------------------------------------------------------
       ! Batch loop
       ! ... split data up into minibatches for training
       !------------------------------------------------------------------------
       batch_loop: do batch = 1, num_batches


          ! Set batch start and end index
          !---------------------------------------------------------------------
          start_index = (batch_order(batch) - 1) * this%batch_size + 1
          end_index = min(batch_order(batch) * this%batch_size, num_samples)
          current_batch_size = end_index - start_index + 1
          if(current_batch_size.ne.target_batch_size)then
             call this%set_batch_size(current_batch_size)
          end if


          ! Forward pass
          !---------------------------------------------------------------------
          select case(this%use_graph_input)
          case(.true.)
             data_poly = get_sample( &
                  this%input_graph, start_index, end_index, current_batch_size &
             )
          case default
             data_poly = get_sample( &
                  this%input_array, start_index, end_index, current_batch_size, &
                  as_graph = .false. &
             )
          end select
          call this%forward(data_poly)
          deallocate(data_poly)


          ! Backward pass
          !---------------------------------------------------------------------
          loss => this%loss_eval(start_index, end_index)
          loss%is_temporary = .false.
          call loss%grad_reverse(reset_graph=.true.)


          ! Compute loss and accuracy (for monitoring)
          !---------------------------------------------------------------------
          batch_loss = sum(loss%val)
          batch_accuracy = 0._real32
          if(use_accuracy)then
             batch_accuracy = this%accuracy_eval(output, start_index, end_index)
          end if


          ! Average metric over batch size and store
          !---------------------------------------------------------------------
          avg_loss = avg_loss + batch_loss
          if(use_accuracy)then
             avg_accuracy = avg_accuracy + batch_accuracy
          end if


          ! Update weights and biases using optimisation algorithm
          !---------------------------------------------------------------------
          call this%update()
          call loss%nullify_graph()
          deallocate(loss)
          nullify(loss)


          ! Print batch results
          !---------------------------------------------------------------------
          if(abs(verbose_).gt.0.and.&
               (batch.eq.1.or.abs(mod(batch,batch_print_step_)).lt.1.E-6))then
             loss_str = format_training_real( &
                  avg_loss / real(batch, real32), print_precision_, &
                  scientific_print_ &
             )
             if(use_accuracy)then
                accuracy_str = ", accuracy=" // trim(format_training_real( &
                     avg_accuracy / real(batch, real32), &
                     print_precision_, scientific_print_ &
                ))
             end if

             write(6,'("epoch=",I0,", batch=",I0,&
                  &", lr=",ES0.2,", loss=",A,A)' &
             ) &
                  this%epoch, batch, &
                  this%optimiser%lr_decay%get_lr( &
                       this%optimiser%learning_rate, this%optimiser%iter &
                  ), &
                  trim(loss_str), trim(accuracy_str)
          end if


          ! Check for user-name stop file
          !---------------------------------------------------------------------
          if(stop_check())then
             write(0,*) "STOPCAR ENCOUNTERED"
             write(0,*) "Exiting training loop..."
             exit epoch_loop
          end if

       end do batch_loop
       call this%metrics(1)%append(avg_loss / real(num_batches, real32))
       if(use_accuracy)then
          call this%metrics(2)%append(avg_accuracy / real(num_batches, real32))
       end if


       !------------------------------------------------------------------------
       ! Validation evaluation at end of epoch
       !------------------------------------------------------------------------
       val_str = ""
       if(use_validation)then

#ifdef __INTEL_COMPILER
          ! Save training expected output. `ifx` crashes on the allocatable
          ! local declaration used with `move_alloc`, so keep an explicit
          ! pointer-backed copy here instead.
          if(allocated(this%expected_array))then
             allocate(saved_expected_array( &
                  size(this%expected_array, 1), size(this%expected_array, 2) &
             ))
             do i = 1, size(this%expected_array, 1)
                do j = 1, size(this%expected_array, 2)
                   call saved_expected_array(i,j)%allocate( &
                        source=this%expected_array(i,j) &
                   )
                   call this%expected_array(i,j)%deallocate()
                end do
             end do
             deallocate(this%expected_array)
          else
             nullify(saved_expected_array)
          end if
#else
          call move_alloc(this%expected_array, saved_expected_array)
#endif

          ! Save validation output to network
          call this%save_output( val_output )

          ! Save validation input
          val_num_samples = this%save_input( val_input )

          ! Set batch size to 1 and enable inference mode
          call this%set_batch_size(1)
          call this%set_inference_mode()

          ! Evaluate validation loss and accuracy
          val_loss_sum = 0._real32
          val_accuracy_sum = 0._real32
          do val_sample = 1, val_num_samples
             select case(this%use_graph_input)
             case(.true.)
                data_poly = get_sample( &
                     this%input_graph, val_sample, val_sample, 1 &
                )
             case default
                data_poly = get_sample_array( &
                     this%input_array, val_sample, val_sample, 1, &
                     as_graph = .false. &
                )
             end select
             call this%forward(data_poly)
             deallocate(data_poly)

             loss => this%loss_eval(val_sample, val_sample)
             val_loss_sum = val_loss_sum + sum(loss%val)
             call loss%nullify_graph()
             deallocate(loss)
             nullify(loss)

             if(use_accuracy)then
                val_accuracy_sum = val_accuracy_sum + &
                     this%accuracy_eval(val_output, val_sample, val_sample)
             end if
          end do

          val_loss = val_loss_sum / real(val_num_samples, real32)
          val_accuracy = val_accuracy_sum / real(val_num_samples, real32)

          ! Build validation print string
          val_str = ", val_loss=" // trim(format_training_real( &
               val_loss, print_precision_, scientific_print_ &
          ))
          if(use_accuracy)then
             val_str = trim(val_str) // ", val_accuracy=" // &
                  trim(format_training_real( &
                       val_accuracy, print_precision_, scientific_print_ &
                  ))
          end if

          ! Restore training expected output
          if(allocated(this%expected_array))then
             do i = 1, size(this%expected_array, 1)
                do j = 1, size(this%expected_array, 2)
                   call this%expected_array(i,j)%deallocate()
                end do
             end do
             deallocate(this%expected_array)
          end if
#ifdef __INTEL_COMPILER
          ! `ifx` does not support `move_alloc` in this context
          if(associated(saved_expected_array))then
             allocate(this%expected_array( &
                  size(saved_expected_array, 1), size(saved_expected_array, 2) &
             ))
             do i = 1, size(saved_expected_array, 1)
                do j = 1, size(saved_expected_array, 2)
                   call this%expected_array(i,j)%allocate( &
                        source=saved_expected_array(i,j) &
                   )
                   call saved_expected_array(i,j)%deallocate()
                end do
             end do
             deallocate(saved_expected_array)
             nullify(saved_expected_array)
          end if
#else
          call move_alloc(saved_expected_array, this%expected_array)
#endif

          ! Restore training input
          num_samples = this%save_input( input )

          ! Restore training batch size and inference mode
          call this%set_batch_size(target_batch_size)
          call this%set_training_mode()

       end if


       ! Print epoch summary results
       !------------------------------------------------------------------------
       if(use_validation.and.verbose_.ge.0)then
          write(6,'("epoch=",I0,A)') this%epoch, trim(val_str)
       elseif(verbose_.eq.0)then
          loss_str = format_training_real( &
               this%metrics(1)%val, print_precision_, scientific_print_ &
          )
          if(use_accuracy)then
             accuracy_str = ", accuracy=" // trim(format_training_real( &
                  this%metrics(2)%val, print_precision_, scientific_print_ &
             ))
          end if
          write(6,'("epoch=",I0,&
               &", lr=",ES0.2,", loss=",A,A,A)' &
          ) &
               this%epoch, &
               this%optimiser%lr_decay%get_lr( &
                    this%optimiser%learning_rate, this%optimiser%iter &
               ), &
               trim(loss_str), trim(accuracy_str), trim(val_str)
       end if


       !------------------------------------------------------------------------
       ! Per-epoch callback (e.g. W&B logging via wandb_network_type)
       !------------------------------------------------------------------------
       call this%post_epoch_hook( &
            this%epoch, &
            this%metrics(1)%val, &
            this%metrics(2)%val &
       )

       !------------------------------------------------------------------------
       ! Check for convergence and stop early if enabled
       ! When validation data is provided, check validation loss for plateau
       !------------------------------------------------------------------------
       if(early_stopping_)then
          if(use_validation)then
             if(val_loss .lt. plateau_threshold_ .and. &
                  plateau_threshold_ .gt. 0._real32) exit epoch_loop
          else
             call this%metrics(1)%check(plateau_threshold_, converged)
             if(converged.ne.0) exit epoch_loop
          end if
          if(use_accuracy)then
             call this%metrics(2)%check(plateau_threshold_, converged)
             if(converged.ne.0) exit epoch_loop
          end if
       end if

    end do epoch_loop

    ! Final epoch metrics
    if(use_accuracy)then
       this%accuracy_val = this%metrics(2)%val
    else
       this%accuracy_val = 0._real32
    end if
    this%loss_val     = this%metrics(1)%val


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end subroutine train