submodule(athena__network) athena__network_submodule
  !! Submodule containing implementations for the network module
#ifdef _OPENMP
  use omp_lib
#endif
  use coreutils, only: stop_program, print_warning, to_lower
  use athena__misc_ml, only: shuffle

  use athena__accuracy, only: categorical_score, mae_score, mse_score, r2_score
  use athena__base_layer, only: learnable_layer_type, merge_layer_type
#if defined(GFORTRAN)
  use athena__container_layer, only: container_reduction
#endif

  use athena__container_layer, only: &
       list_of_layer_types, allocate_list_of_layer_types, &
       list_of_onnx_layer_creators, allocate_list_of_onnx_layer_creators

  ! Layer types
  use athena__flatten_layer, only: flatten_layer_type
  use athena__add_layer, only: add_layer_type
  use athena__concat_layer, only: concat_layer_type
  use athena__input_layer,   only: input_layer_type
  use athena__msgpass_layer, only: msgpass_layer_type
  use athena__recurrent_layer, only: recurrent_layer_type

! #ifdef _OPENMP
!   !$omp declare reduction( &
!   !$omp& network_reduction : network_type:omp_out%network_reduction(omp_in)) &
!   !$omp& initialiser(omp_priv = omp_orig)
! #endif

contains

!###############################################################################
  module subroutine network_reduction(this, source)
    !! Procedure to add two networks together
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    type(network_type), intent(in) :: source
    !! Instance of network to be added to this

    ! Local variables
    integer :: i
    !! Loop index

    this%metrics(1)%val = this%metrics(1)%val + source%metrics(1)%val
    this%metrics(2)%val = this%metrics(2)%val + source%metrics(2)%val
    do i=1,size(this%model)
       select type(layer_this => this%model(i)%layer)
       class is(learnable_layer_type)
          select type(layer_source => source%model(i)%layer)
          class is(learnable_layer_type)
             call layer_this%reduce(layer_source)
          end select
       end select
    end do

  end subroutine network_reduction
!###############################################################################


!###############################################################################
  module subroutine network_copy(this, source)
    !! Procedure to copy a network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    type(network_type), intent(in), target :: source
    !! Instance of network to be copied

    ! Local variables
    integer :: i
    !! Loop index


    this%metrics = source%metrics
    this%model   = source%model
    this%num_layers = source%num_layers
    this%batch_size = source%batch_size
    this%num_params = source%num_params
    this%num_outputs = source%num_outputs
    this%optimiser = source%optimiser
    this%vertex_order = source%vertex_order
    this%root_vertices = source%root_vertices
    this%leaf_vertices = source%leaf_vertices
    this%loss = source%loss
    this%get_accuracy => source%get_accuracy
    this%auto_graph = source%auto_graph

  end subroutine network_copy
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module subroutine build_vertex_order(this)
    !! Generate the order of the layers in the network
    !!
    !! This module contains the subroutine to generate the order of the layers
    !! in the network. The order is generated by depth first search (DFS) on the
    !! graph of the network.
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    ! Local variables
    integer :: i, order_index
    !! Loop index
    logical, dimension(this%auto_graph%num_vertices) :: visited
    !! Array to store whether a vertex has been

    visited = .false.
    if(allocated(this%vertex_order)) deallocate(this%vertex_order)
    allocate(this%vertex_order(this%auto_graph%num_vertices), source=0)

    order_index = 0
    do i = this%auto_graph%num_vertices, 1, -1
       if(.not.visited(i)) call this%dfs( &
            i, visited, this%vertex_order, order_index &
       )
    end do

  end subroutine build_vertex_order
!###############################################################################


!###############################################################################
  recursive module subroutine dfs( &
       this, vertex_index, visited, order, order_index &
  )
    !! Depth first search algorithm
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    integer, intent(in) :: vertex_index
    !! Index of the vertex to start the search from
    logical, dimension(this%auto_graph%num_vertices), intent(inout) :: visited
    !! Array to store whether a vertex has been visited
    integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order
    !! Array to store the order of the vertices
    integer, intent(inout) :: order_index
    !! Index of the current vertex in the order array

    ! Local variables
    integer :: i
    !! Loop index

    visited(vertex_index) = .true.
    do i = 1, this%auto_graph%num_vertices, 1
       if(this%auto_graph%adjacency(i,vertex_index).ne.0)then
          if(.not.visited(i)) call this%dfs(i, visited, order, order_index)
       end if
    end do
    order_index = order_index + 1
    order(order_index) = vertex_index

  end subroutine dfs
!###############################################################################


!###############################################################################
  module subroutine build_root_vertices(this)
    !! Calculate the root vertices of the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    ! Local variables
    integer :: i
    !! Loop index

    if(allocated(this%root_vertices)) deallocate(this%root_vertices)
    allocate(this%root_vertices(0))
    ! from = 1
    do i = 1, this%auto_graph%num_vertices
       if(all(this%auto_graph%adjacency(:,i).eq.0))then
          this%root_vertices = [this%root_vertices, i]
          ! to = from + this%model(i)layer%num_input_data - 1
          ! this%root_bounds = [ this%root_bounds, reshape([from,to], [2,1]) ]
          ! from = to + 1
       end if
    end do
  end subroutine build_root_vertices
!###############################################################################


!###############################################################################
  module subroutine build_leaf_vertices(this)
    !! Calculate the output vertices of the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    ! Local variables
    integer :: i
    !! Loop index

    if(allocated(this%leaf_vertices)) deallocate(this%leaf_vertices)
    allocate(this%leaf_vertices(0))
    do i = 1, this%auto_graph%num_vertices
       if(all(this%auto_graph%adjacency(i,:).eq.0))then
          this%leaf_vertices = [this%leaf_vertices, i]
       end if
    end do
  end subroutine build_leaf_vertices
!###############################################################################





!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module subroutine print(this, file)
    !! Print the network to a file
    use coreutils, only: to_upper
    use athena__io_utils, only: athena__version__
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    character(*), intent(in) :: file
    !! File to print the network to

    ! Local variables
    integer :: l, v, e, vertex_index, unit
    !! Loop index
    integer :: operator_in, operator_out
    !! Operators for the layer
    character(3) :: operator_str
    !! String to store the operator
    character(256) :: suffix, fmt
    !! Suffix for the layer
    integer, dimension(:), allocatable :: input_list, output_list

    open(newunit=unit,file=file,status='replace')

    write(unit,'("NETWORK_SETTINGS")')
    write(unit,'(3X,"ATHENA_VERSION = ",A)') trim(adjustl(athena__version__))
    if(allocated(this%name)) write(unit,'(3X,"NAME = ",A)') trim(adjustl(this%name))
    write(unit,'(3X,"EPOCH = ",I0)') this%epoch
    write(unit,'(3X,"BATCH_SIZE = ",I0)') this%batch_size
    write(unit,'(3X,"ACCURACY = ",F0.9)') this%accuracy_val
    write(unit,'(3X,"LOSS = ",F0.9)') this%loss_val
    if(allocated(this%accuracy_method))then
       write(unit,'(3X,"ACCURACY_METHOD = ",A)') trim(adjustl(this%accuracy_method))
    end if
    if(allocated(this%loss_method))then
       write(unit,'(3X,"LOSS_METHOD = ",A)') trim(adjustl(this%loss_method))
    end if
    if(allocated(this%optimiser))then
       write(unit,'(3X,"OPTIMISER: ",A)') trim(adjustl(this%optimiser%name))
       call this%optimiser%print_to_unit(unit=unit)
       write(unit,'(3X,"END OPTIMISER")')
    end if
    write(unit,'("END NETWORK_SETTINGS")')

    do v = 1, size(this%vertex_order,dim=1), 1
       l = this%vertex_order(v)
       operator_in = -1
       operator_out = -1
       allocate(input_list(0), output_list(0))
       do e = 1, this%auto_graph%num_edges
          if(-this%auto_graph%edge(e)%index(2).eq.l)then
             if(operator_in.gt.0.and.this%auto_graph%edge(e)%id.ne.operator_in)then
                write(0,*) "WARNING: multiple operators for layer ", l
                write(0,*) "  using operator ", this%auto_graph%edge(e)%id
             end if
             operator_in = this%auto_graph%edge(e)%id
             vertex_index = &
                  findloc( this%vertex_order, this%auto_graph%edge(e)%index(1), 1 )
             input_list = [ input_list, vertex_index ]
          end if
          if(this%auto_graph%edge(e)%index(1).eq.l)then
             if(operator_out.gt.0.and.this%auto_graph%edge(e)%id.ne.operator_out)then
                write(0,*) "WARNING: multiple operators for layer ", l
                write(0,*) "  using operator ", this%auto_graph%edge(e)%id
             end if
             operator_in = this%auto_graph%edge(e)%id
             vertex_index = &
                  findloc( this%vertex_order, this%auto_graph%edge(e)%index(2), 1 )
             output_list = [ output_list, vertex_index ]
          end if
       end do

       suffix = ""
       select case(operator_in)
       case(1)
          operator_str = " ||"
       case(2)
          operator_str = " +"
       case(3)
          operator_str = " *"
       end select
       ! get size of input_list and make the formatted string
       if(size(input_list).eq.0)then
          write(suffix,'(A," []")') trim(operator_str)
       else
          write(fmt,'("(A,A,"" ["",",I0,"(1X,I0),"" ]"")")') size(input_list)
          write(suffix,fmt) trim(suffix), operator_str, input_list
       end if
       ! select case(operator_out)
       ! case(1)
       !    operator_str = " ||"
       ! case(2)
       !    operator_str = " +"
       ! case(3)
       !    operator_str = " *"
       ! end select
       ! if(size(output_list).gt.0)then
       !    write(fmt,'("(A,A,"" ["",",I0,"(1X,I0),"" ]"")")') size(output_list)
       !    write(suffix,fmt) trim(suffix), operator_str, output_list
       ! end if

       write(unit,'(A,A)') to_upper(trim(this%model(l)%layer%name)), trim(suffix)
       call this%model(l)%layer%print(unit=unit, print_header_footer=.false.)

       write(unit,'("END ",A)') to_upper(trim(this%model(l)%layer%name))
       deallocate(input_list, output_list)
    end do
    close(unit)

  end subroutine print
!###############################################################################


!###############################################################################
  module subroutine read(this, file)
    !! Read the network from a file
    use coreutils, only: icount
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    character(*), intent(in) :: file
    !! File to read the network from

    ! Local variables
    integer :: i, unit, stat, itmp1
    !! Loop index
    integer, dimension(:), allocatable :: input_list, output_list
    !!! List of input and output layers
    character(256) :: buffer, err_msg, input_str, output_str
    !! Buffer for reading lines from file
    character(20) :: name
    !! Name of the layer
    character(2) :: operator_in, operator_out
    !! Operator for the layer
    integer :: layer_index
    !! Index of the layer in the list of layer types


    if(.not.allocated(list_of_layer_types))then
       call allocate_list_of_layer_types()
    end if

    open(newunit=unit,file=file,action='read')
    i = 0
    card_loop: do
       i = i + 1
       read(unit,'(A)',iostat=stat) buffer
       if(stat.lt.0)then
          exit card_loop
       elseif(stat.gt.0)then
          call stop_program("error encountered in network read")
          return
       end if
       if(trim(adjustl(buffer)).eq."") cycle card_loop

       !! check if a tag line
       if(scan(buffer,'=').ne.0)then
          write(0,*) "WARNING: unexpected line in read file"
          write(0,*) trim(buffer)
          write(0,*) " skipping..."
          cycle card_loop
       end if

       !! check for card
       name = trim(adjustl(buffer(1:scan(buffer,' ')-1)))
       if(name.eq."NETWORK_SETTINGS")then
          call this%read_network_settings(unit)
          cycle card_loop
       end if
       buffer = trim(adjustl(buffer(scan(buffer,' ')+1:)))
       operator_in = trim(adjustl(buffer(1:scan(buffer,' ')-1)))
       buffer = trim(adjustl(buffer(scan(buffer,' ')+1:)))
       input_str = trim(adjustl(buffer(1:scan(buffer,']'))))
       if(scan(input_str,'[').ne.0)then
          input_str = &
               trim(adjustl(input_str(scan(input_str,'[')+1:scan(input_str,']')-1)))
          itmp1 = icount(input_str)
          allocate(input_list(itmp1))
          read(input_str,*) input_list
       else
          allocate(input_list, source = [-1])
       end if
       buffer = trim(adjustl(buffer(scan(buffer,']')+1:)))
       operator_out = trim(adjustl(buffer(1:scan(buffer,' ')-1)))
       buffer = trim(adjustl(buffer(scan(buffer,' ')+1:)))
       output_str = trim(adjustl(buffer(1:scan(buffer,']'))))
       if(scan(output_str,'[').ne.0)then
          output_str = &
               trim(adjustl(output_str(scan(output_str,'[')+1:scan(output_str,']')-1)))
          itmp1 = icount(output_str)
          allocate(output_list(itmp1))
          read(output_str,*) output_list
       else
          allocate(output_list(0))
       end if
       name = trim(adjustl(to_lower(name)))
       layer_index = &
            findloc( &
                 [ list_of_layer_types(:)%name ], &
                 name, &
                 dim = 1 &
            )
       if(layer_index.eq.0)then
          write(err_msg,'("unrecognised card ''",A)') trim(adjustl(buffer))
          call stop_program(err_msg)
          return
       end if
       call this%add( &
            list_of_layer_types(layer_index)%read_ptr(unit), &
            input_list = input_list, &
            operator = operator_in &
       )
       if(allocated(input_list)) deallocate(input_list)
       if(allocated(output_list)) deallocate(output_list)
    end do card_loop
    close(unit)

  end subroutine read
!###############################################################################


!###############################################################################
  module subroutine read_network_settings(this, unit)
    !! Read the network settings from a file
    use athena__tools_infile, only: assign_val, assign_vec
    use coreutils, only: to_lower, to_upper, icount
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    integer, intent(in) :: unit
    !! File unit

    ! Local variables
    integer :: stat
    !! File status
    integer :: itmp1
    !! Temporary integer
    character(20) :: accuracy_method, loss_method
    !! Methods for accuracy and loss
    character(256) :: buffer, tag, err_msg, name_
    !! Buffer for reading lines, tag for identifying lines, error message


    ! Loop over tags in layer card
    !---------------------------------------------------------------------------
    accuracy_method = ""
    loss_method = ""
    tag_loop: do

       ! Check for end of file
       !------------------------------------------------------------------------
       read(unit,'(A)',iostat=stat) buffer
       if(stat.ne.0)then
          write(err_msg,'("file encountered error (EoF?) before END ",A)') &
               to_upper(this%name)
          call stop_program(err_msg)
          return
       end if
       if(trim(adjustl(buffer)).eq."") cycle tag_loop

       ! Check for end of layer card
       !------------------------------------------------------------------------
       if(trim(adjustl(buffer)).eq."END NETWORK_SETTINGS")then
          backspace(unit)
          exit tag_loop
       end if

       tag=trim(adjustl(buffer))
       if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
       if(scan(buffer,":").ne.0) tag=trim(tag(:scan(tag,":")-1))

       ! Read parameters from save file
       !------------------------------------------------------------------------
       select case(trim(tag))
       case("ATHENA_VERSION")
          ! Ignore this tag, it is only for information
       case("NAME")
          call assign_val(buffer, name_, itmp1)
          if(len(trim(adjustl(name_))) .gt. 0)then
             this%name = trim(adjustl(name_))
          end if
       case("EPOCH")
          call assign_val(buffer, this%epoch, itmp1)
       case("BATCH_SIZE")
          call assign_val(buffer, this%batch_size, itmp1)
       case("ACCURACY")
          call assign_val(buffer, this%accuracy_val, itmp1)
       case("LOSS")
          call assign_val(buffer, this%loss_val, itmp1)
       case("ACCURACY_METHOD")
          call assign_val(buffer, accuracy_method, itmp1)
          call this%set_accuracy(accuracy_method)
       case("LOSS_METHOD")
          call assign_val(buffer, loss_method, itmp1)
          call this%set_loss(loss_method)
       case("OPTIMISER")
          backspace(unit)
          call this%read_optimiser_settings(unit)
       case default
          ! Don't look for "e" due to scientific notation of numbers
          ! ... i.e. exponent (E+00)
          if(scan(to_lower(trim(adjustl(buffer))),&
               'abcdfghijklmnopqrstuvwxyz').eq.0)then
             cycle tag_loop
          elseif(tag(:3).eq.'END')then
             cycle tag_loop
          end if
          write(err_msg,'("Unrecognised line in input file: ",A)') &
               trim(adjustl(buffer))
          call stop_program(err_msg)
          return
       end select
    end do tag_loop


    ! Check for end of layer card
    !---------------------------------------------------------------------------
    read(unit,'(A)') buffer
    if(trim(adjustl(buffer)).ne."END NETWORK_SETTINGS")then
       write(0,*) trim(adjustl(buffer))
       write(err_msg,'("END NETWORK_SETTINGS not where expected")')
       call stop_program(err_msg)
       return
    end if

  end subroutine read_network_settings
