Train the network
This function trains the network on the input data for a number of epochs. The input data is split into batches of size batch_size and the network is trained on each batch. The network is trained using the optimiser specified in the network object.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(inout) | :: | this |
Instance of network |
||
| class(*), | intent(in), | dimension(..) | :: | input |
Input data |
|
| class(*), | intent(in), | dimension(:,:) | :: | output |
Output data |
|
| integer, | intent(in) | :: | num_epochs |
Number of epochs |
||
| integer, | intent(in), | optional | :: | batch_size |
Batch size |
|
| real(kind=real32), | intent(in), | optional | :: | plateau_threshold |
Plateau threshold |
|
| logical, | intent(in), | optional | :: | shuffle_batches |
Shuffle batches |
|
| integer, | intent(in), | optional | :: | batch_print_step |
Batch print step |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
|
| integer, | intent(in), | optional | :: | print_precision |
Number of decimal places to print for training metrics |
|
| logical, | intent(in), | optional | :: | scientific_print |
Whether to print training metrics in scientific notation |
|
| logical, | intent(in), | optional | :: | early_stopping |
Whether to stop training early if convergence is detected |
|
| class(*), | intent(in), | optional, | dimension(..) | :: | val_input |
Validation input data |
| class(*), | intent(in), | optional, | dimension(:,:) | :: | val_output |
Validation expected output data |
module subroutine train( & this, input, output, num_epochs, batch_size, & plateau_threshold, shuffle_batches, batch_print_step, verbose, & print_precision, scientific_print, early_stopping, & val_input, val_output & ) !! Train the network !! !! This function trains the network on the input data for a number of !! epochs. The input data is split into batches of size batch_size and !! the network is trained on each batch. The network is trained using !! the optimiser specified in the network object. use athena__tools_infile, only: stop_check implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: input !! Input data class(*), dimension(:,:), intent(in) :: output !! Output data integer, intent(in) :: num_epochs !! Number of epochs integer, optional, intent(in) :: batch_size !! Batch size real(real32), optional, intent(in) :: plateau_threshold !! Plateau threshold logical, optional, intent(in) :: shuffle_batches !! Shuffle batches integer, optional, intent(in) :: batch_print_step !! Batch print step integer, optional, intent(in) :: verbose !! Verbosity level integer, optional, intent(in) :: print_precision !! Number of decimal places to print for training metrics logical, optional, intent(in) :: scientific_print !! Whether to print training metrics in scientific notation logical, optional, intent(in) :: early_stopping !! Whether to stop training early if convergence is detected class(*), dimension(..), optional, intent(in) :: val_input !! Validation input data class(*), dimension(:,:), optional, intent(in) :: val_output !! Validation expected output data ! Training parameters real(real32) :: batch_loss, batch_accuracy, avg_loss, avg_accuracy !! Loss and accuracy ! learning parameters integer :: l, num_samples !! Loop index integer :: num_batches !! Number of batches integer :: converged !! Convergence flag integer :: window_width !! Length of convergence check window integer :: verbose_ !! Verbosity level real(real32) :: plateau_threshold_ !! Plateau threshold logical :: shuffle_batches_ !! Shuffle batches logical :: use_accuracy !! Whether accuracy evaluation is available logical :: early_stopping_ !! Whether to stop training early if convergence is detected ! Printing parameters integer :: batch_print_step_ !! Batch print step integer :: print_precision_ !! Number of decimal places for metric output logical :: scientific_print_ !! Whether to print metrics in scientific notation character(len=64) :: loss_str, accuracy_str !! Formatted metrics for printing character(len=128) :: val_str !! Formatted validation metrics for printing ! Training loop variables integer :: epoch, batch, start_index, end_index !! Loop index integer, allocatable, dimension(:) :: batch_order !! Batch order integer :: i, j, s !! Loop index integer :: tmp_batch_size !! Temporary integer to store batch size during validation integer :: current_batch_size, target_batch_size !! Actual batch size for the current batch and the target batch size logical, allocatable :: mode_store(:) !! Storage for inference mode booleans class(*), allocatable :: data_poly(:,:) type(array_type), pointer :: loss => null() ! Validation variables logical :: use_validation !! Whether validation data is provided integer :: val_num_samples !! Number of validation samples integer :: val_sample !! Loop index for validation real(real32) :: val_loss, val_accuracy, val_loss_sum, val_accuracy_sum !! Validation loss and accuracy #ifdef __INTEL_COMPILER type(array_type), pointer :: saved_expected_array(:,:) => null() !! Saved training expected output (for restoring after validation) #else type(array_type), dimension(:,:), allocatable :: saved_expected_array !! Saved training expected output (for restoring after validation) #endif #ifdef _OPENMP type(network_type) :: this_copy !! Copy of network #endif !--------------------------------------------------------------------------- ! Check loss and accuracy methods are set !--------------------------------------------------------------------------- if(.not.allocated(this%loss))then call stop_program("loss method not set") return end if use_accuracy = associated(this%get_accuracy) accuracy_str = "" val_str = "" !--------------------------------------------------------------------------- ! Check validation data !--------------------------------------------------------------------------- use_validation = present(val_input) .and. present(val_output) if(present(val_input) .neqv. present(val_output))then call stop_program( & "both val_input and val_output must be provided for validation" & ) return end if !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- verbose_ = 0 batch_print_step_ = 20 plateau_threshold_ = 0._real32 shuffle_batches_ = .true. scientific_print_ = .false. print_precision_ = 3 early_stopping_ = .true. if(present(plateau_threshold)) plateau_threshold_ = plateau_threshold if(present(shuffle_batches)) shuffle_batches_ = shuffle_batches if(present(batch_print_step)) batch_print_step_ = batch_print_step if(present(verbose)) verbose_ = verbose if(present(print_precision)) print_precision_ = max(print_precision, 0) if(present(scientific_print)) scientific_print_ = scientific_print if(present(batch_size)) this%batch_size = batch_size if(present(early_stopping)) early_stopping_ = early_stopping !--------------------------------------------------------------------------- ! Initialise monitoring variables !--------------------------------------------------------------------------- window_width = max(ceiling(500._real32/this%batch_size),1) do i = 1, size(this%metrics,dim=1) this%metrics(i)%window_width = window_width end do !--------------------------------------------------------------------------- ! Save input and output to network !--------------------------------------------------------------------------- num_samples = this%save_input( input ) call this%save_output( output ) if(size(output,2).ne.num_samples.and.this%use_graph_output)then call stop_program("number of samples in input and output do not match") return end if !--------------------------------------------------------------------------- ! If parallel, initialise slices !--------------------------------------------------------------------------- num_batches = (num_samples + this%batch_size - 1) / this%batch_size allocate(batch_order(num_batches)) do batch = 1, num_batches batch_order(batch) = batch end do !--------------------------------------------------------------------------- ! Set/reset batch size for training !--------------------------------------------------------------------------- call this%set_batch_size(this%batch_size) target_batch_size = this%batch_size !--------------------------------------------------------------------------- ! Enable training mode !--------------------------------------------------------------------------- call this%set_training_mode(mode_store) epoch_loop: do epoch = 1, num_epochs this%epoch = epoch !------------------------------------------------------------------------ ! Shuffle batch order at the start of each epoch !------------------------------------------------------------------------ if(shuffle_batches_)then call shuffle(batch_order) end if avg_loss = 0._real32 avg_accuracy = 0._real32 !------------------------------------------------------------------------ ! Batch loop ! ... split data up into minibatches for training !------------------------------------------------------------------------ batch_loop: do batch = 1, num_batches ! Set batch start and end index !--------------------------------------------------------------------- start_index = (batch_order(batch) - 1) * this%batch_size + 1 end_index = min(batch_order(batch) * this%batch_size, num_samples) current_batch_size = end_index - start_index + 1 if(current_batch_size.ne.target_batch_size)then call this%set_batch_size(current_batch_size) end if ! Forward pass !--------------------------------------------------------------------- select case(this%use_graph_input) case(.true.) data_poly = get_sample( & this%input_graph, start_index, end_index, current_batch_size & ) case default data_poly = get_sample( & this%input_array, start_index, end_index, current_batch_size, & as_graph = .false. & ) end select call this%forward(data_poly) deallocate(data_poly) ! Backward pass !--------------------------------------------------------------------- loss => this%loss_eval(start_index, end_index) loss%is_temporary = .false. call loss%grad_reverse(reset_graph=.true.) ! Compute loss and accuracy (for monitoring) !--------------------------------------------------------------------- batch_loss = sum(loss%val) batch_accuracy = 0._real32 if(use_accuracy)then batch_accuracy = this%accuracy_eval(output, start_index, end_index) end if ! Average metric over batch size and store !--------------------------------------------------------------------- avg_loss = avg_loss + batch_loss if(use_accuracy)then avg_accuracy = avg_accuracy + batch_accuracy end if ! Update weights and biases using optimisation algorithm !--------------------------------------------------------------------- call this%update() call loss%nullify_graph() deallocate(loss) nullify(loss) ! Print batch results !--------------------------------------------------------------------- if(abs(verbose_).gt.0.and.& (batch.eq.1.or.abs(mod(batch,batch_print_step_)).lt.1.E-6))then loss_str = format_training_real( & avg_loss / real(batch, real32), print_precision_, & scientific_print_ & ) if(use_accuracy)then accuracy_str = ", accuracy=" // trim(format_training_real( & avg_accuracy / real(batch, real32), & print_precision_, scientific_print_ & )) end if write(6,'("epoch=",I0,", batch=",I0,& &", lr=",ES0.2,", loss=",A,A)' & ) & this%epoch, batch, & this%optimiser%lr_decay%get_lr( & this%optimiser%learning_rate, this%optimiser%iter & ), & trim(loss_str), trim(accuracy_str) end if ! Check for user-name stop file !--------------------------------------------------------------------- if(stop_check())then write(0,*) "STOPCAR ENCOUNTERED" write(0,*) "Exiting training loop..." exit epoch_loop end if end do batch_loop call this%metrics(1)%append(avg_loss / real(num_batches, real32)) if(use_accuracy)then call this%metrics(2)%append(avg_accuracy / real(num_batches, real32)) end if !------------------------------------------------------------------------ ! Validation evaluation at end of epoch !------------------------------------------------------------------------ val_str = "" if(use_validation)then #ifdef __INTEL_COMPILER ! Save training expected output. `ifx` crashes on the allocatable ! local declaration used with `move_alloc`, so keep an explicit ! pointer-backed copy here instead. if(allocated(this%expected_array))then allocate(saved_expected_array( & size(this%expected_array, 1), size(this%expected_array, 2) & )) do i = 1, size(this%expected_array, 1) do j = 1, size(this%expected_array, 2) call saved_expected_array(i,j)%allocate( & source=this%expected_array(i,j) & ) call this%expected_array(i,j)%deallocate() end do end do deallocate(this%expected_array) else nullify(saved_expected_array) end if #else call move_alloc(this%expected_array, saved_expected_array) #endif ! Save validation output to network call this%save_output( val_output ) ! Save validation input val_num_samples = this%save_input( val_input ) ! Set batch size to 1 and enable inference mode call this%set_batch_size(1) call this%set_inference_mode() ! Evaluate validation loss and accuracy val_loss_sum = 0._real32 val_accuracy_sum = 0._real32 do val_sample = 1, val_num_samples select case(this%use_graph_input) case(.true.) data_poly = get_sample( & this%input_graph, val_sample, val_sample, 1 & ) case default data_poly = get_sample_array( & this%input_array, val_sample, val_sample, 1, & as_graph = .false. & ) end select call this%forward(data_poly) deallocate(data_poly) loss => this%loss_eval(val_sample, val_sample) val_loss_sum = val_loss_sum + sum(loss%val) call loss%nullify_graph() deallocate(loss) nullify(loss) if(use_accuracy)then val_accuracy_sum = val_accuracy_sum + & this%accuracy_eval(val_output, val_sample, val_sample) end if end do val_loss = val_loss_sum / real(val_num_samples, real32) val_accuracy = val_accuracy_sum / real(val_num_samples, real32) ! Build validation print string val_str = ", val_loss=" // trim(format_training_real( & val_loss, print_precision_, scientific_print_ & )) if(use_accuracy)then val_str = trim(val_str) // ", val_accuracy=" // & trim(format_training_real( & val_accuracy, print_precision_, scientific_print_ & )) end if ! Restore training expected output if(allocated(this%expected_array))then do i = 1, size(this%expected_array, 1) do j = 1, size(this%expected_array, 2) call this%expected_array(i,j)%deallocate() end do end do deallocate(this%expected_array) end if #ifdef __INTEL_COMPILER ! `ifx` does not support `move_alloc` in this context if(associated(saved_expected_array))then allocate(this%expected_array( & size(saved_expected_array, 1), size(saved_expected_array, 2) & )) do i = 1, size(saved_expected_array, 1) do j = 1, size(saved_expected_array, 2) call this%expected_array(i,j)%allocate( & source=saved_expected_array(i,j) & ) call saved_expected_array(i,j)%deallocate() end do end do deallocate(saved_expected_array) nullify(saved_expected_array) end if #else call move_alloc(saved_expected_array, this%expected_array) #endif ! Restore training input num_samples = this%save_input( input ) ! Restore training batch size and inference mode call this%set_batch_size(target_batch_size) call this%set_training_mode() end if ! Print epoch summary results !------------------------------------------------------------------------ if(use_validation.and.verbose_.ge.0)then write(6,'("epoch=",I0,A)') this%epoch, trim(val_str) elseif(verbose_.eq.0)then loss_str = format_training_real( & this%metrics(1)%val, print_precision_, scientific_print_ & ) if(use_accuracy)then accuracy_str = ", accuracy=" // trim(format_training_real( & this%metrics(2)%val, print_precision_, scientific_print_ & )) end if write(6,'("epoch=",I0,& &", lr=",ES0.2,", loss=",A,A,A)' & ) & this%epoch, & this%optimiser%lr_decay%get_lr( & this%optimiser%learning_rate, this%optimiser%iter & ), & trim(loss_str), trim(accuracy_str), trim(val_str) end if !------------------------------------------------------------------------ ! Per-epoch callback (e.g. W&B logging via wandb_network_type) !------------------------------------------------------------------------ call this%post_epoch_hook( & this%epoch, & this%metrics(1)%val, & this%metrics(2)%val & ) !------------------------------------------------------------------------ ! Check for convergence and stop early if enabled ! When validation data is provided, check validation loss for plateau !------------------------------------------------------------------------ if(early_stopping_)then if(use_validation)then if(val_loss .lt. plateau_threshold_ .and. & plateau_threshold_ .gt. 0._real32) exit epoch_loop else call this%metrics(1)%check(plateau_threshold_, converged) if(converged.ne.0) exit epoch_loop end if if(use_accuracy)then call this%metrics(2)%check(plateau_threshold_, converged) if(converged.ne.0) exit epoch_loop end if end if end do epoch_loop ! Final epoch metrics if(use_accuracy)then this%accuracy_val = this%metrics(2)%val else this%accuracy_val = 0._real32 end if this%loss_val = this%metrics(1)%val !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end subroutine train