predict_graph1d Module Function

module function predict_graph1d(this, input, verbose) result(output)

Predict the output for a graph input

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

type(graph_type), intent(in), dimension(:) :: input

Input graph

integer, intent(in), optional :: verbose

Verbosity level

Return Value type(graph_type), dimension(size(this%leaf_vertices),size(input))

Output graph


Source Code

  module function predict_graph1d( this, input, verbose ) result(output)
    !! Predict the output for a graph input
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    type(graph_type), dimension(:), intent(in) :: input
    !! Input graph
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: l, s
    !! Loop index
    type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: output
    !! Output graph
    integer :: verbose_ = 0, batch_size
    !! Verbosity level
    logical, allocatable :: mode_store(:)
    !! Storage for inference mode booleans


    !---------------------------------------------------------------------------
    ! Initialise optional arguments
    !---------------------------------------------------------------------------
    if(present(verbose)) verbose_ = verbose

    !---------------------------------------------------------------------------
    ! Reset batch size for testing
    !---------------------------------------------------------------------------
    batch_size = size(input)
    call this%set_batch_size(batch_size)


    !---------------------------------------------------------------------------
    ! Enable inference mode
    !---------------------------------------------------------------------------
    call this%set_inference_mode(mode_store)


    !---------------------------------------------------------------------------
    ! Predict
    !---------------------------------------------------------------------------
    call this%forward(get_sample(input, 1, batch_size, batch_size))

    do l = 1, size(this%leaf_vertices)
       do s = 1, batch_size
          output(l,s)%num_vertices = input(s)%num_vertices
          output(l,s)%num_edges = input(s)%num_edges
          output(l,s)%num_vertex_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output_shape(1)
          output(l,s)%num_edge_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output_shape(2)
          output(l,s)%vertex_features = this%model( &
               this%leaf_vertices(l) &
          )%layer%output(1,s)%val
          if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
             output(l,s)%edge_features = input(s)%edge_features
          else
             output(l,s)%edge_features = this%model( &
                  this%leaf_vertices(l) &
             )%layer%output(2,s)%val
          end if
       end do
    end do


    !---------------------------------------------------------------------------
    ! Restore training/inference mode
    !---------------------------------------------------------------------------
    call this%restore_mode(mode_store)

  end function predict_graph1d