!-------------------------------------------------------------------------------
  module subroutine read_optimiser_settings(this, unit)
    !! Read the optimiser settings from a file
    use coreutils, only: to_lower, to_upper, icount
    use athena__optimiser, only: &
         sgd_optimiser_type, adam_optimiser_type, rmsprop_optimiser_type, &
         adagrad_optimiser_type, base_optimiser_type
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    integer, intent(in) :: unit
    !! File unit

    ! Local variables
    integer :: stat
    !! File status
    character(20) :: optimiser_name
    !! Name of the optimiser
    character(256) :: buffer, err_msg, tmp
    !! Buffer for reading lines, error message

    ! Read until end of optimiser settings
    read(unit,'(A)',iostat=stat) buffer
    if(stat.ne.0)then
       write(err_msg,'("file encountered error (EoF?) before END ",A)') &
            to_upper(this%name)
       call stop_program(err_msg)
       return
    end if
    read(buffer,*) tmp, optimiser_name

    select case(trim(adjustl(to_lower(optimiser_name))))
    case("sgd")
       this%optimiser = sgd_optimiser_type()
    case("adam")
       this%optimiser = adam_optimiser_type()
    case("rmsprop")
       this%optimiser = rmsprop_optimiser_type()
    case("adagrad")
       this%optimiser = adagrad_optimiser_type()
    case("","base")
       this%optimiser = base_optimiser_type()
    case default
       write(err_msg,'("Unrecognised optimiser: ",A)') trim(adjustl(optimiser_name))
       call stop_program(err_msg)
       return
    end select
    call this%optimiser%read(unit)

  end subroutine read_optimiser_settings
!###############################################################################


!###############################################################################
  module subroutine build_from_onnx( &
       this, nodes, initialisers, inputs, value_info, verbose &
  )
    !! Build network from ONNX nodes and initialisers
    use coreutils, only: to_lower
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    type(onnx_node_type), dimension(:), intent(in) :: nodes
    !! Array of ONNX nodes
    type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
    !! Array of ONNX initialisers
    type(onnx_tensor_type), dimension(:), intent(in) :: inputs
    !! Array of ONNX inputs
    type(onnx_tensor_type), dimension(:), intent(in) :: value_info
    !! Array of ONNX value infos
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: i, j, k, j_out, layer_index
    !! Loop indices
    integer :: verbose_ = 0
    !! Verbosity level
    character(20) :: op_type
    !! Lowercase op_type
    character(64) :: tmp_name
    !! Temporary name for matching
    character(256) :: err_msg
    !! Error message
    integer, dimension(:), allocatable :: input_shape
    !! Shape of input layer
    integer, dimension(:), allocatable :: input_list
    !! List of input layers
    type(onnx_initialiser_type), dimension(:), allocatable :: init_list
    !! List of initialisers for a specific node
    type(onnx_tensor_type), dimension(:), allocatable :: value_info_list
    !! List of value info tensors

    verbose_ = 0
    if(present(verbose)) verbose_ = verbose


    if(.not.allocated(list_of_onnx_layer_creators))then
       call allocate_list_of_onnx_layer_creators()
    end if

    do i = 1, size(inputs)
       input_shape = inputs(i)%dims(2:size(inputs(i)%dims))

       call this%add( &
            input_layer_type(input_shape, index=i) &
       )
    end do

    ! Loop through nodes and create layers
    do i = 1, size(nodes)
       if(verbose_.gt.0) write(*,*) "Processing ONNX node: ", trim(nodes(i)%name), &
            " (", trim(nodes(i)%op_type), ")"
       op_type = trim(adjustl(nodes(i)%op_type))

       layer_index = &
            findloc( &
                 [ list_of_onnx_layer_creators(:)%op_type ], &
                 op_type, &
                 dim = 1 &
            )
       if(layer_index.eq.0)then
          write(err_msg,'("unrecognised op_type ''",A)') trim(adjustl(nodes(i)%op_type))
          call stop_program(err_msg)
          return
       end if

       ! find all input layers and initialisers for this node
       ! ... i.e. check over inputs for name matches
       j_out = 0
       allocate(init_list(0))
       allocate(input_list(0))
       allocate(value_info_list(0))
       do j = 1, size(nodes(i)%inputs)
          do k = 1, size(initialisers)
             if(trim(nodes(i)%inputs(j)) .eq. trim(initialisers(k)%name))then
                init_list = [ init_list, initialisers(k) ]
             end if
          end do
          do k = 1, size(inputs)
             if(trim(nodes(i)%inputs(j)) .eq. trim(inputs(k)%name))then
                input_list = [ input_list, k ]
             end if
          end do
          tmp_name = trim(nodes(i)%inputs(j))
          if(index(tmp_name, "_output").gt.0) &
               tmp_name = trim(tmp_name(:index(tmp_name, "_output")-1))
          do k = 1, size(nodes)
             if(trim(tmp_name) .eq. trim(nodes(k)%name))then
                input_list = [ input_list, k + size(inputs) ]
             end if
          end do
       end do
       do j = 1, size(nodes(i)%outputs)
          do k = 1, size(value_info)
             if(trim(nodes(i)%outputs(j)) .eq. trim(value_info(k)%name))then
                value_info_list = [ value_info_list, value_info(k) ]
             end if
          end do
       end do
       if(size(init_list)+size(input_list).ne.size(nodes(i)%inputs))then
          if(verbose_.gt.0)then
             write(0,*) "WARNING: not all inputs found for node ", &
                  trim(nodes(i)%name)
          end if
       end if

       ! assume default operator

       call this%add( &
            list_of_onnx_layer_creators(layer_index)%create_ptr( &
                 nodes(i), init_list, value_info_list &
            ), &
            input_list = input_list &
            ! operator = operator_in &
       )
       deallocate(input_list)
       deallocate(init_list)
       deallocate(value_info_list)
    end do

    if(verbose_.gt.0) write(*,*) "ONNX model built with ", this%num_layers, " layers."

  end subroutine build_from_onnx
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module subroutine add(this, layer, input_list, output_list, operator)
    !! Add a layer to the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(base_layer_type), intent(in) :: layer
    !! Layer to add to the network
    integer, dimension(:), optional, intent(in) :: input_list
    !! List of input layers
    integer, dimension(:), optional, intent(in) :: output_list
    !! List of output layers
    class(*), optional, intent(in) :: operator
    !! Operator to use to connect the layers

    ! Local variables
    integer :: i, vertex_index
    !! Loop index
    integer :: operator_
    !! Operator to use to connect the layers
    character(256) :: err_msg
    !! Error message
    integer, dimension(2) :: vertex_indices
    !! Indices of the vertices to connect
    type(container_layer_type), allocatable, dimension(:) :: model
    !! Model to add the layer to



    if(.not.allocated(this%model))then
       this%model = [container_layer_type()]
       this%num_layers = 1
    else
       allocate(model(size(this%model,dim=1)+1))
       do i = 1, size(this%model,dim=1)
          allocate(model(i)%layer, source=this%model(i)%layer)
       end do
       call move_alloc(model, this%model)
       this%num_layers = this%num_layers + 1
    end if
    allocate(this%model(size(this%model,dim=1))%layer, source=layer)
    this%model(size(this%model,dim=1))%layer%id = this%num_layers


    operator_ = 1
    if(present(operator))then
       select type(operator)
       type is(integer)
          operator_ = operator
       type is(character(*))
          select case(trim(to_lower(operator)))
          case("||", "concat", "concatenate", "append")
             operator_ = 1
          case("+", "add")
             operator_ = 2
          case("*", "x", "mul", "multiply")
             operator_ = 3
          end select
       end select
    end if
    if(operator_.gt.2.or.operator_.lt.1)then
       call stop_program("invalid operator")
       return
    end if

    ! edge_index(1) = index of the previous layer
    ! abs(edge_index(2)) = index of the current layer
    ! the -ve sign of edge_index(2) indicates that the edge goes from the
    !   previous layer to the current layer
    !   i.e. forward pass flows from positive to negative
    ! adjacency(i,:) is all of the layers that i feeds forward to
    ! adjacency(:,i) is all of the layers that feed forward to i
    !   (i.e. the backward pass)
    this%auto_graph%directed = .true.
    call this%auto_graph%add_vertex( &
         feature=[1._real32], id=this%num_layers, update_adjacency=.true. &
    )
    if(present(input_list))then
       do i = 1, size(input_list), 1
          if(input_list(i).eq.0)then
             vertex_index = 0
          elseif( &
               input_list(i).le.-this%auto_graph%num_vertices .or. &
               input_list(i).gt.this%auto_graph%num_vertices &
          )then
             write(err_msg, &
                  '("input vertex index ",I0," out of range (",I0,":",I0,")")' &
             ) &
                  input_list(i), &
                  -this%auto_graph%num_vertices +1, &
                  this%auto_graph%num_vertices
             call stop_program(err_msg)
             return
          elseif(input_list(i).lt.0)then
             vertex_index = this%auto_graph%num_vertices + input_list(i)
          else
             vertex_index = findloc( &
                  [this%auto_graph%vertex(:)%id], &
                  input_list(i), 1 &
             )
          end if
          vertex_indices = [ vertex_index, -this%auto_graph%num_vertices ]
          call this%auto_graph%add_edge( &
               index = vertex_indices, &
               feature = [ 1._real32 ], &
               id = operator_, &
               update_adjacency = .true. &
          )
       end do
    elseif(trim(layer%type).ne."inpt".and.this%auto_graph%num_vertices.gt.1)then
       vertex_indices = [ &
            this%auto_graph%num_vertices - 1, &
            -this%auto_graph%num_vertices &
       ]
       call this%auto_graph%add_edge( &
            index = vertex_indices, &
            feature = [ 1._real32 ], &
            id = operator_, &
            update_adjacency = .true. &
       )
    end if

    if(present(output_list))then
       do i = 1, size(output_list), 1
          vertex_index = findloc( &
               [this%auto_graph%vertex(:)%id], &
               output_list(i), 1 &
          )
          vertex_indices = [ this%auto_graph%num_vertices, -vertex_index ]
          call this%auto_graph%add_edge( &
               index = vertex_indices, &
               feature = [ 1._real32 ], &
               id = operator_, &
               update_adjacency = .true. &
          )
       end do
    end if

  end subroutine add
!###############################################################################


!###############################################################################
  module function network_setup( &
       layers, optimiser, loss_method, accuracy_method, &
       metrics, batch_size &
  ) result(network)
    !! Setup the network
    implicit none

    ! Arguments
    type(container_layer_type), dimension(:), intent(in) :: layers
    !! Layers to add to the network
    class(base_optimiser_type), optional, intent(in) :: optimiser
    !! Optimiser to use for training
    class(*), optional, intent(in) :: loss_method
    !! Loss method
    character(*), optional, intent(in) :: accuracy_method
    !! Accuracy method
    class(*), dimension(..), optional, intent(in) :: metrics
    !! Metrics
    integer, optional, intent(in) :: batch_size
    !! Batch size

    type(network_type) :: network
    !! Network to setup

    ! Local variables
    integer :: l
    !! Loop index


    !---------------------------------------------------------------------------
    ! Handle optional arguments
    !---------------------------------------------------------------------------
    if(present(loss_method)) call network%set_loss(loss_method)
    if(present(accuracy_method)) call network%set_accuracy(accuracy_method)
    if(present(metrics)) call network%set_metrics(metrics)
    if(present(batch_size)) network%batch_size = batch_size
    network%auto_graph%directed = .true.


    !---------------------------------------------------------------------------
    ! Add layers to network
    !---------------------------------------------------------------------------
    do l = 1, size(layers)
       call network%add(layers(l)%layer)
    end do


    !---------------------------------------------------------------------------
    ! Compile network if optimiser present
    !---------------------------------------------------------------------------
    if(present(optimiser)) call network%compile(optimiser)

  end function network_setup
!###############################################################################


!###############################################################################
  module subroutine set_metrics(this, metrics)
    !! Set the metrics for the network
    use coreutils, only: to_lower
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: metrics
    !! Metrics

    ! Local variables
    integer :: i
    !! Loop index


    this%metrics%active = .false.
    this%metrics(1)%key = "loss"
    this%metrics(2)%key = "accuracy"
    this%metrics%threshold = 1.E-1_real32
    select rank(metrics)
#if defined(GFORTRAN)
    rank(0)
       select type(metrics)
       type is(character(*))
          ! ERROR: ifort cannot identify that the rank of metrics has been ...
          ! ... identified as scalar here
          where(to_lower(trim(metrics)).eq.this%metrics%key)
             this%metrics%active = .true.
          end where
       end select
#endif
    rank(1)
       select type(metrics)
       type is(character(*))
          do i=1,size(metrics,1)
             where(to_lower(trim(metrics(i))).eq.this%metrics%key)
                this%metrics%active = .true.
             end where
          end do
       type is(metric_dict_type)
          if(size(metrics,1).eq.2)then
             this%metrics(:2) = metrics(:2)
          else
             call stop_program("invalid length array for metric_dict_type")
             return
          end if
       end select
    rank default
       call stop_program("provided metrics rank in compile invalid")
       return
    end select

  end subroutine set_metrics
!###############################################################################


!###############################################################################
  module subroutine set_loss(this, loss_method, verbose)
    !! Set the loss method for the network
    use coreutils, only: to_lower
    use athena__loss, only: &
         bce_loss_type, &
         cce_loss_type, &
         mae_loss_type, &
         mse_loss_type, &
         nll_loss_type, &
         huber_loss_type
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), intent(in) :: loss_method
    !! Loss method
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: verbose_
    !! Verbosity level
    character(len=:), allocatable :: loss_method_
    !! Loss method
    character(256) :: err_msg
    !! Error message


    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    !---------------------------------------------------------------------------
    ! Handle analogous definitions
    !---------------------------------------------------------------------------

    !---------------------------------------------------------------------------
    ! Set loss method
    !---------------------------------------------------------------------------
    select type(loss_method)
    class is(base_loss_type)
       this%loss = loss_method
       if(verbose_.gt.0) write(*,*) "Loss method: ", trim(loss_method%name)
       loss_method_ = trim(loss_method%name)
    type is(character(*))
       loss_method_ = to_lower(loss_method)
       select case(loss_method)
       case("binary_crossentropy")
          loss_method_ = "bce"
       case("categorical_crossentropy")
          loss_method_ = "cce"
       case("mean_absolute_error")
          loss_method_ = "mae"
       case("mean_squared_error")
          loss_method_ = "mse"
       case("negative_log_likelihood")
          loss_method_ = "nll"
       case("huber")
          loss_method_ = "hub"
       end select
       select case(loss_method_)
       case("bce")
          this%loss = bce_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Binary Cross Entropy"
       case("cce")
          this%loss = cce_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy"
       case("mae")
          this%loss = mae_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Mean Absolute Error"
       case("mse")
          this%loss = mse_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Mean Squared Error"
       case("nll")
          this%loss = nll_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Negative Log Likelihood"
       case("hub")
          this%loss = huber_loss_type()
          if(verbose_.gt.0) write(*,*) "Loss method: Huber"
       case default
          write(err_msg,'(A)') &
               "No loss method provided" // &
               achar(13) // achar(10) // &
               "Failed loss method: "//trim(loss_method_)
          call stop_program(trim(err_msg))
          return
       end select
    end select
    this%loss_method = loss_method_

  end subroutine set_loss
!###############################################################################


!###############################################################################
  module subroutine set_accuracy(this, accuracy_method, verbose)
    !! Set the accuracy method for the network
    use coreutils, only: to_lower
    use athena__accuracy, only: &
         categorical_score, &
         mae_score, &
         mse_score, &
         rmse_score, &
         r2_score
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    character(*), intent(in) :: accuracy_method
    !! Accuracy method
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: verbose_
    !! Verbosity level
    character(len=:), allocatable :: accuracy_method_
    !! Accuracy method
    character(256) :: err_msg
    !! Error message


    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    !---------------------------------------------------------------------------
    ! Handle analogous definitions
    !---------------------------------------------------------------------------
    accuracy_method_ = to_lower(accuracy_method)
    select case(accuracy_method)
    case("categorical")
       accuracy_method_ = "cat"
    case("mean_absolute_error")
       accuracy_method_ = "mae"
    case("mean_squared_error")
       accuracy_method_ = "mse"
    case("root_mean_squared_error")
       accuracy_method_ = "rmse"
    case("r2", "r^2", "r squared")
       accuracy_method_ = "r2"
    end select

    !---------------------------------------------------------------------------
    ! Set accuracy method
    !---------------------------------------------------------------------------
    select case(accuracy_method_)
    case("cat")
       this%get_accuracy => categorical_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Categorical "
    case("mae")
       this%get_accuracy => mae_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Absolute Error"
    case("mse")
       this%get_accuracy => mse_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Squared Error"
    case("rmse")
       this%get_accuracy => rmse_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: Root Mean Squared Error"
    case("r2")
       this%get_accuracy => r2_score
       if(verbose_.gt.0) write(*,*) "Accuracy method: R^2"
    case default
       write(err_msg,'(A)') &
            "No accuracy method provided" // &
            achar(13) // achar(10) // &
            "Failed accuracy method: "//trim(accuracy_method_)
       call stop_program(trim(err_msg))
       return
    end select
    this%accuracy_method = accuracy_method_

  end subroutine set_accuracy
!###############################################################################


