Predict the output for a generic input
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(inout) | :: | this |
Instance of network |
||
| class(*), | intent(in), | dimension(:,:) | :: | input |
Input graph |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
|
| logical, | intent(in), | optional | :: | output_as_graph |
Boolean whether to output as graph |
Predicted output
module function predict_generic( this, input, verbose, output_as_graph ) & result(output) !! Predict the output for a generic input implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(:,:), intent(in) :: input !! Input graph integer, intent(in), optional :: verbose !! Verbosity level logical, intent(in), optional :: output_as_graph !! Boolean whether to output as graph class(*), dimension(:,:), allocatable :: output !! Predicted output ! Local variables integer :: l, s, i, j, layer_id !! Loop index integer :: num_samples !! Number of samples integer :: verbose_ !! Verbosity level logical :: output_as_graph_ !! Output as graph boolean integer, dimension(2) :: output_shape !! Output shape logical, allocatable :: mode_store(:) !! Storage for inference mode booleans !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose))then verbose_ = verbose else verbose_ = 0 end if if(present(output_as_graph))then output_as_graph_ = output_as_graph else output_as_graph_ = .false. end if if(output_as_graph_.and..not.this%use_graph_output)then call stop_program("output_as_graph is true but network does not use & &graph output") end if !--------------------------------------------------------------------------- ! Set number of samples for predicting !--------------------------------------------------------------------------- num_samples = this%save_input( input ) call this%set_batch_size(num_samples) !--------------------------------------------------------------------------- ! Enable inference mode !--------------------------------------------------------------------------- call this%set_inference_mode(mode_store) !--------------------------------------------------------------------------- ! Forward pass !--------------------------------------------------------------------------- select case(this%use_graph_input) case(.true.) call this%forward(this%input_graph) case default call this%forward(this%input_array) end select !--------------------------------------------------------------------------- ! Allocate output data !--------------------------------------------------------------------------- output_shape = this%get_output_shape() if(output_as_graph_)then allocate(output(output_shape(1), output_shape(2)), source = graph_type()) select type(output) type is(graph_type) select type(input) type is(graph_type) do l = 1, size(this%leaf_vertices) do s = 1, num_samples output(l,s)%num_vertices = input(1,s)%num_vertices output(l,s)%num_edges = input(1,s)%num_edges output(l,s)%num_vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(1) output(l,s)%num_edge_features = this%model( & this%leaf_vertices(l) & )%layer%output_shape(2) output(l,s)%vertex_features = this%model( & this%leaf_vertices(l) & )%layer%output(1,s)%val if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then output(l,s)%edge_features = input(1,s)%edge_features else output(l,s)%edge_features = this%model( & this%leaf_vertices(l) & )%layer%output(2,s)%val end if end do end do class default call stop_program("input is not of type graph_type") end select class default call stop_program("allocation of output as graph_type failed") end select else output_shape = this%get_output_shape() allocate(output(output_shape(1), output_shape(2)), source = array_type()) select type(output) type is(array_type) do l = 1, size(this%leaf_vertices) layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id j = 0 do i = 1, size(this%model(layer_id)%layer%output, 1) j = j + 1 do s = 1, size(this%model(layer_id)%layer%output, 2) output(j,s) = this%model(layer_id)%layer%output(i,s) end do end do end do end select end if !--------------------------------------------------------------------------- ! Restore training/inference mode !--------------------------------------------------------------------------- call this%restore_mode(mode_store) end function predict_generic