submodule(athena__network) athena__network_submodule !! Submodule containing implementations for the network module #ifdef _OPENMP use omp_lib #endif use coreutils, only: stop_program, print_warning, to_lower use athena__misc_ml, only: shuffle use athena__accuracy, only: categorical_score, mae_score, mse_score, r2_score use athena__base_layer, only: learnable_layer_type, merge_layer_type #if defined(GFORTRAN) use athena__container_layer, only: container_reduction #endif use athena__container_layer, only: & list_of_layer_types, allocate_list_of_layer_types, & list_of_onnx_layer_creators, allocate_list_of_onnx_layer_creators ! Layer types use athena__flatten_layer, only: flatten_layer_type use athena__add_layer, only: add_layer_type use athena__concat_layer, only: concat_layer_type use athena__input_layer, only: input_layer_type use athena__msgpass_layer, only: msgpass_layer_type use athena__recurrent_layer, only: recurrent_layer_type ! #ifdef _OPENMP ! !$omp declare reduction( & ! !$omp& network_reduction : network_type:omp_out%network_reduction(omp_in)) & ! !$omp& initialiser(omp_priv = omp_orig) ! #endif contains !############################################################################### module subroutine network_reduction(this, source) !! Procedure to add two networks together implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network type(network_type), intent(in) :: source !! Instance of network to be added to this ! Local variables integer :: i !! Loop index this%metrics(1)%val = this%metrics(1)%val + source%metrics(1)%val this%metrics(2)%val = this%metrics(2)%val + source%metrics(2)%val do i=1,size(this%model) select type(layer_this => this%model(i)%layer) class is(learnable_layer_type) select type(layer_source => source%model(i)%layer) class is(learnable_layer_type) call layer_this%reduce(layer_source) end select end select end do end subroutine network_reduction !############################################################################### !############################################################################### module subroutine network_copy(this, source) !! Procedure to copy a network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network type(network_type), intent(in), target :: source !! Instance of network to be copied ! Local variables integer :: i !! Loop index this%metrics = source%metrics this%model = source%model this%num_layers = source%num_layers this%batch_size = source%batch_size this%num_params = source%num_params this%num_outputs = source%num_outputs this%optimiser = source%optimiser this%vertex_order = source%vertex_order this%root_vertices = source%root_vertices this%leaf_vertices = source%leaf_vertices this%loss = source%loss this%get_accuracy => source%get_accuracy this%auto_graph = source%auto_graph end subroutine network_copy !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module subroutine build_vertex_order(this) !! Generate the order of the layers in the network !! !! This module contains the subroutine to generate the order of the layers !! in the network. The order is generated by depth first search (DFS) on the !! graph of the network. implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network ! Local variables integer :: i, order_index !! Loop index logical, dimension(this%auto_graph%num_vertices) :: visited !! Array to store whether a vertex has been visited = .false. if(allocated(this%vertex_order)) deallocate(this%vertex_order) allocate(this%vertex_order(this%auto_graph%num_vertices), source=0) order_index = 0 do i = this%auto_graph%num_vertices, 1, -1 if(.not.visited(i)) call this%dfs( & i, visited, this%vertex_order, order_index & ) end do end subroutine build_vertex_order !############################################################################### !############################################################################### recursive module subroutine dfs( & this, vertex_index, visited, order, order_index & ) !! Depth first search algorithm implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network integer, intent(in) :: vertex_index !! Index of the vertex to start the search from logical, dimension(this%auto_graph%num_vertices), intent(inout) :: visited !! Array to store whether a vertex has been visited integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order !! Array to store the order of the vertices integer, intent(inout) :: order_index !! Index of the current vertex in the order array ! Local variables integer :: i !! Loop index visited(vertex_index) = .true. do i = 1, this%auto_graph%num_vertices, 1 if(this%auto_graph%adjacency(i,vertex_index).ne.0)then if(.not.visited(i)) call this%dfs(i, visited, order, order_index) end if end do order_index = order_index + 1 order(order_index) = vertex_index end subroutine dfs !############################################################################### !############################################################################### module subroutine build_root_vertices(this) !! Calculate the root vertices of the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network ! Local variables integer :: i !! Loop index if(allocated(this%root_vertices)) deallocate(this%root_vertices) allocate(this%root_vertices(0)) ! from = 1 do i = 1, this%auto_graph%num_vertices if(all(this%auto_graph%adjacency(:,i).eq.0))then this%root_vertices = [this%root_vertices, i] ! to = from + this%model(i)layer%num_input_data - 1 ! this%root_bounds = [ this%root_bounds, reshape([from,to], [2,1]) ] ! from = to + 1 end if end do end subroutine build_root_vertices !############################################################################### !############################################################################### module subroutine build_leaf_vertices(this) !! Calculate the output vertices of the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network ! Local variables integer :: i !! Loop index if(allocated(this%leaf_vertices)) deallocate(this%leaf_vertices) allocate(this%leaf_vertices(0)) do i = 1, this%auto_graph%num_vertices if(all(this%auto_graph%adjacency(i,:).eq.0))then this%leaf_vertices = [this%leaf_vertices, i] end if end do end subroutine build_leaf_vertices !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module subroutine print(this, file) !! Print the network to a file use coreutils, only: to_upper use athena__io_utils, only: athena__version__ implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network character(*), intent(in) :: file !! File to print the network to ! Local variables integer :: l, v, e, vertex_index, unit !! Loop index integer :: operator_in, operator_out !! Operators for the layer character(3) :: operator_str !! String to store the operator character(256) :: suffix, fmt !! Suffix for the layer integer, dimension(:), allocatable :: input_list, output_list open(newunit=unit,file=file,status='replace') write(unit,'("NETWORK_SETTINGS")') write(unit,'(3X,"ATHENA_VERSION = ",A)') trim(adjustl(athena__version__)) if(allocated(this%name)) write(unit,'(3X,"NAME = ",A)') trim(adjustl(this%name)) write(unit,'(3X,"EPOCH = ",I0)') this%epoch write(unit,'(3X,"BATCH_SIZE = ",I0)') this%batch_size write(unit,'(3X,"ACCURACY = ",F0.9)') this%accuracy_val write(unit,'(3X,"LOSS = ",F0.9)') this%loss_val if(allocated(this%accuracy_method))then write(unit,'(3X,"ACCURACY_METHOD = ",A)') trim(adjustl(this%accuracy_method)) end if if(allocated(this%loss_method))then write(unit,'(3X,"LOSS_METHOD = ",A)') trim(adjustl(this%loss_method)) end if if(allocated(this%optimiser))then write(unit,'(3X,"OPTIMISER: ",A)') trim(adjustl(this%optimiser%name)) call this%optimiser%print_to_unit(unit=unit) write(unit,'(3X,"END OPTIMISER")') end if write(unit,'("END NETWORK_SETTINGS")') do v = 1, size(this%vertex_order,dim=1), 1 l = this%vertex_order(v) operator_in = -1 operator_out = -1 allocate(input_list(0), output_list(0)) do e = 1, this%auto_graph%num_edges if(-this%auto_graph%edge(e)%index(2).eq.l)then if(operator_in.gt.0.and.this%auto_graph%edge(e)%id.ne.operator_in)then write(0,*) "WARNING: multiple operators for layer ", l write(0,*) " using operator ", this%auto_graph%edge(e)%id end if operator_in = this%auto_graph%edge(e)%id vertex_index = & findloc( this%vertex_order, this%auto_graph%edge(e)%index(1), 1 ) input_list = [ input_list, vertex_index ] end if if(this%auto_graph%edge(e)%index(1).eq.l)then if(operator_out.gt.0.and.this%auto_graph%edge(e)%id.ne.operator_out)then write(0,*) "WARNING: multiple operators for layer ", l write(0,*) " using operator ", this%auto_graph%edge(e)%id end if operator_in = this%auto_graph%edge(e)%id vertex_index = & findloc( this%vertex_order, this%auto_graph%edge(e)%index(2), 1 ) output_list = [ output_list, vertex_index ] end if end do suffix = "" select case(operator_in) case(1) operator_str = " ||" case(2) operator_str = " +" case(3) operator_str = " *" end select ! get size of input_list and make the formatted string if(size(input_list).eq.0)then write(suffix,'(A," []")') trim(operator_str) else write(fmt,'("(A,A,"" ["",",I0,"(1X,I0),"" ]"")")') size(input_list) write(suffix,fmt) trim(suffix), operator_str, input_list end if ! select case(operator_out) ! case(1) ! operator_str = " ||" ! case(2) ! operator_str = " +" ! case(3) ! operator_str = " *" ! end select ! if(size(output_list).gt.0)then ! write(fmt,'("(A,A,"" ["",",I0,"(1X,I0),"" ]"")")') size(output_list) ! write(suffix,fmt) trim(suffix), operator_str, output_list ! end if write(unit,'(A,A)') to_upper(trim(this%model(l)%layer%name)), trim(suffix) call this%model(l)%layer%print(unit=unit, print_header_footer=.false.) write(unit,'("END ",A)') to_upper(trim(this%model(l)%layer%name)) deallocate(input_list, output_list) end do close(unit) end subroutine print !############################################################################### !############################################################################### module subroutine read(this, file) !! Read the network from a file use coreutils, only: icount implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network character(*), intent(in) :: file !! File to read the network from ! Local variables integer :: i, unit, stat, itmp1 !! Loop index integer, dimension(:), allocatable :: input_list, output_list !!! List of input and output layers character(256) :: buffer, err_msg, input_str, output_str !! Buffer for reading lines from file character(20) :: name !! Name of the layer character(2) :: operator_in, operator_out !! Operator for the layer integer :: layer_index !! Index of the layer in the list of layer types if(.not.allocated(list_of_layer_types))then call allocate_list_of_layer_types() end if open(newunit=unit,file=file,action='read') i = 0 card_loop: do i = i + 1 read(unit,'(A)',iostat=stat) buffer if(stat.lt.0)then exit card_loop elseif(stat.gt.0)then call stop_program("error encountered in network read") return end if if(trim(adjustl(buffer)).eq."") cycle card_loop !! check if a tag line if(scan(buffer,'=').ne.0)then write(0,*) "WARNING: unexpected line in read file" write(0,*) trim(buffer) write(0,*) " skipping..." cycle card_loop end if !! check for card name = trim(adjustl(buffer(1:scan(buffer,' ')-1))) if(name.eq."NETWORK_SETTINGS")then call this%read_network_settings(unit) cycle card_loop end if buffer = trim(adjustl(buffer(scan(buffer,' ')+1:))) operator_in = trim(adjustl(buffer(1:scan(buffer,' ')-1))) buffer = trim(adjustl(buffer(scan(buffer,' ')+1:))) input_str = trim(adjustl(buffer(1:scan(buffer,']')))) if(scan(input_str,'[').ne.0)then input_str = & trim(adjustl(input_str(scan(input_str,'[')+1:scan(input_str,']')-1))) itmp1 = icount(input_str) allocate(input_list(itmp1)) read(input_str,*) input_list else allocate(input_list, source = [-1]) end if buffer = trim(adjustl(buffer(scan(buffer,']')+1:))) operator_out = trim(adjustl(buffer(1:scan(buffer,' ')-1))) buffer = trim(adjustl(buffer(scan(buffer,' ')+1:))) output_str = trim(adjustl(buffer(1:scan(buffer,']')))) if(scan(output_str,'[').ne.0)then output_str = & trim(adjustl(output_str(scan(output_str,'[')+1:scan(output_str,']')-1))) itmp1 = icount(output_str) allocate(output_list(itmp1)) read(output_str,*) output_list else allocate(output_list(0)) end if name = trim(adjustl(to_lower(name))) layer_index = & findloc( & [ list_of_layer_types(:)%name ], & name, & dim = 1 & ) if(layer_index.eq.0)then write(err_msg,'("unrecognised card ''",A)') trim(adjustl(buffer)) call stop_program(err_msg) return end if call this%add( & list_of_layer_types(layer_index)%read_ptr(unit), & input_list = input_list, & operator = operator_in & ) if(allocated(input_list)) deallocate(input_list) if(allocated(output_list)) deallocate(output_list) end do card_loop close(unit) end subroutine read !############################################################################### !############################################################################### module subroutine read_network_settings(this, unit) !! Read the network settings from a file use athena__tools_infile, only: assign_val, assign_vec use coreutils, only: to_lower, to_upper, icount implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network integer, intent(in) :: unit !! File unit ! Local variables integer :: stat !! File status integer :: itmp1 !! Temporary integer character(20) :: accuracy_method, loss_method !! Methods for accuracy and loss character(256) :: buffer, tag, err_msg, name_ !! Buffer for reading lines, tag for identifying lines, error message ! Loop over tags in layer card !--------------------------------------------------------------------------- accuracy_method = "" loss_method = "" tag_loop: do ! Check for end of file !------------------------------------------------------------------------ read(unit,'(A)',iostat=stat) buffer if(stat.ne.0)then write(err_msg,'("file encountered error (EoF?) before END ",A)') & to_upper(this%name) call stop_program(err_msg) return end if if(trim(adjustl(buffer)).eq."") cycle tag_loop ! Check for end of layer card !------------------------------------------------------------------------ if(trim(adjustl(buffer)).eq."END NETWORK_SETTINGS")then backspace(unit) exit tag_loop end if tag=trim(adjustl(buffer)) if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) if(scan(buffer,":").ne.0) tag=trim(tag(:scan(tag,":")-1)) ! Read parameters from save file !------------------------------------------------------------------------ select case(trim(tag)) case("ATHENA_VERSION") ! Ignore this tag, it is only for information case("NAME") call assign_val(buffer, name_, itmp1) if(len(trim(adjustl(name_))) .gt. 0)then this%name = trim(adjustl(name_)) end if case("EPOCH") call assign_val(buffer, this%epoch, itmp1) case("BATCH_SIZE") call assign_val(buffer, this%batch_size, itmp1) case("ACCURACY") call assign_val(buffer, this%accuracy_val, itmp1) case("LOSS") call assign_val(buffer, this%loss_val, itmp1) case("ACCURACY_METHOD") call assign_val(buffer, accuracy_method, itmp1) call this%set_accuracy(accuracy_method) case("LOSS_METHOD") call assign_val(buffer, loss_method, itmp1) call this%set_loss(loss_method) case("OPTIMISER") backspace(unit) call this%read_optimiser_settings(unit) case default ! Don't look for "e" due to scientific notation of numbers ! ... i.e. exponent (E+00) if(scan(to_lower(trim(adjustl(buffer))),& 'abcdfghijklmnopqrstuvwxyz').eq.0)then cycle tag_loop elseif(tag(:3).eq.'END')then cycle tag_loop end if write(err_msg,'("Unrecognised line in input file: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end select end do tag_loop ! Check for end of layer card !--------------------------------------------------------------------------- read(unit,'(A)') buffer if(trim(adjustl(buffer)).ne."END NETWORK_SETTINGS")then write(0,*) trim(adjustl(buffer)) write(err_msg,'("END NETWORK_SETTINGS not where expected")') call stop_program(err_msg) return end if end subroutine read_network_settings !------------------------------------------------------------------------------- module subroutine read_optimiser_settings(this, unit) !! Read the optimiser settings from a file use coreutils, only: to_lower, to_upper, icount use athena__optimiser, only: & sgd_optimiser_type, adam_optimiser_type, rmsprop_optimiser_type, & adagrad_optimiser_type, base_optimiser_type implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network integer, intent(in) :: unit !! File unit ! Local variables integer :: stat !! File status character(20) :: optimiser_name !! Name of the optimiser character(256) :: buffer, err_msg, tmp !! Buffer for reading lines, error message ! Read until end of optimiser settings read(unit,'(A)',iostat=stat) buffer if(stat.ne.0)then write(err_msg,'("file encountered error (EoF?) before END ",A)') & to_upper(this%name) call stop_program(err_msg) return end if read(buffer,*) tmp, optimiser_name select case(trim(adjustl(to_lower(optimiser_name)))) case("sgd") this%optimiser = sgd_optimiser_type() case("adam") this%optimiser = adam_optimiser_type() case("rmsprop") this%optimiser = rmsprop_optimiser_type() case("adagrad") this%optimiser = adagrad_optimiser_type() case("","base") this%optimiser = base_optimiser_type() case default write(err_msg,'("Unrecognised optimiser: ",A)') trim(adjustl(optimiser_name)) call stop_program(err_msg) return end select call this%optimiser%read(unit) end subroutine read_optimiser_settings !############################################################################### !############################################################################### module subroutine build_from_onnx( & this, nodes, initialisers, inputs, value_info, verbose & ) !! Build network from ONNX nodes and initialisers use coreutils, only: to_lower implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network type(onnx_node_type), dimension(:), intent(in) :: nodes !! Array of ONNX nodes type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers !! Array of ONNX initialisers type(onnx_tensor_type), dimension(:), intent(in) :: inputs !! Array of ONNX inputs type(onnx_tensor_type), dimension(:), intent(in) :: value_info !! Array of ONNX value infos integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: i, j, k, j_out, layer_index !! Loop indices integer :: verbose_ = 0 !! Verbosity level character(20) :: op_type !! Lowercase op_type character(64) :: tmp_name !! Temporary name for matching character(256) :: err_msg !! Error message integer, dimension(:), allocatable :: input_shape !! Shape of input layer integer, dimension(:), allocatable :: input_list !! List of input layers type(onnx_initialiser_type), dimension(:), allocatable :: init_list !! List of initialisers for a specific node type(onnx_tensor_type), dimension(:), allocatable :: value_info_list !! List of value info tensors verbose_ = 0 if(present(verbose)) verbose_ = verbose if(.not.allocated(list_of_onnx_layer_creators))then call allocate_list_of_onnx_layer_creators() end if do i = 1, size(inputs) input_shape = inputs(i)%dims(2:size(inputs(i)%dims)) call this%add( & input_layer_type(input_shape, index=i) & ) end do ! Loop through nodes and create layers do i = 1, size(nodes) if(verbose_.gt.0) write(*,*) "Processing ONNX node: ", trim(nodes(i)%name), & " (", trim(nodes(i)%op_type), ")" op_type = trim(adjustl(nodes(i)%op_type)) layer_index = & findloc( & [ list_of_onnx_layer_creators(:)%op_type ], & op_type, & dim = 1 & ) if(layer_index.eq.0)then write(err_msg,'("unrecognised op_type ''",A)') trim(adjustl(nodes(i)%op_type)) call stop_program(err_msg) return end if ! find all input layers and initialisers for this node ! ... i.e. check over inputs for name matches j_out = 0 allocate(init_list(0)) allocate(input_list(0)) allocate(value_info_list(0)) do j = 1, size(nodes(i)%inputs) do k = 1, size(initialisers) if(trim(nodes(i)%inputs(j)) .eq. trim(initialisers(k)%name))then init_list = [ init_list, initialisers(k) ] end if end do do k = 1, size(inputs) if(trim(nodes(i)%inputs(j)) .eq. trim(inputs(k)%name))then input_list = [ input_list, k ] end if end do tmp_name = trim(nodes(i)%inputs(j)) if(index(tmp_name, "_output").gt.0) & tmp_name = trim(tmp_name(:index(tmp_name, "_output")-1)) do k = 1, size(nodes) if(trim(tmp_name) .eq. trim(nodes(k)%name))then input_list = [ input_list, k + size(inputs) ] end if end do end do do j = 1, size(nodes(i)%outputs) do k = 1, size(value_info) if(trim(nodes(i)%outputs(j)) .eq. trim(value_info(k)%name))then value_info_list = [ value_info_list, value_info(k) ] end if end do end do if(size(init_list)+size(input_list).ne.size(nodes(i)%inputs))then if(verbose_.gt.0)then write(0,*) "WARNING: not all inputs found for node ", & trim(nodes(i)%name) end if end if ! assume default operator call this%add( & list_of_onnx_layer_creators(layer_index)%create_ptr( & nodes(i), init_list, value_info_list & ), & input_list = input_list & ! operator = operator_in & ) deallocate(input_list) deallocate(init_list) deallocate(value_info_list) end do if(verbose_.gt.0) write(*,*) "ONNX model built with ", this%num_layers, " layers." end subroutine build_from_onnx !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module subroutine add(this, layer, input_list, output_list, operator) !! Add a layer to the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(base_layer_type), intent(in) :: layer !! Layer to add to the network integer, dimension(:), optional, intent(in) :: input_list !! List of input layers integer, dimension(:), optional, intent(in) :: output_list !! List of output layers class(*), optional, intent(in) :: operator !! Operator to use to connect the layers ! Local variables integer :: i, vertex_index !! Loop index integer :: operator_ !! Operator to use to connect the layers character(256) :: err_msg !! Error message integer, dimension(2) :: vertex_indices !! Indices of the vertices to connect type(container_layer_type), allocatable, dimension(:) :: model !! Model to add the layer to if(.not.allocated(this%model))then this%model = [container_layer_type()] this%num_layers = 1 else allocate(model(size(this%model,dim=1)+1)) do i = 1, size(this%model,dim=1) allocate(model(i)%layer, source=this%model(i)%layer) end do call move_alloc(model, this%model) this%num_layers = this%num_layers + 1 end if allocate(this%model(size(this%model,dim=1))%layer, source=layer) this%model(size(this%model,dim=1))%layer%id = this%num_layers operator_ = 1 if(present(operator))then select type(operator) type is(integer) operator_ = operator type is(character(*)) select case(trim(to_lower(operator))) case("||", "concat", "concatenate", "append") operator_ = 1 case("+", "add") operator_ = 2 case("*", "x", "mul", "multiply") operator_ = 3 end select end select end if if(operator_.gt.2.or.operator_.lt.1)then call stop_program("invalid operator") return end if ! edge_index(1) = index of the previous layer ! abs(edge_index(2)) = index of the current layer ! the -ve sign of edge_index(2) indicates that the edge goes from the ! previous layer to the current layer ! i.e. forward pass flows from positive to negative ! adjacency(i,:) is all of the layers that i feeds forward to ! adjacency(:,i) is all of the layers that feed forward to i ! (i.e. the backward pass) this%auto_graph%directed = .true. call this%auto_graph%add_vertex( & feature=[1._real32], id=this%num_layers, update_adjacency=.true. & ) if(present(input_list))then do i = 1, size(input_list), 1 if(input_list(i).eq.0)then vertex_index = 0 elseif( & input_list(i).le.-this%auto_graph%num_vertices .or. & input_list(i).gt.this%auto_graph%num_vertices & )then write(err_msg, & '("input vertex index ",I0," out of range (",I0,":",I0,")")' & ) & input_list(i), & -this%auto_graph%num_vertices +1, & this%auto_graph%num_vertices call stop_program(err_msg) return elseif(input_list(i).lt.0)then vertex_index = this%auto_graph%num_vertices + input_list(i) else vertex_index = findloc( & [this%auto_graph%vertex(:)%id], & input_list(i), 1 & ) end if vertex_indices = [ vertex_index, -this%auto_graph%num_vertices ] call this%auto_graph%add_edge( & index = vertex_indices, & feature = [ 1._real32 ], & id = operator_, & update_adjacency = .true. & ) end do elseif(trim(layer%type).ne."inpt".and.this%auto_graph%num_vertices.gt.1)then vertex_indices = [ & this%auto_graph%num_vertices - 1, & -this%auto_graph%num_vertices & ] call this%auto_graph%add_edge( & index = vertex_indices, & feature = [ 1._real32 ], & id = operator_, & update_adjacency = .true. & ) end if if(present(output_list))then do i = 1, size(output_list), 1 vertex_index = findloc( & [this%auto_graph%vertex(:)%id], & output_list(i), 1 & ) vertex_indices = [ this%auto_graph%num_vertices, -vertex_index ] call this%auto_graph%add_edge( & index = vertex_indices, & feature = [ 1._real32 ], & id = operator_, & update_adjacency = .true. & ) end do end if end subroutine add !############################################################################### !############################################################################### module function network_setup( & layers, optimiser, loss_method, accuracy_method, & metrics, batch_size & ) result(network) !! Setup the network implicit none ! Arguments type(container_layer_type), dimension(:), intent(in) :: layers !! Layers to add to the network class(base_optimiser_type), optional, intent(in) :: optimiser !! Optimiser to use for training class(*), optional, intent(in) :: loss_method !! Loss method character(*), optional, intent(in) :: accuracy_method !! Accuracy method class(*), dimension(..), optional, intent(in) :: metrics !! Metrics integer, optional, intent(in) :: batch_size !! Batch size type(network_type) :: network !! Network to setup ! Local variables integer :: l !! Loop index !--------------------------------------------------------------------------- ! Handle optional arguments !--------------------------------------------------------------------------- if(present(loss_method)) call network%set_loss(loss_method) if(present(accuracy_method)) call network%set_accuracy(accuracy_method) if(present(metrics)) call network%set_metrics(metrics) if(present(batch_size)) network%batch_size = batch_size network%auto_graph%directed = .true. !--------------------------------------------------------------------------- ! Add layers to network !--------------------------------------------------------------------------- do l = 1, size(layers) call network%add(layers(l)%layer) end do !--------------------------------------------------------------------------- ! Compile network if optimiser present !--------------------------------------------------------------------------- if(present(optimiser)) call network%compile(optimiser) end function network_setup !############################################################################### !############################################################################### module subroutine set_metrics(this, metrics) !! Set the metrics for the network use coreutils, only: to_lower implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: metrics !! Metrics ! Local variables integer :: i !! Loop index this%metrics%active = .false. this%metrics(1)%key = "loss" this%metrics(2)%key = "accuracy" this%metrics%threshold = 1.E-1_real32 select rank(metrics) #if defined(GFORTRAN) rank(0) select type(metrics) type is(character(*)) ! ERROR: ifort cannot identify that the rank of metrics has been ... ! ... identified as scalar here where(to_lower(trim(metrics)).eq.this%metrics%key) this%metrics%active = .true. end where end select #endif rank(1) select type(metrics) type is(character(*)) do i=1,size(metrics,1) where(to_lower(trim(metrics(i))).eq.this%metrics%key) this%metrics%active = .true. end where end do type is(metric_dict_type) if(size(metrics,1).eq.2)then this%metrics(:2) = metrics(:2) else call stop_program("invalid length array for metric_dict_type") return end if end select rank default call stop_program("provided metrics rank in compile invalid") return end select end subroutine set_metrics !############################################################################### !############################################################################### module subroutine set_loss(this, loss_method, verbose) !! Set the loss method for the network use coreutils, only: to_lower use athena__loss, only: & bce_loss_type, & cce_loss_type, & mae_loss_type, & mse_loss_type, & nll_loss_type, & huber_loss_type implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), intent(in) :: loss_method !! Loss method integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: verbose_ !! Verbosity level character(len=:), allocatable :: loss_method_ !! Loss method character(256) :: err_msg !! Error message if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if !--------------------------------------------------------------------------- ! Handle analogous definitions !--------------------------------------------------------------------------- !--------------------------------------------------------------------------- ! Set loss method !--------------------------------------------------------------------------- select type(loss_method) class is(base_loss_type) this%loss = loss_method if(verbose_.gt.0) write(*,*) "Loss method: ", trim(loss_method%name) loss_method_ = trim(loss_method%name) type is(character(*)) loss_method_ = to_lower(loss_method) select case(loss_method) case("binary_crossentropy") loss_method_ = "bce" case("categorical_crossentropy") loss_method_ = "cce" case("mean_absolute_error") loss_method_ = "mae" case("mean_squared_error") loss_method_ = "mse" case("negative_log_likelihood") loss_method_ = "nll" case("huber") loss_method_ = "hub" end select select case(loss_method_) case("bce") this%loss = bce_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Binary Cross Entropy" case("cce") this%loss = cce_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy" case("mae") this%loss = mae_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Mean Absolute Error" case("mse") this%loss = mse_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Mean Squared Error" case("nll") this%loss = nll_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Negative Log Likelihood" case("hub") this%loss = huber_loss_type() if(verbose_.gt.0) write(*,*) "Loss method: Huber" case default write(err_msg,'(A)') & "No loss method provided" // & achar(13) // achar(10) // & "Failed loss method: "//trim(loss_method_) call stop_program(trim(err_msg)) return end select end select this%loss_method = loss_method_ end subroutine set_loss !############################################################################### !############################################################################### module subroutine set_accuracy(this, accuracy_method, verbose) !! Set the accuracy method for the network use coreutils, only: to_lower use athena__accuracy, only: & categorical_score, & mae_score, & mse_score, & rmse_score, & r2_score implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network character(*), intent(in) :: accuracy_method !! Accuracy method integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: verbose_ !! Verbosity level character(len=:), allocatable :: accuracy_method_ !! Accuracy method character(256) :: err_msg !! Error message if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if !--------------------------------------------------------------------------- ! Handle analogous definitions !--------------------------------------------------------------------------- accuracy_method_ = to_lower(accuracy_method) select case(accuracy_method) case("categorical") accuracy_method_ = "cat" case("mean_absolute_error") accuracy_method_ = "mae" case("mean_squared_error") accuracy_method_ = "mse" case("root_mean_squared_error") accuracy_method_ = "rmse" case("r2", "r^2", "r squared") accuracy_method_ = "r2" end select !--------------------------------------------------------------------------- ! Set accuracy method !--------------------------------------------------------------------------- select case(accuracy_method_) case("cat") this%get_accuracy => categorical_score if(verbose_.gt.0) write(*,*) "Accuracy method: Categorical " case("mae") this%get_accuracy => mae_score if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Absolute Error" case("mse") this%get_accuracy => mse_score if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Squared Error" case("rmse") this%get_accuracy => rmse_score if(verbose_.gt.0) write(*,*) "Accuracy method: Root Mean Squared Error" case("r2") this%get_accuracy => r2_score if(verbose_.gt.0) write(*,*) "Accuracy method: R^2" case default write(err_msg,'(A)') & "No accuracy method provided" // & achar(13) // achar(10) // & "Failed accuracy method: "//trim(accuracy_method_) call stop_program(trim(err_msg)) return end select this%accuracy_method = accuracy_method_ end subroutine set_accuracy !############################################################################### !############################################################################### module subroutine reset(this) !! Reset the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network this%epoch = 0 this%accuracy_val = 0._real32 this%loss_val = huge(1._real32) this%batch_size = 0 this%num_layers = 0 this%num_outputs = 0 if(allocated(this%optimiser)) deallocate(this%optimiser) call this%set_metrics(["loss"]) if(allocated(this%model)) deallocate(this%model) if(allocated(this%loss)) deallocate(this%loss) this%get_accuracy => null() if(allocated(this%vertex_order)) deallocate(this%vertex_order) if(allocated(this%leaf_vertices)) deallocate(this%leaf_vertices) if(allocated(this%root_vertices)) deallocate(this%root_vertices) this%auto_graph = graph_type(directed=.true.) end subroutine reset !############################################################################### !############################################################################### module subroutine compile( & this, optimiser, loss_method, accuracy_method, & metrics, batch_size, verbose & ) !! Compile the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(base_optimiser_type), optional, intent(in) :: optimiser !! Optimiser to use for training class(*), optional, intent(in) :: loss_method !! Loss method character(*), optional, intent(in) :: accuracy_method !! Accuracy method class(*), dimension(..), optional, intent(in) :: metrics !! Metrics integer, optional, intent(in) :: batch_size !! Batch size integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: i, j, k, child_id, parent_id, layer_id, num_inputs, input_rank !! Loop index integer :: parent_vertex, vertex_idx !! Vertex indices integer :: layer_rank, parent_rank, operator !! Ranks of layers integer :: verbose_ = 0 !! Verbosity level logical :: use_graph_input = .false. !! Boolean whether to use graph input logical :: l_flatten_child, l_set_input_shape !! Booleans whether to flatten child or set input shape integer, dimension(:), allocatable :: input_shape, & child_vertices, parent_vertices, output_ranks, parent_ids !! Shapes of the input and output of the layers integer, dimension(:,:), allocatable :: merge_shape !! Shapes of the inputs to merge layers class(base_layer_type), allocatable :: & t_input_layer, t_flatten_layer, t_merge_layer !! Temporary input, flatten, and merge layers !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Initialise metrics !--------------------------------------------------------------------------- if(present(metrics)) call this%set_metrics(metrics) !--------------------------------------------------------------------------- ! Initialise loss and accuracy methods !--------------------------------------------------------------------------- if(present(loss_method)) call this%set_loss(loss_method, verbose_) if(present(accuracy_method)) & call this%set_accuracy(accuracy_method, verbose_) !--------------------------------------------------------------------------- ! Check for input layers at root vertices !--------------------------------------------------------------------------- this%auto_graph%directed = .true. call this%build_root_vertices() do i = 1, size(this%root_vertices) layer_id = this%auto_graph%vertex(this%root_vertices(i))%id if(.not.allocated(this%model(layer_id)%layer%input_shape))then call stop_program("input_shape of first layer not defined") return end if use_graph_input = .false. select type( root => this%model(layer_id)%layer) class is(input_layer_type) cycle class is(learnable_layer_type) input_shape = root%input_shape use_graph_input = root%use_graph_input class default input_shape = root%input_shape end select t_input_layer = input_layer_type(& input_shape = input_shape, & index = i, & use_graph_input = use_graph_input, & verbose=verbose_ & ) call this%add( & t_input_layer, output_list = [ this%model(layer_id)%layer%id ] & ) ! NEED TO CALL layer%init? deallocate(input_shape) deallocate(t_input_layer) this%root_vertices(i) = this%num_layers if(i.eq.1)then do j = 1, this%auto_graph%num_edges if(this%auto_graph%edge(j)%index(1).eq.0) & this%auto_graph%edge(j)%index(1) = this%num_layers end do end if end do call this%auto_graph%generate_adjacency() !--------------------------------------------------------------------------- ! Identify whether input is graph type !--------------------------------------------------------------------------- if( & this%model( & this%auto_graph%vertex(this%root_vertices(1))%id & )%layer%use_graph_input & )then this%use_graph_input = .true. else this%use_graph_input = .false. end if !--------------------------------------------------------------------------- ! Check for zero input rank layers !--------------------------------------------------------------------------- do i = 1, size(this%auto_graph%vertex, dim = 1) layer_id = this%auto_graph%vertex(i)%id if(this%model(layer_id)%layer%input_rank.eq.0)then parent_ids = pack( & [ ( & this%auto_graph%vertex(j)%id, & j = 1, size(this%auto_graph%adjacency(:,i)) & ) ], & this%auto_graph%adjacency(:,i) .ne. 0 & ) if(size(parent_ids).eq.0) cycle output_ranks = [ ( this%model(parent_ids(j))%layer%output_rank, & j=1,size(parent_ids) ) ] if(any(output_ranks.ne.output_ranks(1)))then write(0,*) output_ranks call stop_program( & "input rank of layer "//trim(this%model(layer_id)%layer%name) // & " is zero, but multiple parents with different output ranks" & ) return end if input_rank = this%model(parent_ids(1))%layer%output_rank call this%model(layer_id)%layer%set_rank( & input_rank = input_rank, & output_rank = input_rank & ) end if end do !--------------------------------------------------------------------------- ! Check for required flatten layers !--------------------------------------------------------------------------- i = 0 flatten_loop: do i = i + 1 if(i.gt.this%auto_graph%num_vertices) exit flatten_loop layer_id = this%auto_graph%vertex(i)%id ! get all child vertices child_vertices = pack( & [(j, j=1,size(this%auto_graph%adjacency(i,:)))], & this%auto_graph%adjacency(i,:) .ne. 0 & ) child_loop: do j = 1, size(child_vertices) ! Get layer ID (needed for add() function's output_list parameter) child_id = this%auto_graph%vertex(child_vertices(j))%id if(trim(this%model(layer_id)%layer%type).eq."flat") cycle child_loop if( this%model(layer_id)%layer%output_rank .eq. & this%model(child_id)%layer%input_rank ) cycle child_loop if(this%model(layer_id)%layer%output_rank.eq.0) cycle child_loop ! get all parent vertices of the child vertex parent_vertices = pack( & [(k, k=1,size(this%auto_graph%adjacency(:,child_vertices(j))))], & this%auto_graph%adjacency(:,child_vertices(j)) .ne. 0 & ) l_flatten_child = .true. do k = 1, size(parent_vertices) parent_id = this%auto_graph%vertex(parent_vertices(k))%id !check if ranks match, rather than input and output shapes if( this%model(layer_id)%layer%output_rank .ne. & this%model(parent_id)%layer%input_rank & ) l_flatten_child = .false. end do t_flatten_layer = flatten_layer_type( & input_rank = this%model(layer_id)%layer%output_rank & ) if(l_flatten_child)then ! add flatten layer in the place of the child layer operator = this%auto_graph%edge( & this%auto_graph%adjacency(parent_vertices(1),child_vertices(j)) & )%id call this%auto_graph%remove_edges( & indices = [ & this%auto_graph%adjacency( & parent_vertices(:),child_vertices(j) & ) & ] & ) call this%add( & t_flatten_layer, & input_list=[parent_vertices(:)], output_list=[child_id], & operator=operator & ) else ! add flatten layer between the current layer and the child layer operator = this%auto_graph%edge( & this%auto_graph%adjacency(i,child_vertices(j)) & )%id call this%auto_graph%remove_edges( & indices = [this%auto_graph%adjacency(i,child_vertices(j))] & ) call this%add( & t_flatten_layer, input_list = [layer_id], & output_list = [child_id], & operator=operator & ) end if deallocate(t_flatten_layer) deallocate(child_vertices) cycle flatten_loop end do child_loop deallocate(child_vertices) end do flatten_loop call this%build_vertex_order() !--------------------------------------------------------------------------- ! Check for required merge layers !--------------------------------------------------------------------------- i = 0 merge_loop: do i = i + 1 if(i.gt.this%auto_graph%num_vertices) exit merge_loop layer_id = this%auto_graph%vertex(i)%id if(this%model(layer_id)%layer%type.eq."merg") cycle merge_loop ! get all child vertices parent_vertices = pack( & [(j, j=1,size(this%auto_graph%adjacency(:,i)))], & this%auto_graph%adjacency(:,i) .ne. 0 & ) if(size(parent_vertices).le.1) cycle merge_loop ! get edge id for merge layer operator = this%auto_graph%edge( & this%auto_graph%adjacency(parent_vertices(1),i) & )%id ! remove edges from parents to this layer do j = 1, size(parent_vertices) call this%auto_graph%remove_edges( & indices = [this%auto_graph%adjacency(parent_vertices(j),i)] & ) end do parent_ids = & [ ( & this%auto_graph%vertex(parent_vertices(k))%id, & k = 1, size(parent_vertices) & ) ] select case(operator) case(1) ! concatenate t_merge_layer = concat_layer_type( & input_layer_ids = parent_ids, & input_rank = this%model(layer_id)%layer%input_rank & ) case(2) ! add t_merge_layer = add_layer_type( & input_layer_ids = parent_ids, & input_rank = this%model(layer_id)%layer%input_rank & ) ! case(3) ! multiply ! t_merge_layer = multiply_layer_type( & ! input_layer_ids = parent_vertices & ! ) case default write(0,*) "invalid merge operator: ", operator call stop_program("invalid merge operator") return end select t_merge_layer%use_graph_input = this%model(layer_id)%layer%use_graph_input t_merge_layer%use_graph_output = t_merge_layer%use_graph_input call this%add( & t_merge_layer, & input_list = parent_ids, & output_list = [layer_id] & ) deallocate(t_merge_layer) end do merge_loop call this%build_vertex_order() ! Update number of layers !--------------------------------------------------------------------------- this%num_layers = size(this%model,dim=1) !--------------------------------------------------------------------------- ! Initialise layers !--------------------------------------------------------------------------- do i = 1, size(this%vertex_order, dim = 1) vertex_idx = this%vertex_order(i) layer_id = this%auto_graph%vertex(vertex_idx)%id if(allocated(this%model(layer_id)%layer%input_shape))then l_set_input_shape = .false. else l_set_input_shape = .true. end if if(l_set_input_shape)then layer_rank = this%model(layer_id)%layer%input_rank parent_rank = 0 select type( layer => this%model(layer_id)%layer ) class is(merge_layer_type) ! loop over all parent layers allocate( & merge_shape( & this%model(layer_id)%layer%input_rank, & size(layer%input_layer_ids) & ) & ) do k = 1, size(layer%input_layer_ids) merge_shape(:,k) = & this%model(layer%input_layer_ids(k))%layer%output_shape end do input_shape = layer%calc_input_shape(merge_shape) deallocate(merge_shape) class default allocate( & input_shape(this%model(layer_id)%layer%input_rank), & source = 0 & ) do j = 1, this%auto_graph%num_vertices if(this%auto_graph%adjacency(j,vertex_idx).eq.0) cycle parent_id = this%auto_graph%vertex(j)%id parent_rank = this%model(parent_id)%layer%output_rank if(layer_rank .eq. parent_rank)then input_shape(:) = input_shape(:) + & this%model(parent_id)%layer%output_shape elseif(layer_rank .eq. 1)then input_shape(1) = input_shape(1) + & product( this%model(parent_id)%layer%output_shape ) end if end do end select call this%model(layer_id)%layer%init( & input_shape = input_shape, & verbose = verbose_ & ) deallocate(input_shape) end if if(verbose_.gt.0)then write(*,*) "layer: ", layer_id, this%model(layer_id)%layer%type write(*,*) this%model(layer_id)%layer%input_shape write(*,*) this%model(layer_id)%layer%output_shape end if end do ! Set number of outputs !--------------------------------------------------------------------------- this%num_outputs = 0 call this%build_leaf_vertices() do i = 1, size(this%leaf_vertices,1) this%num_outputs = this%num_outputs + & product( & this%model( & this%auto_graph%vertex(this%leaf_vertices(i))%id & )%layer%output_shape & ) end do if( & this%model( & this%auto_graph%vertex(this%leaf_vertices(1))%id & )%layer%use_graph_output & )then this%use_graph_output = .true. else this%use_graph_output = .false. end if !--------------------------------------------------------------------------- ! Confirm input_shape of each layer matches data going into it !--------------------------------------------------------------------------- do i = 1, size(this%vertex_order, dim = 1) vertex_idx = this%vertex_order(i) layer_id = this%auto_graph%vertex(vertex_idx)%id if(this%model(layer_id)%layer%type.eq."inpt") cycle ! Get all parent vertices that feed into this layer parent_vertices = pack( & [(j, j=1,size(this%auto_graph%adjacency(:,vertex_idx)))], & this%auto_graph%adjacency(:,vertex_idx) .ne. 0 & ) if(size(parent_vertices).eq.0) cycle select type( layer => this%model(layer_id)%layer ) class is(merge_layer_type) operator = layer%merge_mode class default if(size(parent_vertices).gt.1)then call stop_program( & "layer "//trim(layer%name)// & " is not a merge layer but has multiple inputs" & ) return end if end select ! Calculate expected input size from parent layers num_inputs = 0 do j = 1, size(parent_vertices) parent_vertex = parent_vertices(j) select case(operator) case(1) ! pointwise - all inputs should have same size if(num_inputs.eq.0)then if(this%model(layer_id)%layer%use_graph_input)then num_inputs = this%model(parent_vertex)%layer%output_shape(1) else num_inputs = product(this%model(parent_vertex)%layer%output_shape) end if end if case(2) ! concatenate if(this%model(layer_id)%layer%use_graph_input)then num_inputs = num_inputs + & this%model(parent_vertex)%layer%output_shape(1) else num_inputs = num_inputs + & product(this%model(parent_vertex)%layer%output_shape) end if end select end do ! Verify calculated input size matches layer's expected input size if(this%model(layer_id)%layer%use_graph_input)then if(num_inputs.ne.this%model(layer_id)%layer%input_shape(1) .and. & num_inputs.ne.0)then write(*,*) "Expected:", num_inputs, "Got:", & this%model(layer_id)%layer%input_shape(1) call stop_program( & "input_shape of layer "//& trim(this%model(layer_id)%layer%name)// & " does not match data going into it" & ) end if else if(num_inputs.ne.product(this%model(layer_id)%layer%input_shape) .and. & num_inputs.ne.0)then write(*,*) "Expected:", num_inputs, "Got:", & product(this%model(layer_id)%layer%input_shape) call stop_program( & "input_shape of layer "//& trim(this%model(layer_id)%layer%name)// & " does not match data going into it" & ) end if end if end do !--------------------------------------------------------------------------- ! Initialise optimiser !--------------------------------------------------------------------------- this%num_params = this%get_num_params() if(present(optimiser))then this%optimiser = optimiser end if if(.not.allocated(this%optimiser))then call stop_program("No optimiser is defined for the network") return else call this%optimiser%init(num_params=this%num_params) end if !--------------------------------------------------------------------------- ! Pre-compute forward pass navigation !--------------------------------------------------------------------------- block integer :: nv, l_idx, v_idx, lid, parent_v nv = size(this%vertex_order, 1) if(allocated(this%fwd_layer_id)) deallocate(this%fwd_layer_id) if(allocated(this%fwd_num_inputs)) deallocate(this%fwd_num_inputs) if(allocated(this%fwd_parent_id)) deallocate(this%fwd_parent_id) if(allocated(this%fwd_layer_type)) deallocate(this%fwd_layer_type) allocate(this%fwd_layer_id(nv)) allocate(this%fwd_num_inputs(nv)) allocate(this%fwd_parent_id(nv)) allocate(this%fwd_layer_type(nv)) this%fwd_parent_id = 0 do l_idx = 1, nv v_idx = this%vertex_order(l_idx) lid = this%auto_graph%vertex(v_idx)%id this%fwd_layer_id(l_idx) = lid this%fwd_num_inputs(l_idx) = & count(this%auto_graph%adjacency(:,v_idx).gt.0) if(this%fwd_num_inputs(l_idx).eq.1)then parent_v = maxloc( & this%auto_graph%adjacency(:,v_idx), dim=1) this%fwd_parent_id(l_idx) = & this%auto_graph%vertex(parent_v)%id end if ! Determine layer type: 0=input, 1=merge, 2=default select type(layer => this%model(lid)%layer) class is(input_layer_type) this%fwd_layer_type(l_idx) = 0 class is(merge_layer_type) this%fwd_layer_type(l_idx) = 1 class default this%fwd_layer_type(l_idx) = 2 end select end do end block !--------------------------------------------------------------------------- ! Pre-compute parameter segment layout !--------------------------------------------------------------------------- block integer :: l_idx, p_idx, seg_count, s_idx, e_idx ! First pass: count segments seg_count = 0 do l_idx = 1, this%num_layers select type(current => this%model(l_idx)%layer) class is(learnable_layer_type) seg_count = seg_count + size(current%params) end select end do this%param_num_segments = seg_count if(allocated(this%param_seg_layer)) deallocate(this%param_seg_layer) if(allocated(this%param_seg_pidx)) deallocate(this%param_seg_pidx) if(allocated(this%param_seg_start)) deallocate(this%param_seg_start) if(allocated(this%param_seg_end)) deallocate(this%param_seg_end) allocate(this%param_seg_layer(seg_count)) allocate(this%param_seg_pidx(seg_count)) allocate(this%param_seg_start(seg_count)) allocate(this%param_seg_end(seg_count)) ! Second pass: fill layout seg_count = 0 e_idx = 0 do l_idx = 1, this%num_layers select type(current => this%model(l_idx)%layer) class is(learnable_layer_type) do p_idx = 1, size(current%params) seg_count = seg_count + 1 s_idx = e_idx + 1 e_idx = e_idx + size(current%params(p_idx)%val, 1) this%param_seg_layer(seg_count) = l_idx this%param_seg_pidx(seg_count) = p_idx this%param_seg_start(seg_count) = s_idx this%param_seg_end(seg_count) = e_idx end do end select end do end block !--------------------------------------------------------------------------- ! Set batch size, if provided !--------------------------------------------------------------------------- if(present(batch_size)) this%batch_size = batch_size end subroutine compile !############################################################################### !############################################################################### module subroutine set_batch_size(this, batch_size) !! Set the batch size for the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network integer, intent(in) :: batch_size !! Batch size ! Local variables integer :: l !! Loop index this%batch_size = batch_size end subroutine set_batch_size !############################################################################### !############################################################################### module subroutine reset_state(this) !! Reset the hidden state of all layers in the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network ! Local variables integer :: l !! Loop index do l = 1, size(this%model, dim = 1) select type( layer => this%model(l)%layer ) class is(recurrent_layer_type) call layer%reset_state() end select end do end subroutine reset_state !############################################################################### !############################################################################### module subroutine set_training_mode(this, mode_store, layer_indices) !! Put the network in training mode. implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network logical, dimension(:), allocatable, intent(out), optional :: mode_store !! Optional array to store the training mode of each layer integer, dimension(:), intent(in), optional :: layer_indices !! Optional array of layer indices to set to training mode. !! If not provided, all layers will be set to training mode. ! Local variables integer :: l !! Loop index if(.not.allocated(this%model)) return if(present(mode_store)) allocate(mode_store(this%num_layers)) do l = 1, this%num_layers if(present(mode_store)) mode_store(l) = this%model(l)%layer%inference this%model(l)%layer%inference = .false. if(present(layer_indices))then if(any(layer_indices.eq.l))then this%model(l)%layer%inference = .false. end if else this%model(l)%layer%inference = .false. end if end do end subroutine set_training_mode !------------------------------------------------------------------------------- module subroutine set_inference_mode(this, mode_store, layer_indices) !! Put the network in inference mode. implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network logical, dimension(:), allocatable, intent(out), optional :: mode_store !! Optional array to store the training mode of each layer integer, dimension(:), intent(in), optional :: layer_indices !! Optional array of layer indices to set to inference mode. !! If not provided, all layers will be set to inference mode. ! Local variables integer :: l !! Loop index if(.not.allocated(this%model)) return if(present(mode_store)) allocate(mode_store(this%num_layers)) do l = 1, this%num_layers if(present(mode_store)) mode_store(l) = this%model(l)%layer%inference if(present(layer_indices))then if(any(layer_indices.eq.l))then this%model(l)%layer%inference = .true. end if else this%model(l)%layer%inference = .true. end if end do end subroutine set_inference_mode !------------------------------------------------------------------------------- module subroutine restore_mode(this, mode_store) !! Restore the training/inference mode of each layer from a stored array. implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network logical, dimension(:), intent(in) :: mode_store !! Array storing the mode of each layer !! .true. = inference, .false. = training ! Local variables integer :: l !! Loop index if(.not.allocated(this%model)) return if(size(mode_store) .ne. this%num_layers)then call stop_program("mode_store size does not match number of layers") return end if do l = 1, this%num_layers this%model(l)%layer%inference = mode_store(l) end do end subroutine restore_mode !############################################################################### !############################################################################### module function layer_from_id(this, id) result(layer) !! Get layer from its ID implicit none ! Arguments class(network_type), intent(in), target :: this !! Instance of network integer, intent(in) :: id !! Layer ID class(base_layer_type), pointer :: layer !! Layer ! Local variables integer :: i, itmp1 !! Loop index itmp1 = 0 do i = 1, size(this%model, dim = 1) if(this%model(i)%layer%id.eq.id)then if(itmp1.ne.0)then call stop_program("multiple layers with same ID found") return end if layer => this%model(i)%layer itmp1 = itmp1 + 1 end if end do end function layer_from_id !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### #ifdef __flang__ module function get_sample_flang( & input, start_index, end_index, batch_size & ) result(sample) !! Get samples of batch size from a real array implicit none ! Arguments integer, intent(in) :: start_index, end_index !! Start and end indices integer, intent(in) :: batch_size !! Batch size real(real32), dimension(..), intent(in) :: input !! Input array real(real32), allocatable :: sample(:,:) !! Sample array integer :: num_samples num_samples = end_index - start_index + 1 select rank(input) rank(2) sample = input(:, start_index:end_index) rank(3) sample = reshape(input(:, :, start_index:end_index), & [size(input,1) * size(input,2), num_samples]) rank(4) sample = reshape(input(:, :, :, start_index:end_index), & [size(input,1) * size(input,2) * size(input,3), num_samples]) rank(5) sample = reshape(input(:, :, :, :, start_index:end_index), & [size(input,1) * size(input,2) * size(input,3) * size(input,4), & num_samples]) rank(6) sample = reshape(input(:, :, :, :, :, start_index:end_index), & [size(input,1) * size(input,2) * size(input,3) * size(input,4) * & size(input,5), num_samples]) rank default allocate(sample(0, 0)) end select end function get_sample_flang #else module function get_sample_ptr( & input, start_index, end_index, batch_size & ) result(sample_ptr) !! Get samples of batch size from a real array implicit none ! Arguments integer, intent(in) :: start_index, end_index !! Start and end indices integer, intent(in) :: batch_size !! Batch size real(real32), dimension(..), intent(in), target :: input !! Input array real(real32), pointer :: sample_ptr(:,:) !! Pointer to sample select rank(input) rank(2) sample_ptr(1:size(input(:,1)),1:end_index-start_index+1) => & input(:,start_index:end_index) rank(3) sample_ptr(1:size(input(:,:,1)),1:end_index-start_index+1) => & input(:,:,start_index:end_index) rank(4) sample_ptr(1:size(input(:,:,:,1)),1:end_index-start_index+1) => & input(:,:,:,start_index:end_index) rank(5) sample_ptr(1:size(input(:,:,:,:,1)),1:end_index-start_index+1) => & input(:,:,:,:,start_index:end_index) rank(6) sample_ptr(1:size(input(:,:,:,:,:,1)),1:end_index-start_index+1) => & input(:,:,:,:,:,start_index:end_index) rank default sample_ptr => null() end select end function get_sample_ptr #endif !------------------------------------------------------------------------------- module function get_sample_array( & input, start_index, end_index, batch_size, as_graph & ) result(sample) !! Get samples of batch size from a derived type array implicit none ! Arguments integer, intent(in) :: start_index, end_index !! Start and end indices integer, intent(in) :: batch_size !! Batch size class(array_type), dimension(:,:), intent(in) :: input !! Input array logical, intent(in) :: as_graph !! Boolean whether to treat the input as a graph type(array_type), dimension(:,:), allocatable :: sample !! Sample array ! Local variables integer :: i, j !! Loop index if(as_graph)then allocate(sample(size(input,1), batch_size)) do i = 1, size(input,1) do j = start_index, end_index, 1 sample(i, j - start_index + 1)%val = input(i, j)%val end do end do else allocate(sample(size(input,1), size(input,2))) do i = 1, size(input,1) do j = 1, size(input,2) call sample(i,j)%allocate(array_shape=[input(i,j)%shape, & end_index - start_index + 1]) sample(i,j)%val = input(i,j)%val(:,start_index:end_index) end do end do end if end function get_sample_array !------------------------------------------------------------------------------- module function get_sample_graph1d( & input, start_index, end_index, batch_size & ) result(sample) !! Get samples of batch size from a graph implicit none ! Arguments integer, intent(in) :: start_index, end_index !! Start and end indices integer, intent(in) :: batch_size !! Batch size class(graph_type), dimension(:), intent(in) :: input !! Input array type(graph_type), dimension(1, batch_size) :: sample !! Sample array sample(1,1:batch_size) = input(start_index:end_index) end function get_sample_graph1d !------------------------------------------------------------------------------- module function get_sample_graph2d( & input, start_index, end_index, batch_size & ) result(sample) !! Get samples of batch size from a graph implicit none ! Arguments integer, intent(in) :: start_index, end_index !! Start and end indices integer, intent(in) :: batch_size !! Batch size class(graph_type), dimension(:,:), intent(in) :: input !! Input array type(graph_type), dimension(size(input,1), batch_size) :: sample !! Sample array sample(1:size(input,1),1:batch_size) = input(:,start_index:end_index) end function get_sample_graph2d !############################################################################### !############################################################################### pure module function get_num_params(this) result(num_params) !! Get the number of learnable parameters in the network implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network integer :: num_params !! Number of parameters ! Local variables integer :: l, i !! Loop index num_params = 0 do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) do i = 1, size(current%params) num_params = num_params + size(current%params(i)%val, 1) end do end select end do end function get_num_params !############################################################################### !############################################################################### pure module function get_params(this) result(params) !! Get learnable parameters implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network real(real32), dimension(this%num_params) :: params !! Parameters ! Local variables integer :: l, i, start_idx, end_idx !! Loop index start_idx = 0 end_idx = 0 do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) do i = 1, size(current%params) start_idx = end_idx + 1 end_idx = end_idx + size(current%params(i)%val, 1) params(start_idx:end_idx) = current%params(i)%val(:,1) end do end select end do end function get_params !############################################################################### !############################################################################### module subroutine set_params(this, params) !! Set learnable parameters implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network real(real32), dimension(this%num_params), intent(in) :: params !! Parameters ! Local variables integer :: l, i, start_idx, end_idx !! Loop index start_idx = 0 end_idx = 0 do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) do i = 1, size(current%params) start_idx = end_idx + 1 end_idx = end_idx + size(current%params(i)%val, 1) current%params(i)%val(:,1) = params(start_idx:end_idx) end do ! call current%set_params(params(start_idx:end_idx)) end select end do end subroutine set_params !############################################################################### !############################################################################### pure module function get_gradients(this) result(gradients) !! Get gradients implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network real(real32), dimension(this%num_params) :: gradients !! Gradients ! Local variables integer :: l, i, start_idx, end_idx !! Loop index start_idx = 0 end_idx = 0 do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) do i = 1, size(current%params) if(associated(current%params(i)%grad))then start_idx = end_idx + 1 end_idx = end_idx + size(current%params(i)%val, 1) gradients(start_idx:end_idx) = [ & sum(current%params(i)%grad%val, dim=2) / & real(size(current%params(i)%grad%val, dim=2), real32) & ] end if end do end select end do call this%optimiser%clip_dict%apply(size(gradients),gradients) end function get_gradients !############################################################################### !############################################################################### module subroutine set_gradients(this, gradients) !! Set gradients implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network real(real32), dimension(..), intent(in) :: gradients !! Gradients ! Local variables integer :: l, start_idx, end_idx !! Loop index start_idx = 0 end_idx = 0 do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) start_idx = end_idx + 1 end_idx = end_idx + current%num_params select rank(gradients) rank(0) call current%set_gradients(gradients) rank(1) call current%set_gradients(gradients(start_idx:end_idx)) end select end select end do end subroutine set_gradients !############################################################################### !############################################################################### module subroutine reset_gradients(this) !! Reset gradients implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network ! Local variables integer :: l, i !! Loop index do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) do i = 1, size(current%params) call current%params(i)%zero_grad() end do end select end do end subroutine reset_gradients !############################################################################### !############################################################################### module function get_output_shape(this) result(output_shape) !! Get the output of the network implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network integer, dimension(2) :: output_shape !! Output shape ! Local variables integer :: i, layer_idx !! Loop indices ! array data: [ layer idx, empty ] ! graph data: [ vertex/edge idx, sample idx] if(this%use_graph_output)then output_shape = [2, this%batch_size] do i = 1, size(this%leaf_vertices,1), 1 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id if(size(this%model(layer_idx)%layer%output,2).ne.this%batch_size)then call stop_program( & "Inconsistent batch size in output layers" & ) return end if output_shape(1) = output_shape(1) + & size( this%model(layer_idx)%layer%output, 1 ) end do else output_shape = [0, 1] do i = 1, size(this%leaf_vertices,1) layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id if(size(this%model(layer_idx)%layer%output,2).ne.1)then call stop_program( & "Inconsistent size of dimension 2 in output layers" & ) return end if output_shape(1) = & output_shape(1) + size( this%model(layer_idx)%layer%output, 1 ) end do end if end function get_output_shape !------------------------------------------------------------------------------- module function get_output(this) result(output) !! Get the output of the network implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network type(array_type), dimension(:,:), allocatable :: output !! Output ! Local variables integer :: i, start_idx, end_idx, layer_idx, output_id !! Loop indices integer, dimension(2) :: output_shape !! Output shape integer, dimension(this%num_outputs) :: output_ids !! Output IDs ! array data: [ layer idx, empty ] ! graph data: [ vertex/edge idx, sample idx] if(this%use_graph_output)then output_shape = [2, this%batch_size] do i = 1, size(this%leaf_vertices,1), 1 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id if(size(this%model(layer_idx)%layer%output,2).ne.this%batch_size)then call stop_program( & "Inconsistent batch size in output layers" & ) return end if output_id = this%model(layer_idx)%layer%id output_ids(output_id) = size( this%model(layer_idx)%layer%output, 1 ) output_shape(1) = output_shape(1) + output_ids(output_id) end do allocate(output(output_shape(1), output_shape(2))) do i = 1, size(this%leaf_vertices,1) layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id output_id = sum(output_ids(1:this%model(layer_idx)%layer%id-1)) + 1 output(output_id,:) = this%model(layer_idx)%layer%output(1,:) if(output_ids(this%model(layer_idx)%layer%id).gt.1)then output(output_id+1,:) = this%model(layer_idx)%layer%output(2,:) end if end do else output_shape = [0, 1] do i = 1, size(this%leaf_vertices,1) layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id if(size(this%model(layer_idx)%layer%output,2).ne.1)then call stop_program( & "Inconsistent size of dimension 2 in output layers" & ) return end if output_shape(1) = & output_shape(1) + size( this%model(layer_idx)%layer%output, 1 ) output_id = this%model(layer_idx)%layer%id output_ids(output_id) = size( this%model(layer_idx)%layer%output, 1 ) end do allocate(output(output_shape(1), output_shape(2))) start_idx = 1 end_idx = 0 do i = 1, size(this%leaf_vertices,1) layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id output_id = this%model(layer_idx)%layer%id end_idx = end_idx + output_ids(output_id) output(start_idx:end_idx,1) = this%model(layer_idx)%layer%output(:,1) start_idx = end_idx + 1 end do end if end function get_output !------------------------------------------------------------------------------- module subroutine extract_output_real(this, output) !! Get the output of the network as real array implicit none ! Arguments class(network_type), intent(in) :: this ! Instance of network real(real32), dimension(..), allocatable, intent(out) :: output !! Output ! Local variables integer :: layer_id !! Layer ID character(len=10) :: rank_str !! String for rank ! check if number of leaf vertices is 1 if(size(this%leaf_vertices,1).gt.1)then call print_warning("Output extraction to real array only works for single & &output networks") return end if ! Get output from the first (and only) leaf vertex layer_id = this%auto_graph%vertex(this%leaf_vertices(1))%id call this%model(layer_id)%layer%output(1,1)%extract(output) end subroutine extract_output_real !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module function accuracy_eval(this, output, start_index, end_index) & result(accuracy) !! Get the loss for the output implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network class(*), dimension(:,:), intent(in) :: output !! Output integer, intent(in) :: start_index, end_index !! Start and end batch indices real(real32) :: accuracy !! Loss value ! Local variables integer :: s, s_idx !! Loop index accuracy = 0._real32 select type(output) type is(graph_type) do s = start_index, end_index, 1 s_idx = s - start_index + 1 accuracy = accuracy + sum( this%get_accuracy( & this%model(this%leaf_vertices(1))%layer%output(1,s_idx)%val, & output(1,s)%vertex_features & ) ) / output(1,s)%num_vertices if( & this%model(this%leaf_vertices(1))%layer%output_shape(2).gt.0 & )then accuracy = accuracy + sum( this%get_accuracy( & this%model(this%leaf_vertices(1))%layer%output(2,s_idx)%val, & output(1,s)%edge_features & ) ) / output(1,s)%num_edges end if end do type is(real(real32)) accuracy = sum( & this%get_accuracy( & this%model(this%leaf_vertices(1))%layer%output(1,1)%val, & output(:,start_index:end_index:1) & )) type is(integer) accuracy = sum( & this%get_accuracy( & this%model(this%leaf_vertices(1))%layer%output(1,1)%val, & real(output(:,start_index:end_index:1),real32) & )) class is(array_type) accuracy = sum( & this%get_accuracy( & this%model(this%leaf_vertices(1))%layer%output(1,1)%val, & output(1,1)%val(:,start_index:end_index:1) & )) end select accuracy = accuracy / real(end_index - start_index + 1, real32) end function accuracy_eval !############################################################################### !############################################################################### module function loss_eval(this, start_index, end_index) result(loss) !! Get the loss for the output implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of network integer, intent(in) :: start_index, end_index !! Start and end batch indices type(array_type), pointer :: loss !! Loss value ! Local variables integer :: i, s !! Loop index type(array_type), pointer :: expected(:,:), predicted(:,:) if(this%use_graph_output)then expected(1:2, 1: end_index - start_index + 1) => & this%expected_array( :, start_index:end_index ) else allocate(expected(size(this%expected_array,1), size(this%expected_array,2))) do s = 1, size(this%expected_array,2) do i = 1, size(this%expected_array,1) call expected(i,s)%allocate( & array_shape = [ & this%expected_array(i,s)%shape, & size(this%expected_array(i,s)%val,2) & ] & ) expected(i,s)%val = this%expected_array(i,s)%val(:, & start_index:end_index:1) end do end do end if predicted => this%model(this%leaf_vertices(1))%layer%output loss => this%loss%compute( & predicted, & expected & ) end function loss_eval !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module subroutine forward_generic2d(this, input) !! Forward pass for array derived type input implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of network class(*), dimension(:,:), intent(in) :: input !! Input ! Local variables integer :: l, i, j, vertex_idx, layer_id, parent_id !! Loop index and vertex index integer :: num_input_layers !! Number of input layers type(array_type), pointer :: input_ptr(:,:) => null() type(array_ptr_type), dimension(:), allocatable :: input_list logical :: use_precomp select type(input) type is(graph_type) do j = 1, this%batch_size if(any(input(1,j)%adj_ja(1,:).gt.input(1,j)%num_vertices))then call stop_program( & "input graph has more vertices than expected" & ) end if end do end select ! Use pre-computed navigation if available use_precomp = allocated(this%fwd_layer_id) ! Forward pass !--------------------------------------------------------------------------- do l = 1, size(this%vertex_order,1) if(use_precomp)then layer_id = this%fwd_layer_id(l) num_input_layers = this%fwd_num_inputs(l) else vertex_idx = this%vertex_order(l) layer_id = this%auto_graph%vertex(vertex_idx)%id num_input_layers = count(this%auto_graph%adjacency(:,vertex_idx).gt.0) end if if(num_input_layers.eq.0)then select type(layer => this%model(layer_id)%layer) class is(input_layer_type) select type(input) type is(graph_type) call layer%set_input_graph( [ input(layer%index, :) ] ) cycle class is(array_type) call layer%forward(input(layer%index:layer%index,:)) do concurrent(i=1:size(layer%output,1), j=1:size(layer%output,2)) call layer%output(i,j)%set_requires_grad(.false.) end do cycle type is(real(real32)) allocate(input_ptr(1,1)) call input_ptr(1,1)%allocate(shape(input)) call input_ptr(1,1)%set(input) call layer%forward(input_ptr) call layer%output(1,1)%set_requires_grad(.false.) deallocate(input_ptr) input_ptr => null() cycle class default call stop_program( & "input type for layer "// & trim(layer%name) // & " is not supported" & ) end select class default return end select elseif(num_input_layers.eq.1)then if(use_precomp)then parent_id = this%fwd_parent_id(l) else vertex_idx = this%vertex_order(l) j = maxloc( & this%auto_graph%adjacency(:,vertex_idx), dim=1) parent_id = this%auto_graph%vertex(j)%id end if input_ptr => this%model(parent_id)%layer%output select type(input) type is(graph_type) call this%model(layer_id)%layer%set_graph( [ input(1,:) ] ) end select else vertex_idx = this%vertex_order(l) allocate(input_list(num_input_layers)) i = 0 do j = 1, size(this%vertex_order,1) if(this%auto_graph%adjacency(j,vertex_idx).gt.0)then i = i + 1 parent_id = this%auto_graph%vertex(j)%id input_list(i)%array => this%model(parent_id)%layer%output end if end do end if if(use_precomp)then if(this%fwd_layer_type(l).eq.1)then select type(layer => this%model(layer_id)%layer) class is(merge_layer_type) call layer%combine(input_list) end select deallocate(input_list) else call this%model(layer_id)%layer%forward(input_ptr) input_ptr => null() end if else select type(layer => this%model(layer_id)%layer) class is(merge_layer_type) call layer%combine(input_list) deallocate(input_list) class default call layer%forward(input_ptr) input_ptr => null() end select end if end do end subroutine forward_generic2d !------------------------------------------------------------------------------- module function forward_eval(this, input) result(output) !! Forward pass for evaluation implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of network class(*), dimension(:,:), intent(in) :: input !! Input type(array_type), pointer :: output(:,:) !! Output call this%forward(input) output => this%model(this%leaf_vertices(1))%layer%output end function forward_eval !------------------------------------------------------------------------------- module function forward_eval_multi(this, input) result(output) !! Forward pass for evaluation implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of network class(*), dimension(:,:), intent(in) :: input !! Input type(array_ptr_type), pointer :: output(:) !! Output ! Local variables integer :: l !! Loop index call this%forward(input) allocate(output(size(this%leaf_vertices,1))) do l = 1, size(this%leaf_vertices,1) output(l)%array => this%model(this%leaf_vertices(l))%layer%output end do end function forward_eval_multi !############################################################################### !############################################################################### module subroutine update(this) !! Update the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network real(real32), dimension(this%num_params) :: params, gradients !! Parameters and gradients ! Local variables integer :: l, i, s, start_idx, end_idx, seg !! Loop index !--------------------------------------------------------------------------- ! Increment optimiser iteration counter !--------------------------------------------------------------------------- if(this%optimiser%lr_decay%iterate_per_epoch)then if(this%epoch.gt.this%optimiser%epoch)then this%optimiser%epoch = this%epoch this%optimiser%iter = this%optimiser%iter + 1 end if else this%optimiser%iter = this%optimiser%iter + 1 end if !--------------------------------------------------------------------------- ! Get learnable parameters and gradients (using pre-computed layout) !--------------------------------------------------------------------------- if(this%param_num_segments.gt.0 .and. & allocated(this%param_seg_layer))then do seg = 1, this%param_num_segments l = this%param_seg_layer(seg) i = this%param_seg_pidx(seg) start_idx = this%param_seg_start(seg) end_idx = this%param_seg_end(seg) select type(current => this%model(l)%layer) class is(learnable_layer_type) params(start_idx:end_idx) = current%params(i)%val(:,1) if(.not.associated(current%params(i)%grad))then call stop_program( & "Gradient not allocated for parameters" & ) end if s = size(current%params(i)%grad%val,2) if(s.eq.1)then gradients(start_idx:end_idx) = & current%params(i)%grad%val(:,1) else gradients(start_idx:end_idx) = & sum(current%params(i)%grad%val, dim=2) / & real(s, real32) end if end select end do else start_idx = 0 end_idx = 0 do l = 1, this%num_layers select type(current => this%model(l)%layer) class is(learnable_layer_type) do i = 1, size(current%params) start_idx = end_idx + 1 end_idx = end_idx + size(current%params(i)%val, 1) params(start_idx:end_idx) = current%params(i)%val(:,1) if(.not.associated(current%params(i)%grad))then call stop_program( & "Gradient not allocated for parameters" & ) end if select case(size(current%params(i)%grad%val,2)) case(1) gradients(start_idx:end_idx) = & current%params(i)%grad%val(:,1) case default gradients(start_idx:end_idx) = [ & sum(current%params(i)%grad%val, dim=2) / & real( & size(current%params(i)%grad%val, dim=2), & real32) & ] end select end do end select end do end if call this%optimiser%clip_dict%apply(size(gradients),gradients) !--------------------------------------------------------------------------- ! Update layers of learnable layer types !--------------------------------------------------------------------------- call this%optimiser%minimise(params, gradients) ! Set params back using pre-computed layout if(this%param_num_segments.gt.0 .and. & allocated(this%param_seg_layer))then do seg = 1, this%param_num_segments l = this%param_seg_layer(seg) i = this%param_seg_pidx(seg) start_idx = this%param_seg_start(seg) end_idx = this%param_seg_end(seg) select type(current => this%model(l)%layer) class is(learnable_layer_type) current%params(i)%val(:,1) = params(start_idx:end_idx) end select end do else call this%set_params(params) end if call this%reset_gradients() end subroutine update !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module subroutine nullify_graph(this) !! Nullify the input graph implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network ! Local variables integer :: l do l = 1, this%num_layers call this%model(l)%layer%nullify_graph() end do end subroutine nullify_graph !############################################################################### !############################################################################### module subroutine post_epoch_hook(this, epoch, loss, accuracy) !! Default epoch hook — no-op. !! Override in a derived type to add custom per-epoch behaviour. implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network integer, intent(in) :: epoch !! Current epoch number real(real32), intent(in) :: loss !! Mean loss over the epoch real(real32), intent(in) :: accuracy !! Mean accuracy over the epoch end subroutine post_epoch_hook !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module function save_input_to_network( this, input ) result(num_samples) !! Save input to network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: input !! Input integer :: num_samples !! Number of samples ! Local variables integer :: i, j, l, ip, input_rank, num_inputs !! Loop index integer :: num_input_layers !! Number of input layers logical :: l_valid_rank_type !! Boolean whether rank type is valid character(256) :: err_msg !! Error message num_samples = get_num_samples(this, input) if(num_samples.le.0) return num_input_layers = size(this%root_vertices, 1) if(allocated(this%input_array))then do i = 1, size(this%input_array, 1) do j = 1, size(this%input_array, 2) call this%input_array(i,j)%deallocate() end do end do deallocate(this%input_array) end if if(allocated(this%input_graph)) deallocate(this%input_graph) ! Determine the rank of the input !--------------------------------------------------------------------------- select rank(input) rank(0) rank(1) rank(2) select type(input) class is(array_type) num_inputs = size(input(1,1)%val, 1) allocate(this%input_array(size(input,1), size(input,2))) do i = 1, size(input,1) do j = 1, size(input,2) call this%input_array(i,j)%assign_shallow(input(i,j)) end do end do return class default input_rank = rank(input) num_inputs = size(input) / num_samples allocate(this%input_array(1,1)) call this%input_array(1,1)%allocate(array_shape=[num_inputs, num_samples]) end select rank default input_rank = rank(input) num_inputs = size(input) / num_samples allocate(this%input_array(1,1)) call this%input_array(1,1)%allocate(array_shape=shape(input)) end select l_valid_rank_type = .false. ! Process input based on its rank !--------------------------------------------------------------------------- rank_select: select rank(input) rank(0) select type(input) type is(real); exit rank_select class default; l_valid_rank_type = .true. end select if(num_input_layers.ne.1)then call stop_program( & "number of input arrays does not match expected number of & &input layers" & ) return end if select type(input) class is(array_type) allocate(this%input_array(1,1)) call handle_array_type(input, this%input_array(1,1), num_samples) type is(array_ptr_type) allocate(this%input_array(size(input%array,1), size(input%array,2))) do i = 1, size(input%array,1) do j = 1, size(input%array,2) call handle_array_type( & input%array(i,j), this%input_array(i,j), num_samples & ) end do end do end select rank(1) select type(input) type is(real(real32)) exit rank_select type is(graph_type) allocate(this%input_graph(num_input_layers, num_samples)) this%input_graph(1,:) = input(:) return class default l_valid_rank_type = .true. end select if(size(input,1).ne.num_input_layers)then call stop_program( & "number of input arrays does not match expected number of & &input layers" & ) return end if select type(input) class is(array_type) allocate(this%input_array(1,size(input,1))) do l = 1, size(input,1) call handle_array_type(input(l), this%input_array(1,l), num_samples) end do type is(array_ptr_type) call stop_program("Use of array_ptr_type with rank 1 input not yet supported") return ! ip = 0 ! do l = 1, size(input,1) ! do i = 1, size(input%array,1) ! ip = ip + 1 ! do j = 1, size(input%array,2) ! call handle_array_type( & ! input(l)%array(i,j), this%input_array(ip,j), num_samples & ! ) ! end do ! end do ! end do end select rank(2) select type(input) type is(real(real32)) this%input_array(1,1)%val = reshape(input, [num_inputs, num_samples]) l_valid_rank_type = .true. type is(graph_type) num_samples = size(input, dim=2) allocate(this%input_graph(num_input_layers, num_samples)) this%input_graph(:,:) = input(:,:) return type is(array_type) call stop_program("SHOULD NOT GET HERE") this%input_array = input l_valid_rank_type = .true. end select rank(3) select type(input) type is(real(real32)) call this%input_array(1,1)%set(input) l_valid_rank_type = .true. end select rank(4) select type(input) type is(real(real32)) call this%input_array(1,1)%set(input) l_valid_rank_type = .true. end select rank(5) select type(input) type is(real(real32)) call this%input_array(1,1)%set(input) l_valid_rank_type = .true. end select end select rank_select if(.not.l_valid_rank_type)then write(err_msg,'("Unknown input type for rank ",I0)') input_rank call stop_program(err_msg) return end if contains function get_num_samples(network, input) result(num_samples) implicit none !! Get the number of samples in the input ! Arguments type(network_type), intent(in) :: network !! Instance of network class(*), dimension(..), intent(in) :: input !! Input integer :: num_samples !! Number of samples ! Local variables integer :: layer_id !! Layer ID logical :: use_graph_input !! Whether to use graph input num_samples = 0 layer_id = network%auto_graph%vertex(network%root_vertices(1))%id use_graph_input = network%model(layer_id)%layer%use_graph_input select rank(input) rank(0) select type(input) class is(array_type) num_samples = size(input%val, 2) class is(array_ptr_type) num_samples = size(input%array(1,1)%val, 2) class default call stop_program("Unknown input type in get_num_samples for rank 0") return end select rank(1) select type(input) class is(array_type) if(use_graph_input)then num_samples = size(input) else num_samples = size(input(1)%val, 2) end if class is(array_ptr_type) if(use_graph_input)then num_samples = size(input(1)%array, 2) else num_samples = size(input(1)%array(1,1)%val, 2) end if class is(graph_type) num_samples = size(input, dim=1) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 1") return end select rank(2) select type(input) class is(array_type) if(use_graph_input)then num_samples = size(input, 2) else num_samples = size(input(1,1)%val, 2) end if class is(graph_type) num_samples = size(input, dim=2) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 2") return end select rank(3) select type(input) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 3") return end select rank(4) select type(input) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 4") return end select rank(5) select type(input) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 5") return end select rank default call stop_program("Unknown input rank in get_num_samples") return end select end function get_num_samples subroutine handle_array_type(input, output, num_samples) !! Handle array type input ! Arguments class(array_type), intent(in) :: input !! Input type(array_type), intent(out) :: output !! Output integer, intent(in) :: num_samples !! Number of samples if(size(input%val,2).ne.num_samples)then call stop_program("number of samples in input arrays do not match") return end if call output%allocate( array_shape = & [ product(input%shape(1:input%rank)), num_samples ] & ) output%val = input%val end subroutine handle_array_type end function save_input_to_network !------------------------------------------------------------------------------- module subroutine save_output_to_network( this, output ) !! Save output to network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(:,:), intent(in) :: output !! Output ! Local variables integer :: i, j, s !! Loop indices if(allocated(this%expected_array))then do i = 1, size(this%expected_array, 1) do j = 1, size(this%expected_array, 2) call this%expected_array(i,j)%deallocate() end do end do deallocate(this%expected_array) end if select type(output) type is(graph_type) allocate(this%expected_array(2,size(output,2))) do s = 1, size(output,2) if(this%expected_array(1,s)%allocated) & call this%expected_array(1,s)%deallocate() if(this%expected_array(2,s)%allocated) & call this%expected_array(2,s)%deallocate() call this%expected_array(1,s)%allocate( & array_shape = [ & output(1,s)%num_vertex_features, output(1,s)%num_vertices & ] & ) call this%expected_array(1,s)%zero_grad() call this%expected_array(1,s)%set_requires_grad(.false.) call this%expected_array(1,s)%set( output(1,s)%vertex_features ) this%expected_array(1,s)%is_temporary = .false. if(output(1,s)%num_edge_features.le.0) cycle call this%expected_array(2,s)%allocate( & array_shape = [ & output(1,s)%num_edge_features, output(1,s)%num_edges & ] & ) call this%expected_array(2,s)%set_requires_grad(.false.) call this%expected_array(2,s)%set( output(1,s)%edge_features ) this%expected_array(2,s)%is_temporary = .false. end do class is(array_type) allocate(this%expected_array(size(output,1),size(output,2))) do s = 1, size(output,2) do i = 1, size(output,1) if(this%expected_array(i,s)%allocated) & call this%expected_array(i,s)%deallocate() call this%expected_array(i,s)%allocate( & array_shape = [ & output(i,s)%shape, size(output(i,s)%val,2) & ] & ) call this%expected_array(i,s)%set_requires_grad(.false.) call this%expected_array(i,s)%set( output(i,s)%val ) this%expected_array(i,s)%is_temporary = .false. end do end do type is(real) allocate(this%expected_array(1,1)) if(this%expected_array(1,1)%allocated) & call this%expected_array(1,1)%deallocate() call this%expected_array(1,1)%allocate( & array_shape = [ size(output,1), size(output,2) ] & ) call this%expected_array(1,1)%set_requires_grad(.false.) call this%expected_array(1,1)%set( output ) this%expected_array(1,1)%is_temporary = .false. type is(integer) allocate(this%expected_array(1,1)) if(this%expected_array(1,1)%allocated) & call this%expected_array(1,1)%deallocate() call this%expected_array(1,1)%allocate( & array_shape = [ size(output,1), size(output,2) ] & ) call this%expected_array(1,1)%set_requires_grad(.false.) this%expected_array(1,1)%val = real(output, real32) this%expected_array(1,1)%is_temporary = .false. class default call stop_program("output type not supported in training") end select end subroutine save_output_to_network !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module subroutine train( & this, input, output, num_epochs, batch_size, & plateau_threshold, shuffle_batches, batch_print_step, verbose, & print_precision, scientific_print, early_stopping, & val_input, val_output & ) !! Train the network !! !! This function trains the network on the input data for a number of !! epochs. The input data is split into batches of size batch_size and !! the network is trained on each batch. The network is trained using !! the optimiser specified in the network object. use athena__tools_infile, only: stop_check implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: input !! Input data class(*), dimension(:,:), intent(in) :: output !! Output data integer, intent(in) :: num_epochs !! Number of epochs integer, optional, intent(in) :: batch_size !! Batch size real(real32), optional, intent(in) :: plateau_threshold !! Plateau threshold logical, optional, intent(in) :: shuffle_batches !! Shuffle batches integer, optional, intent(in) :: batch_print_step !! Batch print step integer, optional, intent(in) :: verbose !! Verbosity level integer, optional, intent(in) :: print_precision !! Number of decimal places to print for training metrics logical, optional, intent(in) :: scientific_print !! Whether to print training metrics in scientific notation logical, optional, intent(in) :: early_stopping !! Whether to stop training early if convergence is detected class(*), dimension(..), optional, intent(in) :: val_input !! Validation input data class(*), dimension(:,:), optional, intent(in) :: val_output !! Validation expected output data ! Training parameters real(real32) :: batch_loss, batch_accuracy, avg_loss, avg_accuracy !! Loss and accuracy ! learning parameters integer :: l, num_samples !! Loop index integer :: num_batches !! Number of batches integer :: converged !! Convergence flag integer :: window_width !! Length of convergence check window integer :: verbose_ !! Verbosity level real(real32) :: plateau_threshold_ !! Plateau threshold logical :: shuffle_batches_ !! Shuffle batches logical :: use_accuracy !! Whether accuracy evaluation is available logical :: early_stopping_ !! Whether to stop training early if convergence is detected ! Printing parameters integer :: batch_print_step_ !! Batch print step integer :: print_precision_ !! Number of decimal places for metric output logical :: scientific_print_ !! Whether to print metrics in scientific notation character(len=64) :: loss_str, accuracy_str !! Formatted metrics for printing character(len=128) :: val_str !! Formatted validation metrics for printing ! Training loop variables integer :: epoch, batch, start_index, end_index !! Loop index integer, allocatable, dimension(:) :: batch_order !! Batch order integer :: i, j, s !! Loop index integer :: tmp_batch_size !! Temporary integer to store batch size during validation integer :: current_batch_size, target_batch_size !! Actual batch size for the current batch and the target batch size logical, allocatable :: mode_store(:) !! Storage for inference mode booleans class(*), allocatable :: data_poly(:,:) type(array_type), pointer :: loss => null() ! Validation variables logical :: use_validation !! Whether validation data is provided integer :: val_num_samples !! Number of validation samples integer :: val_sample !! Loop index for validation real(real32) :: val_loss, val_accuracy, val_loss_sum, val_accuracy_sum !! Validation loss and accuracy #ifdef __INTEL_COMPILER type(array_type), pointer :: saved_expected_array(:,:) => null() !! Saved training expected output (for restoring after validation) #else type(array_type), dimension(:,:), allocatable :: saved_expected_array !! Saved training expected output (for restoring after validation) #endif #ifdef _OPENMP type(network_type) :: this_copy !! Copy of network #endif !--------------------------------------------------------------------------- ! Check loss and accuracy methods are set !--------------------------------------------------------------------------- if(.not.allocated(this%loss))then call stop_program("loss method not set") return end if use_accuracy = associated(this%get_accuracy) accuracy_str = "" val_str = "" !--------------------------------------------------------------------------- ! Check validation data !--------------------------------------------------------------------------- use_validation = present(val_input) .and. present(val_output) if(present(val_input) .neqv. present(val_output))then call stop_program( & "both val_input and val_output must be provided for validation" & ) return end if !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- verbose_ = 0 batch_print_step_ = 20 plateau_threshold_ = 0._real32 shuffle_batches_ = .true. scientific_print_ = .false. print_precision_ = 3 early_stopping_ = .true. if(present(plateau_threshold)) plateau_threshold_ = plateau_threshold if(present(shuffle_batches)) shuffle_batches_ = shuffle_batches if(present(batch_print_step)) batch_print_step_ = batch_print_step if(present(verbose)) verbose_ = verbose if(present(print_precision)) print_precision_ = max(print_precision, 0) if(present(scientific_print)) scientific_print_ = scientific_print if(present(batch_size)) this%batch_size = batch_size if(present(early_stopping)) early_stopping_ = early_stopping !--------------------------------------------------------------------------- ! Initialise monitoring variables !--------------------------------------------------------------------------- window_width = max(ceiling(500._real32/this%batch_size),1) do i = 1, size(this%metrics,dim=1) this%metrics(i)%window_width = window_width end do !--------------------------------------------------------------------------- ! Save input and output to network !--------------------------------------------------------------------------- num_samples = this%save_input( input ) call this%save_output( output ) if(size(output,2).ne.num_samples.and.this%use_graph_output)then call stop_program("number of samples in input and output do not match") return end if !--------------------------------------------------------------------------- ! If parallel, initialise slices !--------------------------------------------------------------------------- num_batches = (num_samples + this%batch_size - 1) / this%batch_size allocate(batch_order(num_batches)) do batch = 1, num_batches batch_order(batch) = batch end do !--------------------------------------------------------------------------- ! Set/reset batch size for training !--------------------------------------------------------------------------- call this%set_batch_size(this%batch_size) target_batch_size = this%batch_size !--------------------------------------------------------------------------- ! Enable training mode !--------------------------------------------------------------------------- call this%set_training_mode(mode_store) epoch_loop: do epoch = 1, num_epochs this%epoch = epoch !------------------------------------------------------------------------ ! Shuffle batch order at the start of each epoch !------------------------------------------------------------------------ if(shuffle_batches_)then call shuffle(batch_order) end if avg_loss = 0._real32 avg_accuracy = 0._real32 !------------------------------------------------------------------------ ! Batch loop ! ... split data up into minibatches for training !------------------------------------------------------------------------ batch_loop: do batch = 1, num_batches ! Set batch start and end index !--------------------------------------------------------------------- start_index = (batch_order(batch) - 1) * this%batch_size + 1 end_index = min(batch_order(batch) * this%batch_size, num_samples) current_batch_size = end_index - start_index + 1 if(current_batch_size.ne.target_batch_size)then call this%set_batch_size(current_batch_size) end if ! Forward pass !--------------------------------------------------------------------- select case(this%use_graph_input) case(.true.) data_poly = get_sample( & this%input_graph, start_index, end_index, current_batch_size & ) case default data_poly = get_sample( & this%input_array, start_index, end_index, current_batch_size, & as_graph = .false. & ) end select call this%forward(data_poly) deallocate(data_poly) ! Backward pass !--------------------------------------------------------------------- loss => this%loss_eval(start_index, end_index) loss%is_temporary = .false. call loss%grad_reverse(reset_graph=.true.) ! Compute loss and accuracy (for monitoring) !--------------------------------------------------------------------- batch_loss = sum(loss%val) batch_accuracy = 0._real32 if(use_accuracy)then batch_accuracy = this%accuracy_eval(output, start_index, end_index) end if ! Average metric over batch size and store !--------------------------------------------------------------------- avg_loss = avg_loss + batch_loss if(use_accuracy)then avg_accuracy = avg_accuracy + batch_accuracy end if ! Update weights and biases using optimisation algorithm !--------------------------------------------------------------------- call this%update() call loss%nullify_graph() deallocate(loss) nullify(loss) ! Print batch results !--------------------------------------------------------------------- if(abs(verbose_).gt.0.and.& (batch.eq.1.or.abs(mod(batch,batch_print_step_)).lt.1.E-6))then loss_str = format_training_real( & avg_loss / real(batch, real32), print_precision_, & scientific_print_ & ) if(use_accuracy)then accuracy_str = ", accuracy=" // trim(format_training_real( & avg_accuracy / real(batch, real32), & print_precision_, scientific_print_ & )) end if write(6,'("epoch=",I0,", batch=",I0,& &", lr=",ES0.2,", loss=",A,A)' & ) & this%epoch, batch, & this%optimiser%lr_decay%get_lr( & this%optimiser%learning_rate, this%optimiser%iter & ), & trim(loss_str), trim(accuracy_str) end if ! Check for user-name stop file !--------------------------------------------------------------------- if(stop_check())then write(0,*) "STOPCAR ENCOUNTERED" write(0,*) "Exiting training loop..." exit epoch_loop end if end do batch_loop call this%metrics(1)%append(avg_loss / real(num_batches, real32)) if(use_accuracy)then call this%metrics(2)%append(avg_accuracy / real(num_batches, real32)) end if !------------------------------------------------------------------------ ! Validation evaluation at end of epoch !------------------------------------------------------------------------ val_str = "" if(use_validation)then #ifdef __INTEL_COMPILER ! Save training expected output. `ifx` crashes on the allocatable ! local declaration used with `move_alloc`, so keep an explicit ! pointer-backed copy here instead. if(allocated(this%expected_array))then allocate(saved_expected_array( & size(this%expected_array, 1), size(this%expected_array, 2) & )) do i = 1, size(this%expected_array, 1) do j = 1, size(this%expected_array, 2) call saved_expected_array(i,j)%allocate( & source=this%expected_array(i,j) & ) call this%expected_array(i,j)%deallocate() end do end do deallocate(this%expected_array) else nullify(saved_expected_array) end if #else call move_alloc(this%expected_array, saved_expected_array) #endif ! Save validation output to network call this%save_output( val_output ) ! Save validation input val_num_samples = this%save_input( val_input ) ! Set batch size to 1 and enable inference mode call this%set_batch_size(1) call this%set_inference_mode() ! Evaluate validation loss and accuracy val_loss_sum = 0._real32 val_accuracy_sum = 0._real32 do val_sample = 1, val_num_samples select case(this%use_graph_input) case(.true.) data_poly = get_sample( & this%input_graph, val_sample, val_sample, 1 & ) case default data_poly = get_sample_array( & this%input_array, val_sample, val_sample, 1, & as_graph = .false. & ) end select call this%forward(data_poly) deallocate(data_poly) loss => this%loss_eval(val_sample, val_sample) val_loss_sum = val_loss_sum + sum(loss%val) call loss%nullify_graph() deallocate(loss) nullify(loss) if(use_accuracy)then val_accuracy_sum = val_accuracy_sum + & this%accuracy_eval(val_output, val_sample, val_sample) end if end do val_loss = val_loss_sum / real(val_num_samples, real32) val_accuracy = val_accuracy_sum / real(val_num_samples, real32) ! Build validation print string val_str = ", val_loss=" // trim(format_training_real( & val_loss, print_precision_, scientific_print_ & )) if(use_accuracy)then val_str = trim(val_str) // ", val_accuracy=" // & trim(format_training_real( & val_accuracy, print_precision_, scientific_print_ & )) end if ! Restore training expected output if(allocated(this%expected_array))then do i = 1, size(this%expected_array, 1) do j = 1, size(this%expected_array, 2) call this%expected_array(i,j)%deallocate() end do end do deallocate(this%expected_array) end if #ifdef __INTEL_COMPILER ! `ifx` does not support `move_alloc` in this context if(associated(saved_expected_array))then allocate(this%expected_array( & size(saved_expected_array, 1), size(saved_expected_array, 2) & )) do i = 1, size(saved_expected_array, 1) do j = 1, size(saved_expected_array, 2) call this%expected_array(i,j)%allocate( & source=saved_expected_array(i,j) & ) call saved_expected_array(i,j)%deallocate() end do end do deallocate(saved_expected_array) nullify(saved_expected_array) end if #else call move_alloc(saved_expected_array, this%expected_array) #endif ! Restore training input num_samples = this%save_input( input ) ! Restore training batch size and inference mode call this%set_batch_size(target_batch_size) call this%set_training_mode() end if ! Print epoch summary results !------------------------------------------------------------------------ if(use_validation.and.verbose_.ge.0)then write(6,'("epoch=",I0,A)') this%epoch, trim(val_str) elseif(verbose_.eq.0)then loss_str = format_training_real( & this%metrics(1)%val, print_precision_, scientific_print_ & ) if(use_accuracy)then accuracy_str = ", accuracy=" // trim(format_training_real( & this%metrics(2)%val, print_precision_, scientific_print_ & )) end if write(6,'("epoch=",I0,& &", lr=",ES0.2,", loss=",A,A,A)' & ) & this%epoch, & this%optimiser%lr_decay%get_lr( & this%optimiser%learning_rate, this%optimiser%iter & ), & trim(loss_str), trim(accuracy_str), trim(val_str) end if !------------------------------------------------------------------------ ! Per-epoch callback (e.g. W&B logging via wandb_network_type) !------------------------------------------------------------------------ call this%post_epoch_hook( & this%epoch, & this%metrics(1)%val, & this%metrics(2)%val & ) !------------------------------------------------------------------------ ! Check for convergence and stop early if enabled ! When validation data is provided, check validation loss for plateau !------------------------------------------------------------------------ if(early_stopping_)then if(use_validation)then if(val_loss .lt. plateau_threshold_ .and. & plateau_threshold_ .gt. 0._real32) exit epoch_loop else call this%metrics(1)%check(plateau_threshold_, converged) if(converged.ne.0) exit epoch_loop end if if(use_accuracy)then call this%metrics(2)%check(plateau_threshold_, converged) if(converged.ne.0) exit epoch_loop end if end if end do epoch_loop ! Final epoch metrics if(use_accuracy)then this%accuracy_val = this%metrics(2)%val else this%accuracy_val = 0._real32 end if this%loss_val = this%metrics(1)%val !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end subroutine train !############################################################################### !############################################################################### function format_training_real(value, decimals, scientific) result(formatted) !! Format a training metric with a configurable number of decimal places. implicit none ! Arguments real(real32), intent(in) :: value !! Value to format integer, intent(in) :: decimals !! Number of decimal places logical, intent(in) :: scientific !! Whether to use scientific notation character(len=64) :: formatted !! Formatted string ! Local variables character(len=16) :: fmt !! Internal write format integer :: width_ !! Field width for scientific formatting integer :: decimals_ !! Clamped decimal count decimals_ = min(max(decimals, 0), 30) if(scientific)then width_ = max(decimals_ + 8, 14) write(fmt,'("(ES",I0,".",I0,"E2)")') width_, decimals_ else write(fmt,'("(F0.",I0,")")') decimals_ end if write(formatted, fmt) value formatted = adjustl(formatted) end function format_training_real !############################################################################### !############################################################################### module subroutine test( & this, input, output, verbose & ) !! Test the network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: input !! Input data class(*), dimension(:,:), intent(in) :: output !! Output data integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: l, sample, num_samples !! Loop index integer :: verbose_ !! Verbosity level logical :: use_accuracy !! Whether accuracy evaluation is available real(real32) :: acc_val, loss_val !! Loss and accuracy class(*), allocatable, dimension(:,:) :: data_poly !! Polymorphic data array type(array_type), pointer :: loss => null() !! Loss logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if use_accuracy = associated(this%get_accuracy) do l = 1, size(this%metrics,dim=1) this%metrics(l)%val = 0._real32 end do loss_val = 0._real32 acc_val = 0._real32 num_samples = this%save_input( input ) !--------------------------------------------------------------------------- ! Reset batch size for testing !--------------------------------------------------------------------------- call this%set_batch_size(1) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Testing loop !--------------------------------------------------------------------------- test_loop1: do sample = 1, num_samples ! Forward pass !------------------------------------------------------------------------ select case(this%use_graph_input) case(.true.) data_poly = get_sample( & this%input_graph, sample, sample, 1 & ) case default data_poly = get_sample_array( & this%input_array, sample, sample, 1, & as_graph = .false. & ) end select call this%forward(data_poly) deallocate(data_poly) ! Compute loss and accuracy (for monitoring) !------------------------------------------------------------------------ loss => this%loss_eval(sample, sample) loss_val = sum(loss%val) call loss%nullify_graph() deallocate(loss) nullify(loss) if(use_accuracy)then acc_val = this%accuracy_eval(output, sample, sample) this%metrics(2)%val = this%metrics(2)%val + acc_val end if this%metrics(1)%val = this%metrics(1)%val + loss_val end do test_loop1 ! Normalise metrics by number of samples !--------------------------------------------------------------------------- if(use_accuracy)then this%accuracy_val = this%metrics(2)%val / real(num_samples, real32) else this%accuracy_val = 0._real32 end if this%loss_val = this%metrics(1)%val / real(num_samples, real32) !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end subroutine test !############################################################################### !############################################################################### module function predict_real( & this, input, verbose & ) result(output) !! Predict the output for a 1D input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network real(real32), dimension(..), intent(in) :: input !! Input integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables real(real32), dimension(:,:), allocatable :: output !! Output integer :: verbose_, batch_size !! Verbosity level logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if select rank(input) rank(2) batch_size = size(input,dim=2) rank(3) batch_size = size(input,dim=3) rank(4) batch_size = size(input,dim=4) rank(5) batch_size = size(input,dim=5) rank(6) batch_size = size(input,dim=6) rank default batch_size = size(input,dim=rank(input)) end select !--------------------------------------------------------------------------- ! Reset batch size for testing !--------------------------------------------------------------------------- call this%set_batch_size(batch_size) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Predict !--------------------------------------------------------------------------- call this%forward(get_sample(input, 1, batch_size, batch_size)) output = this%model(this%leaf_vertices(1))%layer%output(1,1)%val !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_real !############################################################################### !############################################################################### module function predict_graph1d( this, input, verbose ) result(output) !! Predict the output for a graph input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network type(graph_type), dimension(:), intent(in) :: input !! Input graph integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: l, s !! Loop index type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: output !! Output graph integer :: verbose_ = 0, batch_size !! Verbosity level logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Reset batch size for testing !--------------------------------------------------------------------------- batch_size = size(input) call this%set_batch_size(batch_size) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Predict !--------------------------------------------------------------------------- call this%forward(get_sample(input, 1, batch_size, batch_size)) do l = 1, size(this%leaf_vertices) do s = 1, batch_size output(l,s)%num_vertices = input(s)%num_vertices output(l,s)%num_edges = input(s)%num_edges output(l,s)%num_vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(1) output(l,s)%num_edge_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(2) output(l,s)%vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output(1,s)%val if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then output(l,s)%edge_features = input(s)%edge_features else output(l,s)%edge_features = this%model( & this%leaf_vertices(l) & )%layer%output(2,s)%val end if end do end do !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_graph1d !------------------------------------------------------------------------------- module function predict_graph2d( this, input, verbose ) result(output) !! Predict the output for a graph input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network type(graph_type), dimension(:,:), intent(in) :: input !! Input graph integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: l, s !! Loop index type(graph_type), dimension(size(this%leaf_vertices),size(input,dim=2)) :: & output !! Output graph integer :: verbose_ = 0, batch_size !! Verbosity level logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Reset batch size for testing !--------------------------------------------------------------------------- batch_size = size(input, 2) call this%set_batch_size(batch_size) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Predict !--------------------------------------------------------------------------- call this%forward(get_sample(input, 1, batch_size, batch_size)) do l = 1, size(this%leaf_vertices) do s = 1, batch_size output(l,s)%num_vertices = input(1,s)%num_vertices output(l,s)%num_edges = input(1,s)%num_edges output(l,s)%num_vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(1) output(l,s)%num_edge_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(2) output(l,s)%vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output(1,s)%val if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then output(l,s)%edge_features = input(1,s)%edge_features else output(l,s)%edge_features = this%model( & this%leaf_vertices(l) & )%layer%output(2,s)%val end if end do end do !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_graph2d !############################################################################### !############################################################################### module function predict_array_from_real( this, input, output_as_array, verbose ) & result(output) !! Predict the output for a generic input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: input !! Input graph logical, intent(in) :: output_as_array !! Whether to output as array integer, intent(in), optional :: verbose !! Verbosity level type(array_type), dimension(:,:), allocatable :: output !! Predicted output ! Local variables integer :: s, i !! Loop index integer :: num_samples !! Number of samples integer :: verbose_ !! Verbosity level logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if if(.not.output_as_array)then call stop_program("predict_array_from_real: output_as_array must be true") return end if !--------------------------------------------------------------------------- ! Set number of samples for predicting !--------------------------------------------------------------------------- num_samples = this%save_input( input ) ! call this%set_batch_size(num_samples) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Forward pass !--------------------------------------------------------------------------- select case(this%use_graph_input) case(.true.) call this%forward(this%input_graph) case default call this%forward(this%input_array) end select !--------------------------------------------------------------------------- ! Allocate output data !--------------------------------------------------------------------------- allocate(output( & size(this%model(this%leaf_vertices(1))%layer%output, 1), & size(this%model(this%leaf_vertices(1))%layer%output, 2) & )) do s = 1, size(this%model(this%leaf_vertices(1))%layer%output, 2) do i = 1, size(this%model(this%leaf_vertices(1))%layer%output, 1) output(i,s) = this%model(this%leaf_vertices(1))%layer%output(i,s) end do end do !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_array_from_real !############################################################################### !############################################################################### module function predict_array( this, input, verbose ) & result(output) !! Predict the output for a generic input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(array_type), dimension(..), intent(in) :: input !! Input graph integer, intent(in), optional :: verbose !! Verbosity level type(array_type), dimension(:,:), allocatable :: output !! Predicted output ! Local variables integer :: l, s, i, j, layer_id !! Loop index integer :: num_samples !! Number of samples integer :: verbose_ !! Verbosity level integer, dimension(2) :: output_shape !! Output shape logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if !--------------------------------------------------------------------------- ! Set number of samples for predicting !--------------------------------------------------------------------------- num_samples = this%save_input( input ) ! call this%set_batch_size(num_samples) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Forward pass !--------------------------------------------------------------------------- select case(this%use_graph_input) case(.true.) call this%forward(this%input_graph) case default call this%forward(this%input_array) end select !--------------------------------------------------------------------------- ! Allocate output data !--------------------------------------------------------------------------- output_shape = this%get_output_shape() allocate(output(output_shape(1), output_shape(2))) do l = 1, size(this%leaf_vertices) layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id j = 0 do i = 1, size(this%model(layer_id)%layer%output, 1) j = j + 1 do s = 1, size(this%model(layer_id)%layer%output, 2) output(j,s) = this%model(layer_id)%layer%output(i,s) end do end do end do !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_array !############################################################################### !############################################################################### module function predict_generic( this, input, verbose, output_as_graph ) & result(output) !! Predict the output for a generic input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(:,:), intent(in) :: input !! Input graph integer, intent(in), optional :: verbose !! Verbosity level logical, intent(in), optional :: output_as_graph !! Boolean whether to output as graph class(*), dimension(:,:), allocatable :: output !! Predicted output ! Local variables integer :: l, s, i, j, layer_id !! Loop index integer :: num_samples !! Number of samples integer :: verbose_ !! Verbosity level logical :: output_as_graph_ !! Output as graph boolean integer, dimension(2) :: output_shape !! Output shape logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if if(present(output_as_graph))then output_as_graph_ = output_as_graph else output_as_graph_ = .false. end if if(output_as_graph_.and..not.this%use_graph_output)then call stop_program("output_as_graph is true but network does not use & &graph output") end if !--------------------------------------------------------------------------- ! Set number of samples for predicting !--------------------------------------------------------------------------- num_samples = this%save_input( input ) call this%set_batch_size(num_samples) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Forward pass !--------------------------------------------------------------------------- select case(this%use_graph_input) case(.true.) call this%forward(this%input_graph) case default call this%forward(this%input_array) end select !--------------------------------------------------------------------------- ! Allocate output data !--------------------------------------------------------------------------- output_shape = this%get_output_shape() if(output_as_graph_)then allocate(output(output_shape(1), output_shape(2)), source = graph_type()) select type(output) type is(graph_type) select type(input) type is(graph_type) do l = 1, size(this%leaf_vertices) do s = 1, num_samples output(l,s)%num_vertices = input(1,s)%num_vertices output(l,s)%num_edges = input(1,s)%num_edges output(l,s)%num_vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(1) output(l,s)%num_edge_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(2) output(l,s)%vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output(1,s)%val if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then output(l,s)%edge_features = input(1,s)%edge_features else output(l,s)%edge_features = this%model( & this%leaf_vertices(l) & )%layer%output(2,s)%val end if end do end do class default call stop_program("input is not of type graph_type") end select class default call stop_program("allocation of output as graph_type failed") end select else output_shape = this%get_output_shape() allocate(output(output_shape(1), output_shape(2)), source = array_type()) select type(output) type is(array_type) do l = 1, size(this%leaf_vertices) layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id j = 0 do i = 1, size(this%model(layer_id)%layer%output, 1) j = j + 1 do s = 1, size(this%model(layer_id)%layer%output, 2) output(j,s) = this%model(layer_id)%layer%output(i,s) end do end do end do end select end if !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_generic !############################################################################### !############################################################################### module subroutine print_summary(this) !! Print a summary of the network architecture implicit none ! Arguments class(network_type), intent(in) :: this !! Instance of network ! Local variables integer :: i, vertex_idx !! Loop index and vertex index integer :: total_params !! Parameter counts integer :: layer_params !! Parameters in current layer character(len=80) :: line !! Line separator character(len=40) :: layer_name !! Layer name character(len=30) :: output_shape_str !! Output shape string character(len=20) :: param_str !! Parameter count string character(len=100) :: fmt !! Format string line = repeat('_', 80) ! Print header write(*,*) write(*,'(A)') line write(*,'(A)') 'Model Summary' write(*,'(A)') line write(*,'(A35, A25, A15)') 'Layer (type)', 'Output Shape', 'Param #' write(*,'(A)') repeat('=', 80) ! Initialise parameter count total_params = 0 ! Print each layer do i = 1, this%num_layers vertex_idx = this%vertex_order(i) associate(layer => this%model(vertex_idx)%layer) ! Get layer name if(allocated(layer%name))then write(layer_name, '(A," (",A,")")') & trim(layer%name), trim(layer%subtype) else write(layer_name, '(A,I0," (",A,")")') & 'layer_', i, trim(layer%subtype) end if ! Get output shape string if(allocated(layer%output_shape))then ! write the general format for output shape write(fmt,'("(""(""",A,"I0,"")"")")') & repeat('I0,", "', size(layer%output_shape)-1) write(output_shape_str, fmt) layer%output_shape else output_shape_str = '(Not set)' end if ! Get parameter count layer_params = layer%get_num_params() total_params = total_params + layer_params if(layer_params .gt. 0)then write(param_str, '(I0)') layer_params else param_str = '0' end if ! Print layer information write(*,'(A35, A25, A15)') adjustl(trim(layer_name)), & adjustl(trim(output_shape_str)), adjustl(trim(param_str)) end associate end do ! Print footer write(*,'(A)') repeat('=', 80) write(*,'(A,I0)') 'Number of input vertices: ', size(this%root_vertices) write(*,'(A,I0)') 'Number of output vertices: ', size(this%leaf_vertices) write(*,'(A,I0)') 'Total trainable params: ', total_params write(*,'(A)') line write(*,*) end subroutine print_summary !############################################################################### !############################################################################### module function inverse_design_real( & this, target, x_init, optimiser, steps & ) result(x_opt) !! Optimise input to match a target output (real inputs). !! Wraps the array_type implementation after converting real arrays. implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of the network real(real32), dimension(:,:), intent(in) :: target !! Target output values real(real32), dimension(:,:), intent(in) :: x_init !! Initial input values class(base_optimiser_type), optional, intent(in) :: optimiser !! Optimiser for input updates (defaults to network optimiser) integer, intent(in) :: steps !! Number of optimisation iterations real(real32), dimension(size(x_init,1), size(x_init,2)) :: x_opt !! Optimised input ! Local variables type(array_type), pointer :: target_arr(:,:), x_init_arr(:,:), x_opt_arr(:,:) !! Working input and target as array_type !--------------------------------------------------------------------------- ! Convert real arrays to array_type !--------------------------------------------------------------------------- allocate(target_arr(1,1), x_init_arr(1,1), x_opt_arr(1,1)) call target_arr(1,1)%allocate(source=target) call x_init_arr(1,1)%allocate(source=x_init) !--------------------------------------------------------------------------- ! Delegate to array_type implementation !--------------------------------------------------------------------------- x_opt_arr = this%inverse_design_array_2d( & target_arr, x_init_arr, optimiser, steps & ) x_opt = x_opt_arr(1,1)%val call target_arr(1,1)%deallocate() call x_init_arr(1,1)%deallocate() call x_opt_arr(1,1)%deallocate() deallocate(target_arr, x_init_arr, x_opt_arr) end function inverse_design_real !############################################################################### !############################################################################### module function inverse_design_array_0d( & this, target, x_init, optimiser, steps & ) result(x_opt) !! Optimise the input so the network output matches a target. !! Wraps the array_type implementation after converting to 2D array. implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of the network type(array_type), intent(in) :: target !! Target output values type(array_type), intent(in) :: x_init !! Initial input values class(base_optimiser_type), optional, intent(in) :: optimiser !! Optimiser for input updates (defaults to network optimiser) integer, intent(in) :: steps !! Number of optimisation iterations type(array_type) :: x_opt !! Optimised input ! Local variables type(array_type), pointer :: target_arr(:,:), x_init_arr(:,:), x_opt_arr(:,:) !--------------------------------------------------------------------------- ! Convert real arrays to array_type !--------------------------------------------------------------------------- allocate(target_arr(1,1), x_init_arr(1,1), x_opt_arr(1,1)) call target_arr(1,1)%allocate(source=target) call x_init_arr(1,1)%allocate(source=x_init) !--------------------------------------------------------------------------- ! Delegate to array_type implementation !--------------------------------------------------------------------------- x_opt_arr = this%inverse_design_array_2d( & target_arr, x_init_arr, optimiser, steps & ) x_opt = x_opt_arr(1,1) call target_arr(1,1)%deallocate() call x_init_arr(1,1)%deallocate() call x_opt_arr(1,1)%deallocate() deallocate(target_arr, x_init_arr, x_opt_arr) end function inverse_design_array_0d !############################################################################### !############################################################################### module function inverse_design_array_2d( & this, target, x_init, optimiser, steps & ) result(x_opt) !! Optimise the input so the network output matches a target. !! Wraps the array_type implementation after converting to 2D array. implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of the network type(array_type), dimension(:,:), intent(in) :: target !! Target output values type(array_type), dimension(:,:), intent(in) :: x_init !! Initial input values class(base_optimiser_type), optional, intent(in) :: optimiser !! Optimiser for input updates (defaults to network optimiser) integer, intent(in) :: steps !! Number of optimisation iterations type(array_type), dimension(size(x_init,1), size(x_init,2)) :: x_opt !! Optimised input ! Local variables integer :: step, i, j, itmp1, root_id, num_x, num_samples, num_elements !! Loop index, root layer id, number of input elements logical :: use_edge_features !! Whether edge features are used in the input type(array_type), pointer :: loss !! Loss pointer class(base_optimiser_type), allocatable :: opt !! Local optimiser instance real(real32), allocatable :: x_flat(:), x_grad(:) !! Flat input vector and gradient logical, allocatable :: mode_store(:) !! Storage for inference mode booleans real(real32), allocatable :: saved_params(:) !! Saved network parameters !--------------------------------------------------------------------------- ! Ensure the network has a loss function !--------------------------------------------------------------------------- if(.not.allocated(this%loss))then call this%set_loss("mse") end if !--------------------------------------------------------------------------- ! Get number of input elements !--------------------------------------------------------------------------- num_x = 0 use_edge_features = .false. if(this%use_graph_input)then num_samples = size(x_init, dim=2) num_x = size(x_init(1,1)%val) ! vertex features ! determine if edge features are used by checking the output shape of the input layer if(size(this%model(this%root_vertices(1))%layer%output_shape,dim=1).eq.2)then use_edge_features = .true. num_x = num_x + size(x_init(2,1)%val) ! edge features end if else num_samples = size(x_init(1,1)%val, dim=2) do i = 1, size(x_init,1) do j = 1, size(x_init,2) num_x = num_x + size(x_init(i,j)%val,dim=1) end do end do end if x_opt = x_init if(num_samples.gt.1)then call stop_program( & "inverse_design_array_2d: batch size greater than 1 not supported" & ) end if !--------------------------------------------------------------------------- ! Set up optimiser for input variables !--------------------------------------------------------------------------- if(present(optimiser))then allocate(opt, source=optimiser) else allocate(opt, source=base_optimiser_type( & learning_rate=this%optimiser%learning_rate)) end if call opt%init_gradients(num_x) opt%iter = 0 !--------------------------------------------------------------------------- ! Pre-allocate flat arrays used in the optimisation loop !--------------------------------------------------------------------------- allocate(x_flat(num_x)) allocate(x_grad(num_x)) !--------------------------------------------------------------------------- ! Ensure training mode is active so the full graph is built !--------------------------------------------------------------------------- call this%set_training_mode(mode_store) !--------------------------------------------------------------------------- ! Get root layer id !--------------------------------------------------------------------------- root_id = this%auto_graph%vertex(this%root_vertices(1))%id call this%set_batch_size(num_samples) !--------------------------------------------------------------------------- ! Save network parameters so they can be restored afterwards !--------------------------------------------------------------------------- allocate(saved_params(this%num_params)) saved_params = this%get_params() !--------------------------------------------------------------------------- ! Optimisation loop !--------------------------------------------------------------------------- do step = 1, steps ! Forward pass with current x call this%forward(x_opt) ! Enable gradient tracking on the input layer output if(this%use_graph_input)then call this%model(root_id)%layer%output(1,1)%set_requires_grad(.true.) if(use_edge_features)then call this%model(root_id)%layer%output(2,1)%set_requires_grad(.true.) end if else do i = 1, size(x_opt,1) do j = 1, size(x_opt,2) call this%model(root_id)%layer%output(i,j)%set_requires_grad(.true.) end do end do end if ! Compute loss via the network's loss function call this%save_output(target) loss => this%loss_eval(1, num_samples) ! Backward pass call loss%grad_reverse() ! Extract gradient w.r.t. input itmp1 = 0 if(associated(this%model(root_id)%layer%output(1,1)%grad))then if(this%use_graph_input)then num_elements = size(x_opt(1,1)%val, dim=1) do i = 1, size(x_opt(1,1)%val, dim=2) itmp1 = itmp1 + 1 x_grad(itmp1:itmp1+num_elements-1) = & this%model(root_id)%layer%output(1,1)%grad%val(:,i) x_flat(itmp1:itmp1+num_elements-1) = & x_opt(1,1)%val(:,i) itmp1 = itmp1 + num_elements - 1 end do if(use_edge_features)then num_elements = size(x_opt(1,1)%val, dim=1) do i = 1, size(x_opt(2,1)%val, dim=2) itmp1 = itmp1 + 1 x_grad(itmp1:itmp1+num_elements-1) = & this%model(root_id)%layer%output(2,1)%grad%val(:,i) x_flat(itmp1:itmp1+num_elements-1) = & x_opt(2,1)%val(:,i) itmp1 = itmp1 + num_elements - 1 end do end if else do i = 1, size(x_opt,1) do j = 1, size(x_opt,2) num_elements = size(x_opt(i,j)%val, dim=1) itmp1 = itmp1 + 1 x_grad(itmp1:itmp1+num_elements-1) = & this%model(root_id)%layer%output(i,j)%grad%val(:,1) x_flat(itmp1:itmp1+num_elements-1) = & x_opt(i,j)%val(:,1) itmp1 = itmp1 + num_elements - 1 end do end do end if else x_grad = 0._real32 end if ! Update x using the optimiser (not the model weights) opt%iter = opt%iter + 1 call opt%minimise(x_flat, x_grad) ! Convert flat x back to array form itmp1 = 0 if(this%use_graph_input)then do i = 1, size(x_opt(1,1)%val, dim=2) itmp1 = itmp1 + 1 x_opt(1,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(1,1)%val, dim=1)-1) itmp1 = itmp1 + size(x_opt(1,1)%val, dim=1) - 1 end do if(use_edge_features)then do i = 1, size(x_opt(2,1)%val, dim=2) itmp1 = itmp1 + 1 x_opt(2,1)%val(:,i) = x_flat(itmp1:itmp1+size(x_opt(2,1)%val, dim=1)-1) itmp1 = itmp1 + size(x_opt(2,1)%val, dim=1) - 1 end do end if else do i = 1, size(x_opt,1) do j = 1, size(x_opt,2) itmp1 = itmp1 + 1 x_opt(i,j)%val(:,1) = x_flat(itmp1:itmp1+size(x_opt(i,j)%val, dim=1)-1) itmp1 = itmp1 + size(x_opt(i,j)%val, dim=1) - 1 end do end do end if ! Clean up computation graph call loss%nullify_graph() deallocate(loss) nullify(loss) ! Reset network parameter gradients so they remain unchanged call this%reset_gradients() end do !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) !--------------------------------------------------------------------------- ! Restore network parameters to ensure model is unchanged !--------------------------------------------------------------------------- call this%set_params(saved_params) end function inverse_design_array_2d !############################################################################### end submodule athena__network_submodule