!###############################################################################
  module subroutine reset(this)
    !! Reset the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    this%epoch = 0
    this%accuracy_val = 0._real32
    this%loss_val = huge(1._real32)
    this%batch_size = 0
    this%num_layers = 0
    this%num_outputs = 0
    if(allocated(this%optimiser)) deallocate(this%optimiser)
    call this%set_metrics(["loss"])
    if(allocated(this%model)) deallocate(this%model)
    if(allocated(this%loss)) deallocate(this%loss)
    this%get_accuracy => null()

    if(allocated(this%vertex_order)) deallocate(this%vertex_order)
    if(allocated(this%leaf_vertices)) deallocate(this%leaf_vertices)
    if(allocated(this%root_vertices)) deallocate(this%root_vertices)
    this%auto_graph = graph_type(directed=.true.)

  end subroutine reset
!###############################################################################


!###############################################################################
  module subroutine compile( &
       this, optimiser, loss_method, accuracy_method, &
       metrics, batch_size, verbose &
  )
    !! Compile the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(base_optimiser_type), optional, intent(in) :: optimiser
    !! Optimiser to use for training
    class(*), optional, intent(in) :: loss_method
    !! Loss method
    character(*), optional, intent(in) :: accuracy_method
    !! Accuracy method
    class(*), dimension(..), optional, intent(in) :: metrics
    !! Metrics
    integer, optional, intent(in) :: batch_size
    !! Batch size
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: i, j, k, child_id, parent_id, layer_id, num_inputs, input_rank
    !! Loop index
    integer :: parent_vertex, vertex_idx
    !! Vertex indices
    integer :: layer_rank, parent_rank, operator
    !! Ranks of layers
    integer :: verbose_ = 0
    !! Verbosity level
    logical :: use_graph_input = .false.
    !! Boolean whether to use graph input
    logical :: l_flatten_child, l_set_input_shape
    !! Booleans whether to flatten child or set input shape
    integer, dimension(:), allocatable :: input_shape, &
         child_vertices, parent_vertices, output_ranks, parent_ids
    !! Shapes of the input and output of the layers
    integer, dimension(:,:), allocatable :: merge_shape
    !! Shapes of the inputs to merge layers
    class(base_layer_type), allocatable :: &
         t_input_layer, t_flatten_layer, t_merge_layer
    !! Temporary input, flatten, and merge layers


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose


    !---------------------------------------------------------------------------
    ! Initialise metrics
    !---------------------------------------------------------------------------
    if(present(metrics)) call this%set_metrics(metrics)


    !---------------------------------------------------------------------------
    ! Initialise loss and accuracy methods
    !---------------------------------------------------------------------------
    if(present(loss_method)) call this%set_loss(loss_method, verbose_)
    if(present(accuracy_method)) &
         call this%set_accuracy(accuracy_method, verbose_)


    !---------------------------------------------------------------------------
    ! Check for input layers at root vertices
    !---------------------------------------------------------------------------
    this%auto_graph%directed = .true.
    call this%build_root_vertices()
    do i = 1, size(this%root_vertices)
       layer_id = this%auto_graph%vertex(this%root_vertices(i))%id
       if(.not.allocated(this%model(layer_id)%layer%input_shape))then
          call stop_program("input_shape of first layer not defined")
          return
       end if
       use_graph_input = .false.
       select type( root => this%model(layer_id)%layer)
       class is(input_layer_type)
          cycle
       class is(learnable_layer_type)
          input_shape = root%input_shape
          use_graph_input = root%use_graph_input
       class default
          input_shape = root%input_shape
       end select
       t_input_layer = input_layer_type(&
            input_shape = input_shape, &
            index = i, &
            use_graph_input = use_graph_input, &
            verbose=verbose_ &
       )
       call this%add( &
            t_input_layer, output_list = [ this%model(layer_id)%layer%id ] &
       )
       ! NEED TO CALL layer%init?
       deallocate(input_shape)
       deallocate(t_input_layer)
       this%root_vertices(i) = this%num_layers
       if(i.eq.1)then
          do j = 1, this%auto_graph%num_edges
             if(this%auto_graph%edge(j)%index(1).eq.0) &
                  this%auto_graph%edge(j)%index(1) = this%num_layers
          end do
       end if
    end do
    call this%auto_graph%generate_adjacency()


    !---------------------------------------------------------------------------
    ! Identify whether input is graph type
    !---------------------------------------------------------------------------
    if( &
         this%model( &
              this%auto_graph%vertex(this%root_vertices(1))%id &
         )%layer%use_graph_input &
    )then
       this%use_graph_input = .true.
    else
       this%use_graph_input = .false.
    end if


    !---------------------------------------------------------------------------
    ! Check for zero input rank layers
    !---------------------------------------------------------------------------
    do i = 1, size(this%auto_graph%vertex, dim = 1)
       layer_id = this%auto_graph%vertex(i)%id
       if(this%model(layer_id)%layer%input_rank.eq.0)then
          parent_ids = pack( &
               [ ( &
                    this%auto_graph%vertex(j)%id, &
                    j = 1, size(this%auto_graph%adjacency(:,i)) &
               ) ], &
               this%auto_graph%adjacency(:,i) .ne. 0 &
          )
          if(size(parent_ids).eq.0) cycle
          output_ranks = [ ( this%model(parent_ids(j))%layer%output_rank, &
               j=1,size(parent_ids) ) ]
          if(any(output_ranks.ne.output_ranks(1)))then
             write(0,*) output_ranks
             call stop_program( &
                  "input rank of layer "//trim(this%model(layer_id)%layer%name) // &
                  " is zero, but multiple parents with different output ranks" &
             )
             return
          end if
          input_rank = this%model(parent_ids(1))%layer%output_rank
          call this%model(layer_id)%layer%set_rank( &
               input_rank = input_rank, &
               output_rank = input_rank &
          )
       end if
    end do


    !---------------------------------------------------------------------------
    ! Check for required flatten layers
    !---------------------------------------------------------------------------
    i = 0
    flatten_loop: do
       i = i + 1
       if(i.gt.this%auto_graph%num_vertices) exit flatten_loop
       layer_id = this%auto_graph%vertex(i)%id

       ! get all child vertices
       child_vertices = pack( &
            [(j, j=1,size(this%auto_graph%adjacency(i,:)))], &
            this%auto_graph%adjacency(i,:) .ne. 0 &
       )
       child_loop: do j = 1, size(child_vertices)
          ! Get layer ID (needed for add() function's output_list parameter)
          child_id = this%auto_graph%vertex(child_vertices(j))%id
          if(trim(this%model(layer_id)%layer%type).eq."flat") cycle child_loop
          if( this%model(layer_id)%layer%output_rank .eq. &
               this%model(child_id)%layer%input_rank ) cycle child_loop
          if(this%model(layer_id)%layer%output_rank.eq.0) cycle child_loop

          ! get all parent vertices of the child vertex
          parent_vertices = pack( &
               [(k, k=1,size(this%auto_graph%adjacency(:,child_vertices(j))))], &
               this%auto_graph%adjacency(:,child_vertices(j)) .ne. 0 &
          )
          l_flatten_child = .true.
          do k = 1, size(parent_vertices)
             parent_id = this%auto_graph%vertex(parent_vertices(k))%id
             !check if ranks match, rather than input and output shapes
             if( this%model(layer_id)%layer%output_rank .ne. &
                  this%model(parent_id)%layer%input_rank &
             ) l_flatten_child = .false.
          end do
          t_flatten_layer = flatten_layer_type( &
               input_rank = this%model(layer_id)%layer%output_rank &
          )

          if(l_flatten_child)then
             ! add flatten layer in the place of the child layer
             operator = this%auto_graph%edge( &
                  this%auto_graph%adjacency(parent_vertices(1),child_vertices(j)) &
             )%id
             call this%auto_graph%remove_edges( &
                  indices = [ &
                       this%auto_graph%adjacency( &
                            parent_vertices(:),child_vertices(j) &
                       ) &
                  ] &
             )
             call this%add( &
                  t_flatten_layer, &
                  input_list=[parent_vertices(:)], output_list=[child_id], &
                  operator=operator &
             )
          else
             ! add flatten layer between the current layer and the child layer
             operator = this%auto_graph%edge( &
                  this%auto_graph%adjacency(i,child_vertices(j)) &
             )%id
             call this%auto_graph%remove_edges( &
                  indices = [this%auto_graph%adjacency(i,child_vertices(j))] &
             )
             call this%add( &
                  t_flatten_layer, input_list = [layer_id], &
                  output_list = [child_id], &
                  operator=operator &
             )
          end if
          deallocate(t_flatten_layer)
          deallocate(child_vertices)
          cycle flatten_loop
       end do child_loop
       deallocate(child_vertices)
    end do flatten_loop
    call this%build_vertex_order()


    !---------------------------------------------------------------------------
    ! Check for required merge layers
    !---------------------------------------------------------------------------
    i = 0
    merge_loop: do
       i = i + 1
       if(i.gt.this%auto_graph%num_vertices) exit merge_loop
       layer_id = this%auto_graph%vertex(i)%id
       if(this%model(layer_id)%layer%type.eq."merg") cycle merge_loop

       ! get all child vertices
       parent_vertices = pack( &
            [(j, j=1,size(this%auto_graph%adjacency(:,i)))], &
            this%auto_graph%adjacency(:,i) .ne. 0 &
       )
       if(size(parent_vertices).le.1) cycle merge_loop

       ! get edge id for merge layer
       operator = this%auto_graph%edge( &
            this%auto_graph%adjacency(parent_vertices(1),i) &
       )%id

       ! remove edges from parents to this layer
       do j = 1, size(parent_vertices)
          call this%auto_graph%remove_edges( &
               indices = [this%auto_graph%adjacency(parent_vertices(j),i)] &
          )
       end do
       parent_ids = &
            [ ( &
                 this%auto_graph%vertex(parent_vertices(k))%id, &
                 k = 1, size(parent_vertices) &
            ) ]
       select case(operator)
       case(1) ! concatenate
          t_merge_layer = concat_layer_type( &
               input_layer_ids = parent_ids, &
               input_rank = this%model(layer_id)%layer%input_rank &
          )
       case(2) ! add
          t_merge_layer = add_layer_type( &
               input_layer_ids = parent_ids, &
               input_rank = this%model(layer_id)%layer%input_rank &
          )
          ! case(3) ! multiply
          !    t_merge_layer = multiply_layer_type( &
          !         input_layer_ids = parent_vertices &
          !    )
       case default
          write(0,*) "invalid merge operator: ", operator
          call stop_program("invalid merge operator")
          return
       end select
       t_merge_layer%use_graph_input = this%model(layer_id)%layer%use_graph_input
       t_merge_layer%use_graph_output = t_merge_layer%use_graph_input
       call this%add( &
            t_merge_layer, &
            input_list = parent_ids, &
            output_list = [layer_id] &
       )
       deallocate(t_merge_layer)
    end do merge_loop
    call this%build_vertex_order()


    ! Update number of layers
    !---------------------------------------------------------------------------
    this%num_layers = size(this%model,dim=1)



    !---------------------------------------------------------------------------
    ! Initialise layers
    !---------------------------------------------------------------------------
    do i = 1, size(this%vertex_order, dim = 1)
       vertex_idx = this%vertex_order(i)
       layer_id = this%auto_graph%vertex(vertex_idx)%id
       if(allocated(this%model(layer_id)%layer%input_shape))then
          l_set_input_shape = .false.
       else
          l_set_input_shape = .true.
       end if
       if(l_set_input_shape)then
          layer_rank = this%model(layer_id)%layer%input_rank
          parent_rank = 0

          select type( layer => this%model(layer_id)%layer )
          class is(merge_layer_type)
             ! loop over all parent layers
             allocate( &
                  merge_shape( &
                       this%model(layer_id)%layer%input_rank, &
                       size(layer%input_layer_ids) &
                  ) &
             )
             do k = 1, size(layer%input_layer_ids)
                merge_shape(:,k) = &
                     this%model(layer%input_layer_ids(k))%layer%output_shape
             end do
             input_shape = layer%calc_input_shape(merge_shape)
             deallocate(merge_shape)
          class default

             allocate( &
                  input_shape(this%model(layer_id)%layer%input_rank), &
                  source = 0 &
             )
             do j = 1, this%auto_graph%num_vertices
                if(this%auto_graph%adjacency(j,vertex_idx).eq.0) cycle
                parent_id = this%auto_graph%vertex(j)%id
                parent_rank = this%model(parent_id)%layer%output_rank

                if(layer_rank .eq. parent_rank)then
                   input_shape(:) = input_shape(:) + &
                        this%model(parent_id)%layer%output_shape
                elseif(layer_rank .eq. 1)then
                   input_shape(1) = input_shape(1) + &
                        product( this%model(parent_id)%layer%output_shape )
                end if
             end do
          end select
          call this%model(layer_id)%layer%init( &
               input_shape = input_shape, &
               verbose = verbose_ &
          )
          deallocate(input_shape)
       end if
       if(verbose_.gt.0)then
          write(*,*) "layer: ", layer_id, this%model(layer_id)%layer%type
          write(*,*) this%model(layer_id)%layer%input_shape
          write(*,*) this%model(layer_id)%layer%output_shape
       end if
    end do


    ! Set number of outputs
    !---------------------------------------------------------------------------
    this%num_outputs = 0
    call this%build_leaf_vertices()
    do i = 1, size(this%leaf_vertices,1)
       this%num_outputs = this%num_outputs + &
            product( &
                 this%model( &
                      this%auto_graph%vertex(this%leaf_vertices(i))%id &
                 )%layer%output_shape &
            )
    end do
    if( &
         this%model( &
              this%auto_graph%vertex(this%leaf_vertices(1))%id &
         )%layer%use_graph_output &
    )then
       this%use_graph_output = .true.
    else
       this%use_graph_output = .false.
    end if


    !---------------------------------------------------------------------------
    ! Confirm input_shape of each layer matches data going into it
    !---------------------------------------------------------------------------
    do i = 1, size(this%vertex_order, dim = 1)
       vertex_idx = this%vertex_order(i)
       layer_id = this%auto_graph%vertex(vertex_idx)%id
       if(this%model(layer_id)%layer%type.eq."inpt") cycle

       ! Get all parent vertices that feed into this layer
       parent_vertices = pack( &
            [(j, j=1,size(this%auto_graph%adjacency(:,vertex_idx)))], &
            this%auto_graph%adjacency(:,vertex_idx) .ne. 0 &
       )
       if(size(parent_vertices).eq.0) cycle
       select type( layer => this%model(layer_id)%layer )
       class is(merge_layer_type)
          operator = layer%merge_mode
       class default
          if(size(parent_vertices).gt.1)then
             call stop_program( &
                  "layer "//trim(layer%name)// &
                  " is not a merge layer but has multiple inputs" &
             )
             return
          end if
       end select

       ! Calculate expected input size from parent layers
       num_inputs = 0
       do j = 1, size(parent_vertices)
          parent_vertex = parent_vertices(j)

          select case(operator)
          case(1) ! pointwise - all inputs should have same size
             if(num_inputs.eq.0)then
                if(this%model(layer_id)%layer%use_graph_input)then
                   num_inputs = this%model(parent_vertex)%layer%output_shape(1)
                else
                   num_inputs = product(this%model(parent_vertex)%layer%output_shape)
                end if
             end if
          case(2) ! concatenate
             if(this%model(layer_id)%layer%use_graph_input)then
                num_inputs = num_inputs + &
                     this%model(parent_vertex)%layer%output_shape(1)
             else
                num_inputs = num_inputs + &
                     product(this%model(parent_vertex)%layer%output_shape)
             end if
          end select
       end do

       ! Verify calculated input size matches layer's expected input size
       if(this%model(layer_id)%layer%use_graph_input)then
          if(num_inputs.ne.this%model(layer_id)%layer%input_shape(1) .and. &
               num_inputs.ne.0)then
             write(*,*) "Expected:", num_inputs, "Got:", &
                  this%model(layer_id)%layer%input_shape(1)
             call stop_program( &
                  "input_shape of layer "//&
                  trim(this%model(layer_id)%layer%name)// &
                  " does not match data going into it" &
             )
          end if
       else
          if(num_inputs.ne.product(this%model(layer_id)%layer%input_shape) .and. &
               num_inputs.ne.0)then
             write(*,*) "Expected:", num_inputs, "Got:", &
                  product(this%model(layer_id)%layer%input_shape)
             call stop_program( &
                  "input_shape of layer "//&
                  trim(this%model(layer_id)%layer%name)// &
                  " does not match data going into it" &
             )
          end if
       end if

    end do

    !---------------------------------------------------------------------------
    ! Initialise optimiser
    !---------------------------------------------------------------------------
    this%num_params = this%get_num_params()
    if(present(optimiser))then
       this%optimiser = optimiser
    end if
    if(.not.allocated(this%optimiser))then
       call stop_program("No optimiser is defined for the network")
       return
    else
       call this%optimiser%init(num_params=this%num_params)
    end if


    !---------------------------------------------------------------------------
    ! Pre-compute forward pass navigation
    !---------------------------------------------------------------------------
    block
      integer :: nv, l_idx, v_idx, lid, parent_v
      nv = size(this%vertex_order, 1)
      if(allocated(this%fwd_layer_id))   deallocate(this%fwd_layer_id)
      if(allocated(this%fwd_num_inputs)) deallocate(this%fwd_num_inputs)
      if(allocated(this%fwd_parent_id))  deallocate(this%fwd_parent_id)
      if(allocated(this%fwd_layer_type)) deallocate(this%fwd_layer_type)
      allocate(this%fwd_layer_id(nv))
      allocate(this%fwd_num_inputs(nv))
      allocate(this%fwd_parent_id(nv))
      allocate(this%fwd_layer_type(nv))
      this%fwd_parent_id = 0
      do l_idx = 1, nv
         v_idx = this%vertex_order(l_idx)
         lid = this%auto_graph%vertex(v_idx)%id
         this%fwd_layer_id(l_idx) = lid
         this%fwd_num_inputs(l_idx) = &
              count(this%auto_graph%adjacency(:,v_idx).gt.0)
         if(this%fwd_num_inputs(l_idx).eq.1)then
            parent_v = maxloc( &
                 this%auto_graph%adjacency(:,v_idx), dim=1)
            this%fwd_parent_id(l_idx) = &
                 this%auto_graph%vertex(parent_v)%id
         end if
         ! Determine layer type: 0=input, 1=merge, 2=default
         select type(layer => this%model(lid)%layer)
         class is(input_layer_type)
            this%fwd_layer_type(l_idx) = 0
         class is(merge_layer_type)
            this%fwd_layer_type(l_idx) = 1
         class default
            this%fwd_layer_type(l_idx) = 2
         end select
      end do
    end block


    !---------------------------------------------------------------------------
    ! Pre-compute parameter segment layout
    !---------------------------------------------------------------------------
    block
      integer :: l_idx, p_idx, seg_count, s_idx, e_idx
      ! First pass: count segments
      seg_count = 0
      do l_idx = 1, this%num_layers
         select type(current => this%model(l_idx)%layer)
         class is(learnable_layer_type)
            seg_count = seg_count + size(current%params)
         end select
      end do
      this%param_num_segments = seg_count
      if(allocated(this%param_seg_layer)) deallocate(this%param_seg_layer)
      if(allocated(this%param_seg_pidx))  deallocate(this%param_seg_pidx)
      if(allocated(this%param_seg_start)) deallocate(this%param_seg_start)
      if(allocated(this%param_seg_end))   deallocate(this%param_seg_end)
      allocate(this%param_seg_layer(seg_count))
      allocate(this%param_seg_pidx(seg_count))
      allocate(this%param_seg_start(seg_count))
      allocate(this%param_seg_end(seg_count))
      ! Second pass: fill layout
      seg_count = 0
      e_idx = 0
      do l_idx = 1, this%num_layers
         select type(current => this%model(l_idx)%layer)
         class is(learnable_layer_type)
            do p_idx = 1, size(current%params)
               seg_count = seg_count + 1
               s_idx = e_idx + 1
               e_idx = e_idx + size(current%params(p_idx)%val, 1)
               this%param_seg_layer(seg_count) = l_idx
               this%param_seg_pidx(seg_count) = p_idx
               this%param_seg_start(seg_count) = s_idx
               this%param_seg_end(seg_count) = e_idx
            end do
         end select
      end do
    end block


    !---------------------------------------------------------------------------
    ! Set batch size, if provided
    !---------------------------------------------------------------------------
    if(present(batch_size)) this%batch_size = batch_size

  end subroutine compile
!###############################################################################


!###############################################################################
  module subroutine set_batch_size(this, batch_size)
    !! Set the batch size for the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    integer, intent(in) :: batch_size
    !! Batch size

    ! Local variables
    integer :: l
    !! Loop index


    this%batch_size = batch_size

  end subroutine set_batch_size
!###############################################################################


!###############################################################################
  module subroutine reset_state(this)
    !! Reset the hidden state of all layers in the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    ! Local variables
    integer :: l
    !! Loop index

    do l = 1, size(this%model, dim = 1)
       select type( layer => this%model(l)%layer )
       class is(recurrent_layer_type)
          call layer%reset_state()
       end select
    end do

  end subroutine reset_state
!###############################################################################


!###############################################################################
  module subroutine set_training_mode(this, mode_store, layer_indices)
    !! Put the network in training mode.
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    logical, dimension(:), allocatable, intent(out), optional :: mode_store
    !! Optional array to store the training mode of each layer
    integer, dimension(:), intent(in), optional :: layer_indices
    !! Optional array of layer indices to set to training mode.
    !! If not provided, all layers will be set to training mode.

    ! Local variables
    integer :: l
    !! Loop index

    if(.not.allocated(this%model)) return
    if(present(mode_store)) allocate(mode_store(this%num_layers))
    do l = 1, this%num_layers
       if(present(mode_store)) mode_store(l) = this%model(l)%layer%inference
       this%model(l)%layer%inference = .false.
       if(present(layer_indices))then
          if(any(layer_indices.eq.l))then
             this%model(l)%layer%inference = .false.
          end if
       else
          this%model(l)%layer%inference = .false.
       end if
    end do

  end subroutine set_training_mode
!-------------------------------------------------------------------------------
  module subroutine set_inference_mode(this, mode_store, layer_indices)
    !! Put the network in inference mode.
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    logical, dimension(:), allocatable, intent(out), optional :: mode_store
    !! Optional array to store the training mode of each layer
    integer, dimension(:), intent(in), optional :: layer_indices
    !! Optional array of layer indices to set to inference mode.
    !! If not provided, all layers will be set to inference mode.

    ! Local variables
    integer :: l
    !! Loop index

    if(.not.allocated(this%model)) return
    if(present(mode_store)) allocate(mode_store(this%num_layers))
    do l = 1, this%num_layers
       if(present(mode_store)) mode_store(l) = this%model(l)%layer%inference
       if(present(layer_indices))then
          if(any(layer_indices.eq.l))then
             this%model(l)%layer%inference = .true.
          end if
       else
          this%model(l)%layer%inference = .true.
       end if
    end do

  end subroutine set_inference_mode
!-------------------------------------------------------------------------------
  module subroutine restore_mode(this, mode_store)
    !! Restore the training/inference mode of each layer from a stored array.
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    logical, dimension(:), intent(in) :: mode_store
    !! Array storing the mode of each layer
    !! .true. = inference, .false. = training

    ! Local variables
    integer :: l
    !! Loop index

    if(.not.allocated(this%model)) return
    if(size(mode_store) .ne. this%num_layers)then
       call stop_program("mode_store size does not match number of layers")
       return
    end if
    do l = 1, this%num_layers
       this%model(l)%layer%inference = mode_store(l)
    end do

  end subroutine restore_mode
!###############################################################################


!###############################################################################
  module function layer_from_id(this, id) result(layer)
    !! Get layer from its ID
    implicit none

    ! Arguments
    class(network_type), intent(in), target :: this
    !! Instance of network
    integer, intent(in) :: id
    !! Layer ID

    class(base_layer_type), pointer :: layer
    !! Layer

    ! Local variables
    integer :: i, itmp1
    !! Loop index

    itmp1 = 0
    do i = 1, size(this%model, dim = 1)
       if(this%model(i)%layer%id.eq.id)then
          if(itmp1.ne.0)then
             call stop_program("multiple layers with same ID found")
             return
          end if
          layer => this%model(i)%layer
          itmp1 = itmp1 + 1
       end if
    end do

  end function layer_from_id
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
#ifdef __flang__
  module function get_sample_flang( &
       input, start_index, end_index, batch_size &
  ) result(sample)
    !! Get samples of batch size from a real array
    implicit none

    ! Arguments
    integer, intent(in) :: start_index, end_index
    !! Start and end indices
    integer, intent(in) :: batch_size
    !! Batch size
    real(real32), dimension(..), intent(in) :: input
    !! Input array

    real(real32), allocatable :: sample(:,:)
    !! Sample array

    integer :: num_samples

    num_samples = end_index - start_index + 1

    select rank(input)
    rank(2)
       sample = input(:, start_index:end_index)
    rank(3)
       sample = reshape(input(:, :, start_index:end_index), &
            [size(input,1) * size(input,2), num_samples])
    rank(4)
       sample = reshape(input(:, :, :, start_index:end_index), &
            [size(input,1) * size(input,2) * size(input,3), num_samples])
    rank(5)
       sample = reshape(input(:, :, :, :, start_index:end_index), &
            [size(input,1) * size(input,2) * size(input,3) * size(input,4), &
                 num_samples])
    rank(6)
       sample = reshape(input(:, :, :, :, :, start_index:end_index), &
            [size(input,1) * size(input,2) * size(input,3) * size(input,4) * &
                 size(input,5), num_samples])
    rank default
       allocate(sample(0, 0))
    end select

  end function get_sample_flang
#else
  module function get_sample_ptr( &
       input, start_index, end_index, batch_size &
  ) result(sample_ptr)
    !! Get samples of batch size from a real array
    implicit none

    ! Arguments
    integer, intent(in) :: start_index, end_index
    !! Start and end indices
    integer, intent(in) :: batch_size
    !! Batch size
    real(real32), dimension(..), intent(in), target :: input
    !! Input array

    real(real32), pointer :: sample_ptr(:,:)
    !! Pointer to sample


    select rank(input)
    rank(2)
       sample_ptr(1:size(input(:,1)),1:end_index-start_index+1) => &
            input(:,start_index:end_index)
    rank(3)
       sample_ptr(1:size(input(:,:,1)),1:end_index-start_index+1) => &
            input(:,:,start_index:end_index)
    rank(4)
       sample_ptr(1:size(input(:,:,:,1)),1:end_index-start_index+1) => &
            input(:,:,:,start_index:end_index)
    rank(5)
       sample_ptr(1:size(input(:,:,:,:,1)),1:end_index-start_index+1) => &
            input(:,:,:,:,start_index:end_index)
    rank(6)
       sample_ptr(1:size(input(:,:,:,:,:,1)),1:end_index-start_index+1) => &
            input(:,:,:,:,:,start_index:end_index)
    rank default
       sample_ptr => null()
    end select

  end function get_sample_ptr
#endif
!-------------------------------------------------------------------------------
  module function get_sample_array( &
       input, start_index, end_index, batch_size, as_graph &
  ) result(sample)
    !! Get samples of batch size from a derived type array
    implicit none

    ! Arguments
    integer, intent(in) :: start_index, end_index
    !! Start and end indices
    integer, intent(in) :: batch_size
    !! Batch size
    class(array_type), dimension(:,:), intent(in) :: input
    !! Input array
    logical, intent(in) :: as_graph
    !! Boolean whether to treat the input as a graph

    type(array_type), dimension(:,:), allocatable :: sample
    !! Sample array

    ! Local variables
    integer :: i, j
    !! Loop index

    if(as_graph)then
       allocate(sample(size(input,1), batch_size))
       do i = 1, size(input,1)
          do j = start_index, end_index, 1
             sample(i, j - start_index + 1)%val = input(i, j)%val
          end do
       end do
    else
       allocate(sample(size(input,1), size(input,2)))
       do i = 1, size(input,1)
          do j = 1, size(input,2)
             call sample(i,j)%allocate(array_shape=[input(i,j)%shape, &
                  end_index - start_index + 1])
             sample(i,j)%val = input(i,j)%val(:,start_index:end_index)
          end do
       end do
    end if

  end function get_sample_array
!-------------------------------------------------------------------------------
  module function get_sample_graph1d( &
       input, start_index, end_index, batch_size &
  ) result(sample)
    !! Get samples of batch size from a graph
    implicit none

    ! Arguments
    integer, intent(in) :: start_index, end_index
    !! Start and end indices
    integer, intent(in) :: batch_size
    !! Batch size
    class(graph_type), dimension(:), intent(in) :: input
    !! Input array

    type(graph_type), dimension(1, batch_size) :: sample
    !! Sample array

    sample(1,1:batch_size) = input(start_index:end_index)

  end function get_sample_graph1d
!-------------------------------------------------------------------------------
  module function get_sample_graph2d( &
       input, start_index, end_index, batch_size &
  ) result(sample)
    !! Get samples of batch size from a graph
    implicit none

    ! Arguments
    integer, intent(in) :: start_index, end_index
    !! Start and end indices
    integer, intent(in) :: batch_size
    !! Batch size
    class(graph_type), dimension(:,:), intent(in) :: input
    !! Input array

    type(graph_type), dimension(size(input,1), batch_size) :: sample
    !! Sample array

    sample(1:size(input,1),1:batch_size) = input(:,start_index:end_index)

  end function get_sample_graph2d
!###############################################################################


!###############################################################################
  pure module function get_num_params(this) result(num_params)
    !! Get the number of learnable parameters in the network
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    integer :: num_params
    !! Number of parameters

    ! Local variables
    integer :: l, i
    !! Loop index

    num_params = 0
    do l = 1, this%num_layers
       select type(current => this%model(l)%layer)
       class is(learnable_layer_type)
          do i = 1, size(current%params)
             num_params = num_params + size(current%params(i)%val, 1)
          end do
       end select
    end do

  end function get_num_params
!###############################################################################


!###############################################################################
  pure module function get_params(this) result(params)
    !! Get learnable parameters
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    real(real32), dimension(this%num_params) :: params
    !! Parameters

    ! Local variables
    integer :: l, i, start_idx, end_idx
    !! Loop index

    start_idx = 0
    end_idx   = 0
    do l = 1, this%num_layers
       select type(current => this%model(l)%layer)
       class is(learnable_layer_type)
          do i = 1, size(current%params)
             start_idx = end_idx + 1
             end_idx = end_idx + size(current%params(i)%val, 1)
             params(start_idx:end_idx) = current%params(i)%val(:,1)
          end do
       end select
    end do

  end function get_params
!###############################################################################


!###############################################################################
  module subroutine set_params(this, params)
    !! Set learnable parameters
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    real(real32), dimension(this%num_params), intent(in) :: params
    !! Parameters

    ! Local variables
    integer :: l, i, start_idx, end_idx
    !! Loop index

    start_idx = 0
    end_idx   = 0
    do l = 1, this%num_layers
       select type(current => this%model(l)%layer)
       class is(learnable_layer_type)
          do i = 1, size(current%params)
             start_idx = end_idx + 1
             end_idx = end_idx + size(current%params(i)%val, 1)
             current%params(i)%val(:,1) = params(start_idx:end_idx)
          end do
          !  call current%set_params(params(start_idx:end_idx))
       end select
    end do

  end subroutine set_params
!###############################################################################


!###############################################################################
  pure module function get_gradients(this) result(gradients)
    !! Get gradients
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    real(real32), dimension(this%num_params) :: gradients
    !! Gradients

    ! Local variables
    integer :: l, i, start_idx, end_idx
    !! Loop index

    start_idx = 0
    end_idx   = 0
    do l = 1, this%num_layers
       select type(current => this%model(l)%layer)
       class is(learnable_layer_type)
          do i = 1, size(current%params)
             if(associated(current%params(i)%grad))then
                start_idx = end_idx + 1
                end_idx = end_idx + size(current%params(i)%val, 1)
                gradients(start_idx:end_idx) = [ &
                     sum(current%params(i)%grad%val, dim=2) / &
                     real(size(current%params(i)%grad%val, dim=2), real32) &
                ]
             end if
          end do
       end select
    end do
    call this%optimiser%clip_dict%apply(size(gradients),gradients)

  end function get_gradients
!###############################################################################


!###############################################################################
  module subroutine set_gradients(this, gradients)
    !! Set gradients
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    real(real32), dimension(..), intent(in) :: gradients
    !! Gradients

    ! Local variables
    integer :: l, start_idx, end_idx
    !! Loop index

    start_idx = 0
    end_idx   = 0
    do l = 1, this%num_layers
       select type(current => this%model(l)%layer)
       class is(learnable_layer_type)
          start_idx = end_idx + 1
          end_idx = end_idx + current%num_params
          select rank(gradients)
          rank(0)
             call current%set_gradients(gradients)
          rank(1)
             call current%set_gradients(gradients(start_idx:end_idx))
          end select
       end select
    end do

  end subroutine set_gradients
!###############################################################################


!###############################################################################
  module subroutine reset_gradients(this)
    !! Reset gradients
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    ! Local variables
    integer :: l, i
    !! Loop index

    do l = 1, this%num_layers
       select type(current => this%model(l)%layer)
       class is(learnable_layer_type)
          do i = 1, size(current%params)
             call current%params(i)%zero_grad()
          end do
       end select
    end do

  end subroutine reset_gradients
!###############################################################################


!###############################################################################
  module function get_output_shape(this) result(output_shape)
    !! Get the output of the network
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    integer, dimension(2) :: output_shape
    !! Output shape

    ! Local variables
    integer :: i, layer_idx
    !! Loop indices


    ! array data: [ layer idx, empty ]
    ! graph data: [ vertex/edge idx, sample idx]

    if(this%use_graph_output)then
       output_shape = [2, this%batch_size]
       do i = 1, size(this%leaf_vertices,1), 1
          layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
          if(size(this%model(layer_idx)%layer%output,2).ne.this%batch_size)then
             call stop_program( &
                  "Inconsistent batch size in output layers" &
             )
             return
          end if
          output_shape(1) = output_shape(1) + &
               size( this%model(layer_idx)%layer%output, 1 )
       end do
    else
       output_shape = [0, 1]
       do i = 1, size(this%leaf_vertices,1)
          layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
          if(size(this%model(layer_idx)%layer%output,2).ne.1)then
             call stop_program( &
                  "Inconsistent size of dimension 2 in output layers" &
             )
             return
          end if
          output_shape(1) = &
               output_shape(1) + size( this%model(layer_idx)%layer%output, 1 )
       end do
    end if

  end function get_output_shape
!-------------------------------------------------------------------------------
  module function get_output(this) result(output)
    !! Get the output of the network
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    type(array_type), dimension(:,:), allocatable :: output
    !! Output

    ! Local variables
    integer :: i, start_idx, end_idx, layer_idx, output_id
    !! Loop indices
    integer, dimension(2) :: output_shape
    !! Output shape
    integer, dimension(this%num_outputs) :: output_ids
    !! Output IDs


    ! array data: [ layer idx, empty ]
    ! graph data: [ vertex/edge idx, sample idx]

    if(this%use_graph_output)then
       output_shape = [2, this%batch_size]
       do i = 1, size(this%leaf_vertices,1), 1
          layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
          if(size(this%model(layer_idx)%layer%output,2).ne.this%batch_size)then
             call stop_program( &
                  "Inconsistent batch size in output layers" &
             )
             return
          end if
          output_id = this%model(layer_idx)%layer%id
          output_ids(output_id) = size( this%model(layer_idx)%layer%output, 1 )
          output_shape(1) = output_shape(1) + output_ids(output_id)
       end do
       allocate(output(output_shape(1), output_shape(2)))
       do i = 1, size(this%leaf_vertices,1)
          layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
          output_id = sum(output_ids(1:this%model(layer_idx)%layer%id-1)) + 1
          output(output_id,:) = this%model(layer_idx)%layer%output(1,:)
          if(output_ids(this%model(layer_idx)%layer%id).gt.1)then
             output(output_id+1,:) = this%model(layer_idx)%layer%output(2,:)
          end if
       end do
    else
       output_shape = [0, 1]
       do i = 1, size(this%leaf_vertices,1)
          layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
          if(size(this%model(layer_idx)%layer%output,2).ne.1)then
             call stop_program( &
                  "Inconsistent size of dimension 2 in output layers" &
             )
             return
          end if
          output_shape(1) = &
               output_shape(1) + size( this%model(layer_idx)%layer%output, 1 )
          output_id = this%model(layer_idx)%layer%id
          output_ids(output_id) = size( this%model(layer_idx)%layer%output, 1 )
       end do
       allocate(output(output_shape(1), output_shape(2)))
       start_idx = 1
       end_idx = 0
       do i = 1, size(this%leaf_vertices,1)
          layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
          output_id = this%model(layer_idx)%layer%id
          end_idx = end_idx + output_ids(output_id)
          output(start_idx:end_idx,1) = this%model(layer_idx)%layer%output(:,1)
          start_idx = end_idx + 1
       end do
    end if

  end function get_output
!-------------------------------------------------------------------------------
  module subroutine extract_output_real(this, output)
    !! Get the output of the network as real array
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    ! Instance of network
    real(real32), dimension(..), allocatable, intent(out) :: output
    !! Output

    ! Local variables
    integer :: layer_id
    !! Layer ID
    character(len=10) :: rank_str
    !! String for rank

    ! check if number of leaf vertices is 1
    if(size(this%leaf_vertices,1).gt.1)then
       call print_warning("Output extraction to real array only works for single &
            &output networks")
       return
    end if

    ! Get output from the first (and only) leaf vertex
    layer_id = this%auto_graph%vertex(this%leaf_vertices(1))%id
    call this%model(layer_id)%layer%output(1,1)%extract(output)

  end subroutine extract_output_real
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module function accuracy_eval(this, output, start_index, end_index) &
       result(accuracy)
    !! Get the loss for the output
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: output
    !! Output
    integer, intent(in) :: start_index, end_index
    !! Start and end batch indices

    real(real32) :: accuracy
    !! Loss value

    ! Local variables
    integer :: s, s_idx
    !! Loop index

    accuracy = 0._real32
    select type(output)
    type is(graph_type)
       do s = start_index, end_index, 1
          s_idx = s - start_index + 1
          accuracy = accuracy + sum( this%get_accuracy( &
               this%model(this%leaf_vertices(1))%layer%output(1,s_idx)%val, &
               output(1,s)%vertex_features &
          ) ) / output(1,s)%num_vertices
          if( &
               this%model(this%leaf_vertices(1))%layer%output_shape(2).gt.0 &
          )then
             accuracy = accuracy + sum( this%get_accuracy( &
                  this%model(this%leaf_vertices(1))%layer%output(2,s_idx)%val, &
                  output(1,s)%edge_features &
             ) ) / output(1,s)%num_edges
          end if
       end do
    type is(real(real32))
       accuracy = sum( &
            this%get_accuracy( &
                 this%model(this%leaf_vertices(1))%layer%output(1,1)%val, &
                 output(:,start_index:end_index:1) &
            ))
    type is(integer)
       accuracy = sum( &
            this%get_accuracy( &
                 this%model(this%leaf_vertices(1))%layer%output(1,1)%val, &
                 real(output(:,start_index:end_index:1),real32) &
            ))
    class is(array_type)
       accuracy = sum( &
            this%get_accuracy( &
                 this%model(this%leaf_vertices(1))%layer%output(1,1)%val, &
                 output(1,1)%val(:,start_index:end_index:1) &
            ))
    end select
    accuracy = accuracy / real(end_index - start_index + 1, real32)

  end function accuracy_eval
!###############################################################################


!###############################################################################
  module function loss_eval(this, start_index, end_index) result(loss)
    !! Get the loss for the output
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of network
    integer, intent(in) :: start_index, end_index
    !! Start and end batch indices

    type(array_type), pointer :: loss
    !! Loss value

    ! Local variables
    integer :: i, s
    !! Loop index
    type(array_type), pointer :: expected(:,:), predicted(:,:)


    if(this%use_graph_output)then
       expected(1:2, 1: end_index - start_index + 1) => &
            this%expected_array( :, start_index:end_index )
    else
       allocate(expected(size(this%expected_array,1), size(this%expected_array,2)))
       do s = 1, size(this%expected_array,2)
          do i = 1, size(this%expected_array,1)
             call expected(i,s)%allocate( &
                  array_shape = [ &
                       this%expected_array(i,s)%shape, &
                       size(this%expected_array(i,s)%val,2) &
                  ] &
             )
             expected(i,s)%val = this%expected_array(i,s)%val(:, &
                  start_index:end_index:1)
          end do
       end do
    end if

    predicted => this%model(this%leaf_vertices(1))%layer%output
    loss => this%loss%compute( &
         predicted, &
         expected &
    )

  end function loss_eval
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module subroutine forward_generic2d(this, input)
    !! Forward pass for array derived type input
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: input
    !! Input

    ! Local variables
    integer :: l, i, j, vertex_idx, layer_id, parent_id
    !! Loop index and vertex index
    integer :: num_input_layers
    !! Number of input layers
    type(array_type), pointer :: input_ptr(:,:) => null()
    type(array_ptr_type), dimension(:), allocatable :: input_list
    logical :: use_precomp


    select type(input)
    type is(graph_type)
       do j = 1, this%batch_size
          if(any(input(1,j)%adj_ja(1,:).gt.input(1,j)%num_vertices))then
             call stop_program( &
                  "input graph has more vertices than expected" &
             )
          end if
       end do
    end select

    ! Use pre-computed navigation if available
    use_precomp = allocated(this%fwd_layer_id)

    ! Forward pass
    !---------------------------------------------------------------------------
    do l = 1, size(this%vertex_order,1)
       if(use_precomp)then
          layer_id = this%fwd_layer_id(l)
          num_input_layers = this%fwd_num_inputs(l)
       else
          vertex_idx = this%vertex_order(l)
          layer_id = this%auto_graph%vertex(vertex_idx)%id
          num_input_layers = count(this%auto_graph%adjacency(:,vertex_idx).gt.0)
       end if

       if(num_input_layers.eq.0)then
          select type(layer => this%model(layer_id)%layer)
          class is(input_layer_type)
             select type(input)
             type is(graph_type)
                call layer%set_input_graph( [ input(layer%index, :) ] )
                cycle
             class is(array_type)
                call layer%forward(input(layer%index:layer%index,:))
                do concurrent(i=1:size(layer%output,1), j=1:size(layer%output,2))
                   call layer%output(i,j)%set_requires_grad(.false.)
                end do
                cycle
             type is(real(real32))
                allocate(input_ptr(1,1))
                call input_ptr(1,1)%allocate(shape(input))
                call input_ptr(1,1)%set(input)
                call layer%forward(input_ptr)
                call layer%output(1,1)%set_requires_grad(.false.)
                deallocate(input_ptr)
                input_ptr => null()
                cycle
             class default
                call stop_program( &
                     "input type for layer "// &
                     trim(layer%name) // &
                     " is not supported" &
                )
             end select
          class default
             return
          end select
       elseif(num_input_layers.eq.1)then
          if(use_precomp)then
             parent_id = this%fwd_parent_id(l)
          else
             vertex_idx = this%vertex_order(l)
             j = maxloc( &
                  this%auto_graph%adjacency(:,vertex_idx), dim=1)
             parent_id = this%auto_graph%vertex(j)%id
          end if
          input_ptr => this%model(parent_id)%layer%output
          select type(input)
          type is(graph_type)
             call this%model(layer_id)%layer%set_graph( [ input(1,:) ] )
          end select
       else
          vertex_idx = this%vertex_order(l)
          allocate(input_list(num_input_layers))
          i = 0
          do j = 1, size(this%vertex_order,1)
             if(this%auto_graph%adjacency(j,vertex_idx).gt.0)then
                i = i + 1
                parent_id = this%auto_graph%vertex(j)%id
                input_list(i)%array => this%model(parent_id)%layer%output
             end if
          end do
       end if

       if(use_precomp)then
          if(this%fwd_layer_type(l).eq.1)then
             select type(layer => this%model(layer_id)%layer)
             class is(merge_layer_type)
                call layer%combine(input_list)
             end select
             deallocate(input_list)
          else
             call this%model(layer_id)%layer%forward(input_ptr)
             input_ptr => null()
          end if
       else
          select type(layer => this%model(layer_id)%layer)
          class is(merge_layer_type)
             call layer%combine(input_list)
             deallocate(input_list)
          class default
             call layer%forward(input_ptr)
             input_ptr => null()
          end select
       end if

    end do

  end subroutine forward_generic2d
!-------------------------------------------------------------------------------
  module function forward_eval(this, input) result(output)
    !! Forward pass for evaluation
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: input
    !! Input

    type(array_type), pointer :: output(:,:)
    !! Output

    call this%forward(input)
    output => this%model(this%leaf_vertices(1))%layer%output

  end function forward_eval
!-------------------------------------------------------------------------------
  module function forward_eval_multi(this, input) result(output)
    !! Forward pass for evaluation
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: input
    !! Input

    type(array_ptr_type), pointer :: output(:)
    !! Output

    ! Local variables
    integer :: l
    !! Loop index

    call this%forward(input)
    allocate(output(size(this%leaf_vertices,1)))
    do l = 1, size(this%leaf_vertices,1)
       output(l)%array => this%model(this%leaf_vertices(l))%layer%output
    end do

  end function forward_eval_multi
!###############################################################################


!###############################################################################
  module subroutine update(this)
    !! Update the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    real(real32), dimension(this%num_params) :: params, gradients
    !! Parameters and gradients

    ! Local variables
    integer :: l, i, s, start_idx, end_idx, seg
    !! Loop index


    !---------------------------------------------------------------------------
    ! Increment optimiser iteration counter
    !---------------------------------------------------------------------------
    if(this%optimiser%lr_decay%iterate_per_epoch)then
       if(this%epoch.gt.this%optimiser%epoch)then
          this%optimiser%epoch = this%epoch
          this%optimiser%iter = this%optimiser%iter + 1
       end if
    else
       this%optimiser%iter = this%optimiser%iter + 1
    end if


    !---------------------------------------------------------------------------
    ! Get learnable parameters and gradients (using pre-computed layout)
    !---------------------------------------------------------------------------
    if(this%param_num_segments.gt.0 .and. &
         allocated(this%param_seg_layer))then
       do seg = 1, this%param_num_segments
          l = this%param_seg_layer(seg)
          i = this%param_seg_pidx(seg)
          start_idx = this%param_seg_start(seg)
          end_idx = this%param_seg_end(seg)
          select type(current => this%model(l)%layer)
          class is(learnable_layer_type)
             params(start_idx:end_idx) = current%params(i)%val(:,1)
             if(.not.associated(current%params(i)%grad))then
                call stop_program( &
                     "Gradient not allocated for parameters" &
                )
             end if
             s = size(current%params(i)%grad%val,2)
             if(s.eq.1)then
                gradients(start_idx:end_idx) = &
                     current%params(i)%grad%val(:,1)
             else
                gradients(start_idx:end_idx) = &
                     sum(current%params(i)%grad%val, dim=2) / &
                     real(s, real32)
             end if
          end select
       end do
    else
       start_idx = 0
       end_idx   = 0
       do l = 1, this%num_layers
          select type(current => this%model(l)%layer)
          class is(learnable_layer_type)
             do i = 1, size(current%params)
                start_idx = end_idx + 1
                end_idx = end_idx + size(current%params(i)%val, 1)
                params(start_idx:end_idx) = current%params(i)%val(:,1)
                if(.not.associated(current%params(i)%grad))then
                   call stop_program( &
                        "Gradient not allocated for parameters" &
                   )
                end if
                select case(size(current%params(i)%grad%val,2))
                case(1)
                   gradients(start_idx:end_idx) = &
                        current%params(i)%grad%val(:,1)
                case default
                   gradients(start_idx:end_idx) = [ &
                        sum(current%params(i)%grad%val, dim=2) / &
                        real( &
                             size(current%params(i)%grad%val, dim=2), &
                             real32) &
                   ]
                end select
             end do
          end select
       end do
    end if
    call this%optimiser%clip_dict%apply(size(gradients),gradients)

    !---------------------------------------------------------------------------
    ! Update layers of learnable layer types
    !---------------------------------------------------------------------------
    call this%optimiser%minimise(params, gradients)

    ! Set params back using pre-computed layout
    if(this%param_num_segments.gt.0 .and. &
         allocated(this%param_seg_layer))then
       do seg = 1, this%param_num_segments
          l = this%param_seg_layer(seg)
          i = this%param_seg_pidx(seg)
          start_idx = this%param_seg_start(seg)
          end_idx = this%param_seg_end(seg)
          select type(current => this%model(l)%layer)
          class is(learnable_layer_type)
             current%params(i)%val(:,1) = params(start_idx:end_idx)
          end select
       end do
    else
       call this%set_params(params)
    end if
    call this%reset_gradients()

  end subroutine update
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module subroutine nullify_graph(this)
    !! Nullify the input graph
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network

    ! Local variables
    integer :: l

    do l = 1, this%num_layers
       call this%model(l)%layer%nullify_graph()
    end do

  end subroutine nullify_graph
!###############################################################################


!###############################################################################
  module subroutine post_epoch_hook(this, epoch, loss, accuracy)
    !! Default epoch hook — no-op.
    !! Override in a derived type to add custom per-epoch behaviour.
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    integer, intent(in) :: epoch
    !! Current epoch number
    real(real32), intent(in) :: loss
    !! Mean loss over the epoch
    real(real32), intent(in) :: accuracy
    !! Mean accuracy over the epoch

  end subroutine post_epoch_hook
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module function save_input_to_network( this, input ) result(num_samples)
    !! Save input to network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: input
    !! Input

    integer :: num_samples
    !! Number of samples

    ! Local variables
    integer :: i, j, l, ip, input_rank, num_inputs
    !! Loop index
    integer :: num_input_layers
    !! Number of input layers
    logical :: l_valid_rank_type
    !! Boolean whether rank type is valid
    character(256) :: err_msg
    !! Error message

    num_samples = get_num_samples(this, input)
    if(num_samples.le.0) return
    num_input_layers = size(this%root_vertices, 1)
    if(allocated(this%input_array))then
       do i = 1, size(this%input_array, 1)
          do j = 1, size(this%input_array, 2)
             call this%input_array(i,j)%deallocate()
          end do
       end do
       deallocate(this%input_array)
    end if
    if(allocated(this%input_graph)) deallocate(this%input_graph)

    ! Determine the rank of the input
    !---------------------------------------------------------------------------
    select rank(input)
    rank(0)
    rank(1)
    rank(2)
       select type(input)
       class is(array_type)
          num_inputs = size(input(1,1)%val, 1)
          allocate(this%input_array(size(input,1), size(input,2)))
          do i = 1, size(input,1)
             do j = 1, size(input,2)
                call this%input_array(i,j)%assign_shallow(input(i,j))
             end do
          end do
          return
       class default
          input_rank = rank(input)
          num_inputs = size(input) / num_samples
          allocate(this%input_array(1,1))
          call this%input_array(1,1)%allocate(array_shape=[num_inputs, num_samples])
       end select
    rank default
       input_rank = rank(input)
       num_inputs = size(input) / num_samples
       allocate(this%input_array(1,1))
       call this%input_array(1,1)%allocate(array_shape=shape(input))
    end select
    l_valid_rank_type = .false.


    ! Process input based on its rank
    !---------------------------------------------------------------------------
    rank_select: select rank(input)
    rank(0)
       select type(input)
       type is(real); exit rank_select
       class default; l_valid_rank_type = .true.
       end select
       if(num_input_layers.ne.1)then
          call stop_program( &
               "number of input arrays does not match expected number of &
               &input layers" &
          )
          return
       end if
       select type(input)
       class is(array_type)
          allocate(this%input_array(1,1))
          call handle_array_type(input, this%input_array(1,1), num_samples)
       type is(array_ptr_type)
          allocate(this%input_array(size(input%array,1), size(input%array,2)))
          do i = 1, size(input%array,1)
             do j = 1, size(input%array,2)
                call handle_array_type( &
                     input%array(i,j), this%input_array(i,j), num_samples &
                )
             end do
          end do
       end select
    rank(1)
       select type(input)
       type is(real(real32))
          exit rank_select
       type is(graph_type)
          allocate(this%input_graph(num_input_layers, num_samples))
          this%input_graph(1,:) = input(:)
          return
       class default
          l_valid_rank_type = .true.
       end select
       if(size(input,1).ne.num_input_layers)then
          call stop_program( &
               "number of input arrays does not match expected number of &
               &input layers" &
          )
          return
       end if
       select type(input)
       class is(array_type)
          allocate(this%input_array(1,size(input,1)))
          do l = 1, size(input,1)
             call handle_array_type(input(l), this%input_array(1,l), num_samples)
          end do
       type is(array_ptr_type)
          call stop_program("Use of array_ptr_type with rank 1 input not yet supported")
          return
          ! ip = 0
          ! do l = 1, size(input,1)
          !       do i = 1, size(input%array,1)
          !          ip = ip + 1
          !          do j = 1, size(input%array,2)
          !             call handle_array_type( &
          !                  input(l)%array(i,j), this%input_array(ip,j), num_samples &
          !             )
          !          end do
          !       end do
          ! end do
       end select
    rank(2)
       select type(input)
       type is(real(real32))
          this%input_array(1,1)%val = reshape(input, [num_inputs, num_samples])
          l_valid_rank_type = .true.
       type is(graph_type)
          num_samples = size(input, dim=2)
          allocate(this%input_graph(num_input_layers, num_samples))
          this%input_graph(:,:) = input(:,:)
          return
       type is(array_type)
          call stop_program("SHOULD NOT GET HERE")
          this%input_array = input
          l_valid_rank_type = .true.
       end select
    rank(3)
       select type(input)
       type is(real(real32))
          call this%input_array(1,1)%set(input)
          l_valid_rank_type = .true.
       end select
    rank(4)
       select type(input)
       type is(real(real32))
          call this%input_array(1,1)%set(input)
          l_valid_rank_type = .true.
       end select
    rank(5)
       select type(input)
       type is(real(real32))
          call this%input_array(1,1)%set(input)
          l_valid_rank_type = .true.
       end select
    end select rank_select

    if(.not.l_valid_rank_type)then
       write(err_msg,'("Unknown input type for rank ",I0)') input_rank
       call stop_program(err_msg)
       return
    end if

  contains

    function get_num_samples(network, input) result(num_samples)
      implicit none
      !! Get the number of samples in the input

      ! Arguments
      type(network_type), intent(in) :: network
      !! Instance of network
      class(*), dimension(..), intent(in) :: input
      !! Input
      integer :: num_samples
      !! Number of samples

      ! Local variables
      integer :: layer_id
      !! Layer ID
      logical :: use_graph_input
      !! Whether to use graph input

      num_samples = 0
      layer_id = network%auto_graph%vertex(network%root_vertices(1))%id
      use_graph_input = network%model(layer_id)%layer%use_graph_input
      select rank(input)
      rank(0)
         select type(input)
         class is(array_type)
            num_samples = size(input%val, 2)
         class is(array_ptr_type)
            num_samples = size(input%array(1,1)%val, 2)
         class default
            call stop_program("Unknown input type in get_num_samples for rank 0")
            return
         end select
      rank(1)
         select type(input)
         class is(array_type)
            if(use_graph_input)then
               num_samples = size(input)
            else
               num_samples = size(input(1)%val, 2)
            end if
         class is(array_ptr_type)
            if(use_graph_input)then
               num_samples = size(input(1)%array, 2)
            else
               num_samples = size(input(1)%array(1,1)%val, 2)
            end if
         class is(graph_type)
            num_samples = size(input, dim=1)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 1")
            return
         end select
      rank(2)
         select type(input)
         class is(array_type)
            if(use_graph_input)then
               num_samples = size(input, 2)
            else
               num_samples = size(input(1,1)%val, 2)
            end if
         class is(graph_type)
            num_samples = size(input, dim=2)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 2")
            return
         end select
      rank(3)
         select type(input)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 3")
            return
         end select
      rank(4)
         select type(input)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 4")
            return
         end select
      rank(5)
         select type(input)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 5")
            return
         end select
      rank default
         call stop_program("Unknown input rank in get_num_samples")
         return
      end select

    end function get_num_samples


    subroutine handle_array_type(input, output, num_samples)
      !! Handle array type input

      ! Arguments
      class(array_type), intent(in) :: input
      !! Input
      type(array_type), intent(out) :: output
      !! Output
      integer, intent(in) :: num_samples
      !! Number of samples

      if(size(input%val,2).ne.num_samples)then
         call stop_program("number of samples in input arrays do not match")
         return
      end if
      call output%allocate( array_shape = &
           [ product(input%shape(1:input%rank)), num_samples ] &
      )
      output%val = input%val
    end subroutine handle_array_type

  end function save_input_to_network
!-------------------------------------------------------------------------------
  module subroutine save_output_to_network( this, output )
    !! Save output to network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: output
    !! Output

    ! Local variables
    integer :: i, j, s
    !! Loop indices

    if(allocated(this%expected_array))then
       do i = 1, size(this%expected_array, 1)
          do j = 1, size(this%expected_array, 2)
             call this%expected_array(i,j)%deallocate()
          end do
       end do
       deallocate(this%expected_array)
    end if

    select type(output)
    type is(graph_type)
       allocate(this%expected_array(2,size(output,2)))
       do s = 1, size(output,2)
          if(this%expected_array(1,s)%allocated) &
               call this%expected_array(1,s)%deallocate()
          if(this%expected_array(2,s)%allocated) &
               call this%expected_array(2,s)%deallocate()
          call this%expected_array(1,s)%allocate( &
               array_shape = [ &
                    output(1,s)%num_vertex_features, output(1,s)%num_vertices &
               ] &
          )
          call this%expected_array(1,s)%zero_grad()
          call this%expected_array(1,s)%set_requires_grad(.false.)
          call this%expected_array(1,s)%set( output(1,s)%vertex_features )
          this%expected_array(1,s)%is_temporary = .false.
          if(output(1,s)%num_edge_features.le.0) cycle
          call this%expected_array(2,s)%allocate( &
               array_shape = [ &
                    output(1,s)%num_edge_features, output(1,s)%num_edges &
               ] &
          )
          call this%expected_array(2,s)%set_requires_grad(.false.)
          call this%expected_array(2,s)%set( output(1,s)%edge_features )
          this%expected_array(2,s)%is_temporary = .false.
       end do
    class is(array_type)
       allocate(this%expected_array(size(output,1),size(output,2)))
       do s = 1, size(output,2)
          do i = 1, size(output,1)
             if(this%expected_array(i,s)%allocated) &
                  call this%expected_array(i,s)%deallocate()
             call this%expected_array(i,s)%allocate( &
                  array_shape = [ &
                       output(i,s)%shape, size(output(i,s)%val,2) &
                  ] &
             )
             call this%expected_array(i,s)%set_requires_grad(.false.)
             call this%expected_array(i,s)%set( output(i,s)%val )
             this%expected_array(i,s)%is_temporary = .false.
          end do
       end do
    type is(real)
       allocate(this%expected_array(1,1))
       if(this%expected_array(1,1)%allocated) &
            call this%expected_array(1,1)%deallocate()
       call this%expected_array(1,1)%allocate( &
            array_shape = [ size(output,1), size(output,2) ] &
       )
       call this%expected_array(1,1)%set_requires_grad(.false.)
       call this%expected_array(1,1)%set( output )
       this%expected_array(1,1)%is_temporary = .false.
    type is(integer)
       allocate(this%expected_array(1,1))
       if(this%expected_array(1,1)%allocated) &
            call this%expected_array(1,1)%deallocate()
       call this%expected_array(1,1)%allocate( &
            array_shape = [ size(output,1), size(output,2) ] &
       )
       call this%expected_array(1,1)%set_requires_grad(.false.)
       this%expected_array(1,1)%val = real(output, real32)
       this%expected_array(1,1)%is_temporary = .false.
    class default
       call stop_program("output type not supported in training")
    end select

  end subroutine save_output_to_network
!###############################################################################


!##############################################################################!
! * * * * * * * * * * * * * * * * * * *  * * * * * * * * * * * * * * * * * * * !
!##############################################################################!


!###############################################################################
  module subroutine train( &
       this, input, output, num_epochs, batch_size, &
       plateau_threshold, shuffle_batches, batch_print_step, verbose, &
       print_precision, scientific_print, early_stopping, &
       val_input, val_output &
  )
    !! Train the network
    !!
    !! This function trains the network on the input data for a number of
    !! epochs. The input data is split into batches of size batch_size and
    !! the network is trained on each batch. The network is trained using
    !! the optimiser specified in the network object.
    use athena__tools_infile, only: stop_check
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: input
    !! Input data
    class(*), dimension(:,:), intent(in) :: output
    !! Output data
    integer, intent(in) :: num_epochs
    !! Number of epochs
    integer, optional, intent(in) :: batch_size
    !! Batch size
    real(real32), optional, intent(in) :: plateau_threshold
    !! Plateau threshold
    logical, optional, intent(in) :: shuffle_batches
    !! Shuffle batches
    integer, optional, intent(in) :: batch_print_step
    !! Batch print step
    integer, optional, intent(in) :: verbose
    !! Verbosity level
    integer, optional, intent(in) :: print_precision
    !! Number of decimal places to print for training metrics
    logical, optional, intent(in) :: scientific_print
    !! Whether to print training metrics in scientific notation
    logical, optional, intent(in) :: early_stopping
    !! Whether to stop training early if convergence is detected
    class(*), dimension(..), optional, intent(in) :: val_input
    !! Validation input data
    class(*), dimension(:,:), optional, intent(in) :: val_output
    !! Validation expected output data

    ! Training parameters
    real(real32) :: batch_loss, batch_accuracy, avg_loss, avg_accuracy
    !! Loss and accuracy

    ! learning parameters
    integer :: l, num_samples
    !! Loop index
    integer :: num_batches
    !! Number of batches
    integer :: converged
    !! Convergence flag
    integer :: window_width
    !! Length of convergence check window
    integer :: verbose_
    !! Verbosity level
    real(real32) :: plateau_threshold_
    !! Plateau threshold
    logical :: shuffle_batches_
    !! Shuffle batches
    logical :: use_accuracy
    !! Whether accuracy evaluation is available
    logical :: early_stopping_
    !! Whether to stop training early if convergence is detected

    ! Printing parameters
    integer :: batch_print_step_
    !! Batch print step
    integer :: print_precision_
    !! Number of decimal places for metric output
    logical :: scientific_print_
    !! Whether to print metrics in scientific notation
    character(len=64) :: loss_str, accuracy_str
    !! Formatted metrics for printing
    character(len=128) :: val_str
    !! Formatted validation metrics for printing

    ! Training loop variables
    integer :: epoch, batch, start_index, end_index
    !! Loop index
    integer, allocatable, dimension(:) :: batch_order
    !! Batch order

    integer :: i, j, s
    !! Loop index
    integer :: tmp_batch_size
    !! Temporary integer to store batch size during validation
    integer :: current_batch_size, target_batch_size
    !! Actual batch size for the current batch and the target batch size
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans

    class(*), allocatable :: data_poly(:,:)
    type(array_type), pointer :: loss => null()

    ! Validation variables
    logical :: use_validation
    !! Whether validation data is provided
    integer :: val_num_samples
    !! Number of validation samples
    integer :: val_sample
    !! Loop index for validation
    real(real32) :: val_loss, val_accuracy, val_loss_sum, val_accuracy_sum
    !! Validation loss and accuracy
#ifdef __INTEL_COMPILER
    type(array_type), pointer :: saved_expected_array(:,:) => null()
    !! Saved training expected output (for restoring after validation)
#else
    type(array_type), dimension(:,:), allocatable :: saved_expected_array
    !! Saved training expected output (for restoring after validation)
#endif

#ifdef _OPENMP
    type(network_type) :: this_copy
    !! Copy of network
#endif


    !---------------------------------------------------------------------------
    ! Check loss and accuracy methods are set
    !---------------------------------------------------------------------------
    if(.not.allocated(this%loss))then
       call stop_program("loss method not set")
       return
    end if
    use_accuracy = associated(this%get_accuracy)
    accuracy_str = ""
    val_str = ""

    !---------------------------------------------------------------------------
    ! Check validation data
    !---------------------------------------------------------------------------
    use_validation = present(val_input) .and. present(val_output)
    if(present(val_input) .neqv. present(val_output))then
       call stop_program( &
            "both val_input and val_output must be provided for validation" &
       )
       return
    end if


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    verbose_ = 0
    batch_print_step_ = 20
    plateau_threshold_ = 0._real32
    shuffle_batches_ = .true.
    scientific_print_ = .false.
    print_precision_ = 3
    early_stopping_ = .true.
    if(present(plateau_threshold)) plateau_threshold_ = plateau_threshold
    if(present(shuffle_batches)) shuffle_batches_ = shuffle_batches
    if(present(batch_print_step)) batch_print_step_ = batch_print_step
    if(present(verbose)) verbose_ = verbose
    if(present(print_precision)) print_precision_ = max(print_precision, 0)
    if(present(scientific_print)) scientific_print_ = scientific_print
    if(present(batch_size)) this%batch_size = batch_size
    if(present(early_stopping)) early_stopping_ = early_stopping


    !---------------------------------------------------------------------------
    ! Initialise monitoring variables
    !---------------------------------------------------------------------------
    window_width = max(ceiling(500._real32/this%batch_size),1)
    do i = 1, size(this%metrics,dim=1)
       this%metrics(i)%window_width = window_width
    end do


    !---------------------------------------------------------------------------
    ! Save input and output to network
    !---------------------------------------------------------------------------
    num_samples = this%save_input( input )
    call this%save_output( output )
    if(size(output,2).ne.num_samples.and.this%use_graph_output)then
       call stop_program("number of samples in input and output do not match")
       return
    end if


    !---------------------------------------------------------------------------
    ! If parallel, initialise slices
    !---------------------------------------------------------------------------
    num_batches = (num_samples + this%batch_size - 1) / this%batch_size
    allocate(batch_order(num_batches))
    do batch = 1, num_batches
       batch_order(batch) = batch
    end do


    !---------------------------------------------------------------------------
    ! Set/reset batch size for training
    !---------------------------------------------------------------------------
    call this%set_batch_size(this%batch_size)
    target_batch_size = this%batch_size


    !---------------------------------------------------------------------------
    ! Enable training mode
    !---------------------------------------------------------------------------
    call this%set_training_mode(mode_store)


    epoch_loop: do epoch = 1, num_epochs
       this%epoch = epoch
       !------------------------------------------------------------------------
       ! Shuffle batch order at the start of each epoch
       !------------------------------------------------------------------------
       if(shuffle_batches_)then
          call shuffle(batch_order)
       end if

       avg_loss     = 0._real32
       avg_accuracy = 0._real32

       !------------------------------------------------------------------------
       ! Batch loop
       ! ... split data up into minibatches for training
       !------------------------------------------------------------------------
       batch_loop: do batch = 1, num_batches


          ! Set batch start and end index
          !---------------------------------------------------------------------
          start_index = (batch_order(batch) - 1) * this%batch_size + 1
          end_index = min(batch_order(batch) * this%batch_size, num_samples)
          current_batch_size = end_index - start_index + 1
          if(current_batch_size.ne.target_batch_size)then
             call this%set_batch_size(current_batch_size)
          end if


          ! Forward pass
          !---------------------------------------------------------------------
          select case(this%use_graph_input)
          case(.true.)
             data_poly = get_sample( &
                  this%input_graph, start_index, end_index, current_batch_size &
             )
          case default
             data_poly = get_sample( &
                  this%input_array, start_index, end_index, current_batch_size, &
                  as_graph = .false. &
             )
          end select
          call this%forward(data_poly)
          deallocate(data_poly)


          ! Backward pass
          !---------------------------------------------------------------------
          loss => this%loss_eval(start_index, end_index)
          loss%is_temporary = .false.
          call loss%grad_reverse(reset_graph=.true.)


          ! Compute loss and accuracy (for monitoring)
          !---------------------------------------------------------------------
          batch_loss = sum(loss%val)
          batch_accuracy = 0._real32
          if(use_accuracy)then
             batch_accuracy = this%accuracy_eval(output, start_index, end_index)
          end if


          ! Average metric over batch size and store
          !---------------------------------------------------------------------
          avg_loss = avg_loss + batch_loss
          if(use_accuracy)then
             avg_accuracy = avg_accuracy + batch_accuracy
          end if


          ! Update weights and biases using optimisation algorithm
          !---------------------------------------------------------------------
          call this%update()
          call loss%nullify_graph()
          deallocate(loss)
          nullify(loss)


          ! Print batch results
          !---------------------------------------------------------------------
          if(abs(verbose_).gt.0.and.&
               (batch.eq.1.or.abs(mod(batch,batch_print_step_)).lt.1.E-6))then
             loss_str = format_training_real( &
                  avg_loss / real(batch, real32), print_precision_, &
                  scientific_print_ &
             )
             if(use_accuracy)then
                accuracy_str = ", accuracy=" // trim(format_training_real( &
                     avg_accuracy / real(batch, real32), &
                     print_precision_, scientific_print_ &
                ))
             end if

             write(6,'("epoch=",I0,", batch=",I0,&
                  &", lr=",ES0.2,", loss=",A,A)' &
             ) &
                  this%epoch, batch, &
                  this%optimiser%lr_decay%get_lr( &
                       this%optimiser%learning_rate, this%optimiser%iter &
                  ), &
                  trim(loss_str), trim(accuracy_str)
          end if


          ! Check for user-name stop file
          !---------------------------------------------------------------------
          if(stop_check())then
             write(0,*) "STOPCAR ENCOUNTERED"
             write(0,*) "Exiting training loop..."
             exit epoch_loop
          end if

       end do batch_loop
       call this%metrics(1)%append(avg_loss / real(num_batches, real32))
       if(use_accuracy)then
          call this%metrics(2)%append(avg_accuracy / real(num_batches, real32))
       end if


       !------------------------------------------------------------------------
       ! Validation evaluation at end of epoch
       !------------------------------------------------------------------------
       val_str = ""
       if(use_validation)then

