predict_generic Module Function

module function predict_generic(this, input, verbose, output_as_graph) result(output)

Predict the output for a generic input

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

class(*), intent(in), dimension(:,:) :: input

Input graph

integer, intent(in), optional :: verbose

Verbosity level

logical, intent(in), optional :: output_as_graph

Boolean whether to output as graph

Return Value class(*), dimension(:,:), allocatable

Predicted output


Source Code

  module function predict_generic( this, input, verbose, output_as_graph ) &
       result(output)
    !! Predict the output for a generic input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: input
    !! Input graph
    integer, intent(in), optional :: verbose
    !! Verbosity level
    logical, intent(in), optional :: output_as_graph
    !! Boolean whether to output as graph

    class(*), dimension(:,:), allocatable :: output
    !! Predicted output

    ! Local variables
    integer :: l, s, i, j, layer_id
    !! Loop index
    integer :: num_samples
    !! Number of samples
    integer :: verbose_
    !! Verbosity level
    logical :: output_as_graph_
    !! Output as graph boolean
    integer, dimension(2) :: output_shape
    !! Output shape
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose))then
       verbose_ = verbose
    else
       verbose_ = 0
    end if

    if(present(output_as_graph))then
       output_as_graph_ = output_as_graph
    else
       output_as_graph_ = .false.
    end if
    if(output_as_graph_.and..not.this%use_graph_output)then
       call stop_program("output_as_graph is true but network does not use &
            &graph output")
    end if


    !---------------------------------------------------------------------------
    ! Set number of samples for predicting
    !---------------------------------------------------------------------------
    num_samples = this%save_input( input )
    call this%set_batch_size(num_samples)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)

    !---------------------------------------------------------------------------
    ! Forward pass
    !---------------------------------------------------------------------------
    select case(this%use_graph_input)
    case(.true.)
       call this%forward(this%input_graph)
    case default
       call this%forward(this%input_array)
    end select


    !---------------------------------------------------------------------------
    ! Allocate output data
    !---------------------------------------------------------------------------
    output_shape = this%get_output_shape()
    if(output_as_graph_)then
       allocate(output(output_shape(1), output_shape(2)), source = graph_type())
       select type(output)
       type is(graph_type)
          select type(input)
          type is(graph_type)
             do l = 1, size(this%leaf_vertices)
                do s = 1, num_samples
                   output(l,s)%num_vertices = input(1,s)%num_vertices
                   output(l,s)%num_edges = input(1,s)%num_edges
                   output(l,s)%num_vertex_features = this%model( &
                        this%leaf_vertices(l) &
                   )%layer%output_shape(1)
                   output(l,s)%num_edge_features = this%model( &
                        this%leaf_vertices(l) &
                   )%layer%output_shape(2)
                   output(l,s)%vertex_features = this%model( &
                        this%leaf_vertices(l) &
                   )%layer%output(1,s)%val
                   if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
                      output(l,s)%edge_features = input(1,s)%edge_features
                   else
                      output(l,s)%edge_features = this%model( &
                           this%leaf_vertices(l) &
                      )%layer%output(2,s)%val
                   end if
                end do
             end do
          class default
             call stop_program("input is not of type graph_type")
          end select
       class default
          call stop_program("allocation of output as graph_type failed")
       end select
    else
       output_shape = this%get_output_shape()
       allocate(output(output_shape(1), output_shape(2)), source = array_type())
       select type(output)
       type is(array_type)
          do l = 1, size(this%leaf_vertices)
             layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id
             j = 0
             do i = 1, size(this%model(layer_id)%layer%output, 1)
                j = j + 1
                do s = 1, size(this%model(layer_id)%layer%output, 2)
                   output(j,s) = this%model(layer_id)%layer%output(i,s)
                end do
             end do
          end do
       end select
    end if


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_generic