#ifdef __INTEL_COMPILER
          ! Save training expected output. `ifx` crashes on the allocatable
          ! local declaration used with `move_alloc`, so keep an explicit
          ! pointer-backed copy here instead.
          if(allocated(this%expected_array))then
             allocate(saved_expected_array( &
                  size(this%expected_array, 1), size(this%expected_array, 2) &
             ))
             do i = 1, size(this%expected_array, 1)
                do j = 1, size(this%expected_array, 2)
                   call saved_expected_array(i,j)%allocate( &
                        source=this%expected_array(i,j) &
                   )
                   call this%expected_array(i,j)%deallocate()
                end do
             end do
             deallocate(this%expected_array)
          else
             nullify(saved_expected_array)
          end if
#else
          call move_alloc(this%expected_array, saved_expected_array)
#endif

          ! Save validation output to network
          call this%save_output( val_output )

          ! Save validation input
          val_num_samples = this%save_input( val_input )

          ! Set batch size to 1 and enable inference mode
          call this%set_batch_size(1)
          call this%set_inference_mode()

          ! Evaluate validation loss and accuracy
          val_loss_sum = 0._real32
          val_accuracy_sum = 0._real32
          do val_sample = 1, val_num_samples
             select case(this%use_graph_input)
             case(.true.)
                data_poly = get_sample( &
                     this%input_graph, val_sample, val_sample, 1 &
                )
             case default
                data_poly = get_sample_array( &
                     this%input_array, val_sample, val_sample, 1, &
                     as_graph = .false. &
                )
             end select
             call this%forward(data_poly)
             deallocate(data_poly)

             loss => this%loss_eval(val_sample, val_sample)
             val_loss_sum = val_loss_sum + sum(loss%val)
             call loss%nullify_graph()
             deallocate(loss)
             nullify(loss)

             if(use_accuracy)then
                val_accuracy_sum = val_accuracy_sum + &
                     this%accuracy_eval(val_output, val_sample, val_sample)
             end if
          end do

          val_loss = val_loss_sum / real(val_num_samples, real32)
          val_accuracy = val_accuracy_sum / real(val_num_samples, real32)

          ! Build validation print string
          val_str = ", val_loss=" // trim(format_training_real( &
               val_loss, print_precision_, scientific_print_ &
          ))
          if(use_accuracy)then
             val_str = trim(val_str) // ", val_accuracy=" // &
                  trim(format_training_real( &
                       val_accuracy, print_precision_, scientific_print_ &
                  ))
          end if

          ! Restore training expected output
          if(allocated(this%expected_array))then
             do i = 1, size(this%expected_array, 1)
                do j = 1, size(this%expected_array, 2)
                   call this%expected_array(i,j)%deallocate()
                end do
             end do
             deallocate(this%expected_array)
          end if
#ifdef __INTEL_COMPILER
          ! `ifx` does not support `move_alloc` in this context
          if(associated(saved_expected_array))then
             allocate(this%expected_array( &
                  size(saved_expected_array, 1), size(saved_expected_array, 2) &
             ))
             do i = 1, size(saved_expected_array, 1)
                do j = 1, size(saved_expected_array, 2)
                   call this%expected_array(i,j)%allocate( &
                        source=saved_expected_array(i,j) &
                   )
                   call saved_expected_array(i,j)%deallocate()
                end do
             end do
             deallocate(saved_expected_array)
             nullify(saved_expected_array)
          end if
#else
          call move_alloc(saved_expected_array, this%expected_array)
#endif

          ! Restore training input
          num_samples = this%save_input( input )

          ! Restore training batch size and inference mode
          call this%set_batch_size(target_batch_size)
          call this%set_training_mode()

       end if


       ! Print epoch summary results
       !------------------------------------------------------------------------
       if(use_validation.and.verbose_.ge.0)then
          write(6,'("epoch=",I0,A)') this%epoch, trim(val_str)
       elseif(verbose_.eq.0)then
          loss_str = format_training_real( &
               this%metrics(1)%val, print_precision_, scientific_print_ &
          )
          if(use_accuracy)then
             accuracy_str = ", accuracy=" // trim(format_training_real( &
                  this%metrics(2)%val, print_precision_, scientific_print_ &
             ))
          end if
          write(6,'("epoch=",I0,&
               &", lr=",ES0.2,", loss=",A,A,A)' &
          ) &
               this%epoch, &
               this%optimiser%lr_decay%get_lr( &
                    this%optimiser%learning_rate, this%optimiser%iter &
               ), &
               trim(loss_str), trim(accuracy_str), trim(val_str)
       end if


       !------------------------------------------------------------------------
       ! Per-epoch callback (e.g. W&B logging via wandb_network_type)
       !------------------------------------------------------------------------
       call this%post_epoch_hook( &
            this%epoch, &
            this%metrics(1)%val, &
            this%metrics(2)%val &
       )

       !------------------------------------------------------------------------
       ! Check for convergence and stop early if enabled
       ! When validation data is provided, check validation loss for plateau
       !------------------------------------------------------------------------
       if(early_stopping_)then
          if(use_validation)then
             if(val_loss .lt. plateau_threshold_ .and. &
                  plateau_threshold_ .gt. 0._real32) exit epoch_loop
          else
             call this%metrics(1)%check(plateau_threshold_, converged)
             if(converged.ne.0) exit epoch_loop
          end if
          if(use_accuracy)then
             call this%metrics(2)%check(plateau_threshold_, converged)
             if(converged.ne.0) exit epoch_loop
          end if
       end if

    end do epoch_loop

    ! Final epoch metrics
    if(use_accuracy)then
       this%accuracy_val = this%metrics(2)%val
    else
       this%accuracy_val = 0._real32
    end if
    this%loss_val     = this%metrics(1)%val


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end subroutine train
!###############################################################################


!###############################################################################
  function format_training_real(value, decimals, scientific) result(formatted)
    !! Format a training metric with a configurable number of decimal places.
    implicit none

    ! Arguments
    real(real32), intent(in) :: value
    !! Value to format
    integer, intent(in) :: decimals
    !! Number of decimal places
    logical, intent(in) :: scientific
    !! Whether to use scientific notation

    character(len=64) :: formatted
    !! Formatted string

    ! Local variables
    character(len=16) :: fmt
    !! Internal write format
    integer :: width_
    !! Field width for scientific formatting
    integer :: decimals_
    !! Clamped decimal count

    decimals_ = min(max(decimals, 0), 30)
    if(scientific)then
       width_ = max(decimals_ + 8, 14)
       write(fmt,'("(ES",I0,".",I0,"E2)")') width_, decimals_
    else
       write(fmt,'("(F0.",I0,")")') decimals_
    end if
    write(formatted, fmt) value
    formatted = adjustl(formatted)

  end function format_training_real
!###############################################################################


!###############################################################################
  module subroutine test( &
       this, input, output, verbose &
  )
    !! Test the network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: input
    !! Input data
    class(*), dimension(:,:), intent(in) :: output
    !! Output data
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: l, sample, num_samples
    !! Loop index
    integer :: verbose_
    !! Verbosity level
    logical :: use_accuracy
    !! Whether accuracy evaluation is available
    real(real32) :: acc_val, loss_val
    !! Loss and accuracy
    class(*), allocatable, dimension(:,:) :: data_poly
    !! Polymorphic data array
    type(array_type), pointer :: loss => null()
    !! Loss
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if
    use_accuracy = associated(this%get_accuracy)

    do l = 1, size(this%metrics,dim=1)
       this%metrics(l)%val = 0._real32
    end do
    loss_val  = 0._real32
    acc_val = 0._real32


    num_samples = this%save_input( input )


    !---------------------------------------------------------------------------
    ! Reset batch size for testing
    !---------------------------------------------------------------------------
    call this%set_batch_size(1)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Testing loop
    !---------------------------------------------------------------------------
    test_loop1: do sample = 1, num_samples

       ! Forward pass
       !------------------------------------------------------------------------
       select case(this%use_graph_input)
       case(.true.)
          data_poly = get_sample( &
               this%input_graph, sample, sample, 1 &
          )
       case default
          data_poly = get_sample_array( &
               this%input_array, sample, sample, 1, &
               as_graph = .false. &
          )
       end select
       call this%forward(data_poly)
       deallocate(data_poly)


       ! Compute loss and accuracy (for monitoring)
       !------------------------------------------------------------------------
       loss => this%loss_eval(sample, sample)
       loss_val = sum(loss%val)
       call loss%nullify_graph()
       deallocate(loss)
       nullify(loss)
       if(use_accuracy)then
          acc_val = this%accuracy_eval(output, sample, sample)
          this%metrics(2)%val = this%metrics(2)%val + acc_val
       end if
       this%metrics(1)%val = this%metrics(1)%val + loss_val

    end do test_loop1


    ! Normalise metrics by number of samples
    !---------------------------------------------------------------------------
    if(use_accuracy)then
       this%accuracy_val = this%metrics(2)%val / real(num_samples, real32)
    else
       this%accuracy_val = 0._real32
    end if
    this%loss_val     = this%metrics(1)%val / real(num_samples, real32)


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end subroutine test
!###############################################################################


!###############################################################################
  module function predict_real( &
       this, input, verbose &
  ) result(output)
    !! Predict the output for a 1D input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    real(real32), dimension(..), intent(in) :: input
    !! Input
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    real(real32), dimension(:,:), allocatable :: output
    !! Output
    integer :: verbose_, batch_size
    !! Verbosity level
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    select rank(input)
    rank(2)
       batch_size = size(input,dim=2)
    rank(3)
       batch_size = size(input,dim=3)
    rank(4)
       batch_size = size(input,dim=4)
    rank(5)
       batch_size = size(input,dim=5)
    rank(6)
       batch_size = size(input,dim=6)
    rank default
       batch_size = size(input,dim=rank(input))
    end select


    !---------------------------------------------------------------------------
    ! Reset batch size for testing
    !---------------------------------------------------------------------------
    call this%set_batch_size(batch_size)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Predict
    !---------------------------------------------------------------------------
    call this%forward(get_sample(input, 1, batch_size, batch_size))

    output = this%model(this%leaf_vertices(1))%layer%output(1,1)%val


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_real
!###############################################################################


!###############################################################################
  module function predict_graph1d( this, input, verbose ) result(output)
    !! Predict the output for a graph input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    type(graph_type), dimension(:), intent(in) :: input
    !! Input graph
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: l, s
    !! Loop index
    type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: output
    !! Output graph
    integer :: verbose_ = 0, batch_size
    !! Verbosity level
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose

    !---------------------------------------------------------------------------
    ! Reset batch size for testing
    !---------------------------------------------------------------------------
    batch_size = size(input)
    call this%set_batch_size(batch_size)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Predict
    !---------------------------------------------------------------------------
    call this%forward(get_sample(input, 1, batch_size, batch_size))

    do l = 1, size(this%leaf_vertices)
       do s = 1, batch_size
          output(l,s)%num_vertices = input(s)%num_vertices
          output(l,s)%num_edges = input(s)%num_edges
          output(l,s)%num_vertex_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output_shape(1)
          output(l,s)%num_edge_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output_shape(2)
          output(l,s)%vertex_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output(1,s)%val
          if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
             output(l,s)%edge_features = input(s)%edge_features
          else
             output(l,s)%edge_features = this%model( &
                  this%leaf_vertices(l) &
             )%layer%output(2,s)%val
          end if
       end do
    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_graph1d
!-------------------------------------------------------------------------------
  module function predict_graph2d( this, input, verbose ) result(output)
    !! Predict the output for a graph input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    type(graph_type), dimension(:,:), intent(in) :: input
    !! Input graph
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: l, s
    !! Loop index
    type(graph_type), dimension(size(this%leaf_vertices),size(input,dim=2)) :: &
         output
    !! Output graph
    integer :: verbose_ = 0, batch_size
    !! Verbosity level
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose

    !---------------------------------------------------------------------------
    ! Reset batch size for testing
    !---------------------------------------------------------------------------
    batch_size = size(input, 2)
    call this%set_batch_size(batch_size)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Predict
    !---------------------------------------------------------------------------
    call this%forward(get_sample(input, 1, batch_size, batch_size))

    do l = 1, size(this%leaf_vertices)
       do s = 1, batch_size
          output(l,s)%num_vertices = input(1,s)%num_vertices
          output(l,s)%num_edges = input(1,s)%num_edges
          output(l,s)%num_vertex_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output_shape(1)
          output(l,s)%num_edge_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output_shape(2)
          output(l,s)%vertex_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output(1,s)%val
          if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
             output(l,s)%edge_features = input(1,s)%edge_features
          else
             output(l,s)%edge_features = this%model( &
                  this%leaf_vertices(l) &
             )%layer%output(2,s)%val
          end if
       end do
    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_graph2d
!###############################################################################


!###############################################################################
  module function predict_array_from_real( this, input, output_as_array, verbose ) &
       result(output)
    !! Predict the output for a generic input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: input
    !! Input graph
    logical, intent(in) :: output_as_array
    !! Whether to output as array
    integer, intent(in), optional :: verbose
    !! Verbosity level

    type(array_type), dimension(:,:), allocatable :: output
    !! Predicted output

    ! Local variables
    integer :: s, i
    !! Loop index
    integer :: num_samples
    !! Number of samples
    integer :: verbose_
    !! Verbosity level
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if
    if(.not.output_as_array)then
       call stop_program("predict_array_from_real: output_as_array must be true")
       return
    end if


    !---------------------------------------------------------------------------
    ! Set number of samples for predicting
    !---------------------------------------------------------------------------
    num_samples = this%save_input( input )
    ! call this%set_batch_size(num_samples)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)

    !---------------------------------------------------------------------------
    ! Forward pass
    !---------------------------------------------------------------------------
    select case(this%use_graph_input)
    case(.true.)
       call this%forward(this%input_graph)
    case default
       call this%forward(this%input_array)
    end select


    !---------------------------------------------------------------------------
    ! Allocate output data
    !---------------------------------------------------------------------------
    allocate(output( &
         size(this%model(this%leaf_vertices(1))%layer%output, 1), &
         size(this%model(this%leaf_vertices(1))%layer%output, 2) &
    ))
    do s = 1, size(this%model(this%leaf_vertices(1))%layer%output, 2)
       do i = 1, size(this%model(this%leaf_vertices(1))%layer%output, 1)
          output(i,s) = this%model(this%leaf_vertices(1))%layer%output(i,s)
       end do
    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_array_from_real
!###############################################################################


!###############################################################################
  module function predict_array( this, input, verbose ) &
       result(output)
    !! Predict the output for a generic input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(array_type), dimension(..), intent(in) :: input
    !! Input graph
    integer, intent(in), optional :: verbose
    !! Verbosity level

    type(array_type), dimension(:,:), allocatable :: output
    !! Predicted output

    ! Local variables
    integer :: l, s, i, j, layer_id
    !! Loop index
    integer :: num_samples
    !! Number of samples
    integer :: verbose_
    !! Verbosity level
    integer, dimension(2) :: output_shape
    !! Output shape
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if


    !---------------------------------------------------------------------------
    ! Set number of samples for predicting
    !---------------------------------------------------------------------------
    num_samples = this%save_input( input )
    ! call this%set_batch_size(num_samples)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)

    !---------------------------------------------------------------------------
    ! Forward pass
    !---------------------------------------------------------------------------
    select case(this%use_graph_input)
    case(.true.)
       call this%forward(this%input_graph)
    case default
       call this%forward(this%input_array)
    end select


    !---------------------------------------------------------------------------
    ! Allocate output data
    !---------------------------------------------------------------------------
    output_shape = this%get_output_shape()
    allocate(output(output_shape(1), output_shape(2)))
    do l = 1, size(this%leaf_vertices)
       layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id
       j = 0
       do i = 1, size(this%model(layer_id)%layer%output, 1)
          j = j + 1
          do s = 1, size(this%model(layer_id)%layer%output, 2)
             output(j,s) = this%model(layer_id)%layer%output(i,s)
          end do
       end do
    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_array
!###############################################################################


!###############################################################################
  module function predict_generic( this, input, verbose, output_as_graph ) &
       result(output)
    !! Predict the output for a generic input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: input
    !! Input graph
    integer, intent(in), optional :: verbose
    !! Verbosity level
    logical, intent(in), optional :: output_as_graph
    !! Boolean whether to output as graph

    class(*), dimension(:,:), allocatable :: output
    !! Predicted output

    ! Local variables
    integer :: l, s, i, j, layer_id
    !! Loop index
    integer :: num_samples
    !! Number of samples
    integer :: verbose_
    !! Verbosity level
    logical :: output_as_graph_
    !! Output as graph boolean
    integer, dimension(2) :: output_shape
    !! Output shape
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    if(present(output_as_graph))then
       output_as_graph_ = output_as_graph
    else
       output_as_graph_ = .false.
    end if
    if(output_as_graph_.and..not.this%use_graph_output)then
       call stop_program("output_as_graph is true but network does not use &
            &graph output")
    end if


    !---------------------------------------------------------------------------
    ! Set number of samples for predicting
    !---------------------------------------------------------------------------
    num_samples = this%save_input( input )
    call this%set_batch_size(num_samples)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)

    !---------------------------------------------------------------------------
    ! Forward pass
    !---------------------------------------------------------------------------
    select case(this%use_graph_input)
    case(.true.)
       call this%forward(this%input_graph)
    case default
       call this%forward(this%input_array)
    end select


    !---------------------------------------------------------------------------
    ! Allocate output data
    !---------------------------------------------------------------------------
    output_shape = this%get_output_shape()
    if(output_as_graph_)then
       allocate(output(output_shape(1), output_shape(2)), source = graph_type())
       select type(output)
       type is(graph_type)
          select type(input)
          type is(graph_type)
             do l = 1, size(this%leaf_vertices)
                do s = 1, num_samples
                   output(l,s)%num_vertices = input(1,s)%num_vertices
                   output(l,s)%num_edges = input(1,s)%num_edges
                   output(l,s)%num_vertex_features = this%model( &
                        this%leaf_vertices(l) &
                   )%layer%output_shape(1)
                   output(l,s)%num_edge_features = this%model( &
                        this%leaf_vertices(l) &
                   )%layer%output_shape(2)
                   output(l,s)%vertex_features = this%model( &
                        this%leaf_vertices(l) &
                   )%layer%output(1,s)%val
                   if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
                      output(l,s)%edge_features = input(1,s)%edge_features
                   else
                      output(l,s)%edge_features = this%model( &
                           this%leaf_vertices(l) &
                      )%layer%output(2,s)%val
                   end if
                end do
             end do
          class default
             call stop_program("input is not of type graph_type")
          end select
       class default
          call stop_program("allocation of output as graph_type failed")
       end select
    else
       output_shape = this%get_output_shape()
       allocate(output(output_shape(1), output_shape(2)), source = array_type())
       select type(output)
       type is(array_type)
          do l = 1, size(this%leaf_vertices)
             layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id
             j = 0
             do i = 1, size(this%model(layer_id)%layer%output, 1)
                j = j + 1
                do s = 1, size(this%model(layer_id)%layer%output, 2)
                   output(j,s) = this%model(layer_id)%layer%output(i,s)
                end do
             end do
          end do
       end select
    end if


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_generic
!###############################################################################


!###############################################################################
  module subroutine print_summary(this)
    !! Print a summary of the network architecture
    implicit none

    ! Arguments
    class(network_type), intent(in) :: this
    !! Instance of network

    ! Local variables
    integer :: i, vertex_idx
    !! Loop index and vertex index
    integer :: total_params
    !! Parameter counts
    integer :: layer_params
    !! Parameters in current layer
    character(len=80) :: line
    !! Line separator
    character(len=40) :: layer_name
    !! Layer name
    character(len=30) :: output_shape_str
    !! Output shape string
    character(len=20) :: param_str
    !! Parameter count string
    character(len=100) :: fmt
    !! Format string

    line = repeat('_', 80)

    ! Print header
    write(*,*)
    write(*,'(A)') line
    write(*,'(A)') 'Model Summary'
    write(*,'(A)') line
    write(*,'(A35, A25, A15)') 'Layer (type)', 'Output Shape', 'Param #'
    write(*,'(A)') repeat('=', 80)

    ! Initialise parameter count
    total_params = 0

    ! Print each layer
    do i = 1, this%num_layers
       vertex_idx = this%vertex_order(i)
       associate(layer => this%model(vertex_idx)%layer)
          ! Get layer name
          if(allocated(layer%name))then
             write(layer_name, '(A," (",A,")")') &
                  trim(layer%name), trim(layer%subtype)
          else
             write(layer_name, '(A,I0," (",A,")")') &
                  'layer_', i, trim(layer%subtype)
          end if

          ! Get output shape string
          if(allocated(layer%output_shape))then
             ! write the general format for output shape
             write(fmt,'("(""(""",A,"I0,"")"")")') &
                  repeat('I0,", "', size(layer%output_shape)-1)
             write(output_shape_str, fmt) layer%output_shape
          else
             output_shape_str = '(Not set)'
          end if

          ! Get parameter count
          layer_params = layer%get_num_params()
          total_params = total_params + layer_params
          if(layer_params .gt. 0)then
             write(param_str, '(I0)') layer_params
          else
             param_str = '0'
          end if

          ! Print layer information
          write(*,'(A35, A25, A15)') adjustl(trim(layer_name)), &
               adjustl(trim(output_shape_str)), adjustl(trim(param_str))
       end associate
    end do

    ! Print footer
    write(*,'(A)') repeat('=', 80)
    write(*,'(A,I0)') 'Number of input vertices: ', size(this%root_vertices)
    write(*,'(A,I0)') 'Number of output vertices: ', size(this%leaf_vertices)
    write(*,'(A,I0)') 'Total trainable params: ', total_params
    write(*,'(A)') line
    write(*,*)

  end subroutine print_summary
!###############################################################################


!###############################################################################
  module function inverse_design_real( &
       this, target, x_init, optimiser, steps &
  ) result(x_opt)
    !! Optimise input to match a target output (real inputs).
    !! Wraps the array_type implementation after converting real arrays.
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of the network
    real(real32), dimension(:,:), intent(in) :: target
    !! Target output values
    real(real32), dimension(:,:), intent(in) :: x_init
    !! Initial input values
    class(base_optimiser_type), optional, intent(in) :: optimiser
    !! Optimiser for input updates (defaults to network optimiser)
    integer, intent(in) :: steps
    !! Number of optimisation iterations
    real(real32), dimension(size(x_init,1), size(x_init,2)) :: x_opt
    !! Optimised input

    ! Local variables
    type(array_type), pointer :: target_arr(:,:), x_init_arr(:,:), x_opt_arr(:,:)
    !! Working input and target as array_type


    !---------------------------------------------------------------------------
    ! Convert real arrays to array_type
    !---------------------------------------------------------------------------
    allocate(target_arr(1,1), x_init_arr(1,1), x_opt_arr(1,1))
    call target_arr(1,1)%allocate(source=target)
    call x_init_arr(1,1)%allocate(source=x_init)

    !---------------------------------------------------------------------------
    ! Delegate to array_type implementation
    !---------------------------------------------------------------------------
    x_opt_arr = this%inverse_design_array_2d( &
         target_arr, x_init_arr, optimiser, steps &
    )
    x_opt = x_opt_arr(1,1)%val

    call target_arr(1,1)%deallocate()
    call x_init_arr(1,1)%deallocate()
    call x_opt_arr(1,1)%deallocate()
    deallocate(target_arr, x_init_arr, x_opt_arr)

  end function inverse_design_real
!###############################################################################


!###############################################################################
  module function inverse_design_array_0d( &
       this, target, x_init, optimiser, steps &
  ) result(x_opt)
    !! Optimise the input so the network output matches a target.
    !! Wraps the array_type implementation after converting to 2D array.
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of the network
    type(array_type), intent(in) :: target
    !! Target output values
    type(array_type), intent(in) :: x_init
    !! Initial input values
    class(base_optimiser_type), optional, intent(in) :: optimiser
    !! Optimiser for input updates (defaults to network optimiser)
    integer, intent(in) :: steps
    !! Number of optimisation iterations
    type(array_type) :: x_opt
    !! Optimised input

    ! Local variables
    type(array_type), pointer :: target_arr(:,:), x_init_arr(:,:), x_opt_arr(:,:)


    !---------------------------------------------------------------------------
    ! Convert real arrays to array_type
    !---------------------------------------------------------------------------
    allocate(target_arr(1,1), x_init_arr(1,1), x_opt_arr(1,1))
    call target_arr(1,1)%allocate(source=target)
    call x_init_arr(1,1)%allocate(source=x_init)

    !---------------------------------------------------------------------------
    ! Delegate to array_type implementation
    !---------------------------------------------------------------------------
    x_opt_arr = this%inverse_design_array_2d( &
         target_arr, x_init_arr, optimiser, steps &
    )
    x_opt = x_opt_arr(1,1)

    call target_arr(1,1)%deallocate()
    call x_init_arr(1,1)%deallocate()
    call x_opt_arr(1,1)%deallocate()
    deallocate(target_arr, x_init_arr, x_opt_arr)

  end function inverse_design_array_0d
!###############################################################################


!###############################################################################
  module function inverse_design_array_2d( &
       this, target, x_init, optimiser, steps &
  ) result(x_opt)
    !! Optimise the input so the network output matches a target.
    !! Wraps the array_type implementation after converting to 2D array.
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of the network
    type(array_type), dimension(:,:), intent(in) :: target
    !! Target output values
    type(array_type), dimension(:,:), intent(in) :: x_init
    !! Initial input values
    class(base_optimiser_type), optional, intent(in) :: optimiser
    !! Optimiser for input updates (defaults to network optimiser)
    integer, intent(in) :: steps
    !! Number of optimisation iterations
    type(array_type), dimension(size(x_init,1), size(x_init,2)) :: x_opt
    !! Optimised input

    ! Local variables
    integer :: step, i, j, itmp1, root_id, num_x, num_samples, num_elements
    !! Loop index, root layer id, number of input elements
    logical :: use_edge_features
    !! Whether edge features are used in the input
    type(array_type), pointer :: loss
    !! Loss pointer
    class(base_optimiser_type), allocatable :: opt
    !! Local optimiser instance
    real(real32), allocatable :: x_flat(:), x_grad(:)
    !! Flat input vector and gradient
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans
    real(real32), allocatable :: saved_params(:)
    !! Saved network parameters


    !---------------------------------------------------------------------------
    ! Ensure the network has a loss function
    !---------------------------------------------------------------------------
    if(.not.allocated(this%loss))then
       call this%set_loss("mse")
    end if


    !---------------------------------------------------------------------------
    ! Get number of input elements
    !---------------------------------------------------------------------------
    num_x = 0
    use_edge_features = .false.
    if(this%use_graph_input)then
       num_samples = size(x_init, dim=2)
       num_x = size(x_init(1,1)%val) ! vertex features
       ! determine if edge features are used by checking the output shape of the input layer
       if(size(this%model(this%root_vertices(1))%layer%output_shape,dim=1).eq.2)then
          use_edge_features = .true.
          num_x = num_x + size(x_init(2,1)%val) ! edge features
       end if
    else
       num_samples = size(x_init(1,1)%val, dim=2)
       do i = 1, size(x_init,1)
          do j = 1, size(x_init,2)
             num_x = num_x + size(x_init(i,j)%val,dim=1)
          end do
       end do
    end if
    x_opt = x_init
    if(num_samples.gt.1)then
       call stop_program( &
            "inverse_design_array_2d: batch size greater than 1 not supported" &
       )
    end if


    !---------------------------------------------------------------------------
    ! Set up optimiser for input variables
    !---------------------------------------------------------------------------
    if(present(optimiser))then
       allocate(opt, source=optimiser)
    else
       allocate(opt, source=base_optimiser_type( &
            learning_rate=this%optimiser%learning_rate))
    end if
    call opt%init_gradients(num_x)
    opt%iter = 0


    !---------------------------------------------------------------------------
    ! Pre-allocate flat arrays used in the optimisation loop
    !---------------------------------------------------------------------------
    allocate(x_flat(num_x))
    allocate(x_grad(num_x))


    !---------------------------------------------------------------------------
    ! Ensure training mode is active so the full graph is built
    !---------------------------------------------------------------------------
    call this%set_training_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Get root layer id
    !---------------------------------------------------------------------------
    root_id = this%auto_graph%vertex(this%root_vertices(1))%id
    call this%set_batch_size(num_samples)


    !---------------------------------------------------------------------------
    ! Save network parameters so they can be restored afterwards
    !---------------------------------------------------------------------------
    allocate(saved_params(this%num_params))
    saved_params = this%get_params()


    !---------------------------------------------------------------------------
    ! Optimisation loop
    !---------------------------------------------------------------------------
    do step = 1, steps

       ! Forward pass with current x
       call this%forward(x_opt)

       ! Enable gradient tracking on the input layer output
       if(this%use_graph_input)then
          call this%model(root_id)%layer%output(1,1)%set_requires_grad(.true.)
          if(use_edge_features)then
             call this%model(root_id)%layer%output(2,1)%set_requires_grad(.true.)
          end if
       else
          do i = 1, size(x_opt,1)
             do j = 1, size(x_opt,2)
                call this%model(root_id)%layer%output(i,j)%set_requires_grad(.true.)
             end do
          end do
       end if

       ! Compute loss via the network's loss function
       call this%save_output(target)
       loss => this%loss_eval(1, num_samples)

       ! Backward pass
       call loss%grad_reverse()

       ! Extract gradient w.r.t. input
       itmp1 = 0
       if(associated(this%model(root_id)%layer%output(1,1)%grad))then
          if(this%use_graph_input)then
             num_elements = size(x_opt(1,1)%val, dim=1)
             do i = 1, size(x_opt(1,1)%val, dim=2)
                itmp1 = itmp1 + 1
                x_grad(itmp1:itmp1+num_elements-1) = &
                     this%model(root_id)%layer%output(1,1)%grad%val(:,i)
                x_flat(itmp1:itmp1+num_elements-1) = &
                     x_opt(1,1)%val(:,i)
                itmp1 = itmp1 + num_elements - 1
             end do
             if(use_edge_features)then
                num_elements = size(x_opt(1,1)%val, dim=1)
                do i = 1, size(x_opt(2,1)%val, dim=2)
                   itmp1 = itmp1 + 1
                   x_grad(itmp1:itmp1+num_elements-1) = &
                        this%model(root_id)%layer%output(2,1)%grad%val(:,i)
                   x_flat(itmp1:itmp1+num_elements-1) = &
                        x_opt(2,1)%val(:,i)
                   itmp1 = itmp1 + num_elements - 1
                end do
             end if
          else
             do i = 1, size(x_opt,1)
                do j = 1, size(x_opt,2)
                   num_elements = size(x_opt(i,j)%val, dim=1)
                   itmp1 = itmp1 + 1
                   x_grad(itmp1:itmp1+num_elements-1) = &
                        this%model(root_id)%layer%output(i,j)%grad%val(:,1)
                   x_flat(itmp1:itmp1+num_elements-1) = &
                        x_opt(i,j)%val(:,1)
                   itmp1 = itmp1 + num_elements - 1
                end do
             end do
          end if
       else
          x_grad = 0._real32
       end if

       ! Update x using the optimiser (not the model weights)
       opt%iter = opt%iter + 1
       call opt%minimise(x_flat, x_grad)

       ! Convert flat x back to array form
       itmp1 = 0
       if(this%use_graph_input)then
          do i = 1, size(x_opt(1,1)%val, dim=2)
             itmp1 = itmp1 + 1
             x_opt(1,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(1,1)%val, dim=1)-1)
             itmp1 = itmp1 + size(x_opt(1,1)%val, dim=1) - 1
          end do
          if(use_edge_features)then
             do i = 1, size(x_opt(2,1)%val, dim=2)
                itmp1 = itmp1 + 1
                x_opt(2,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(2,1)%val, dim=1)-1)
                itmp1 = itmp1 + size(x_opt(2,1)%val, dim=1) - 1
             end do
          end if
       else
          do i = 1, size(x_opt,1)
             do j = 1, size(x_opt,2)
                itmp1 = itmp1 + 1
                x_opt(i,j)%val(:,1) = x_flat(itmp1:itmp1+size(x_opt(i,j)%val, dim=1)-1)
                itmp1 = itmp1 + size(x_opt(i,j)%val, dim=1) - 1
             end do
          end do
       end if

       ! Clean up computation graph
       call loss%nullify_graph()
       deallocate(loss)
       nullify(loss)

       ! Reset network parameter gradients so they remain unchanged
       call this%reset_gradients()

    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Restore network parameters to ensure model is unchanged
    !---------------------------------------------------------------------------
    call this%set_params(saved_params)

  end function inverse_design_array_2d
!###############################################################################

end submodule athena__network_submodule
