save_output_to_network Module Subroutine

module subroutine save_output_to_network(this, output)

Save output to network

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

class(*), intent(in), dimension(:,:) :: output

Output


Source Code

  module subroutine save_output_to_network( this, output )
    !! Save output to network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: output
    !! Output

    ! Local variables
    integer :: i, j, s
    !! Loop indices

    if(allocated(this%expected_array))then
       do i = 1, size(this%expected_array, 1)
          do j = 1, size(this%expected_array, 2)
             call this%expected_array(i,j)%deallocate()
          end do
       end do
       deallocate(this%expected_array)
    end if

    select type(output)
    type is(graph_type)
       allocate(this%expected_array(2,size(output,2)))
       do s = 1, size(output,2)
          if(this%expected_array(1,s)%allocated) &
               call this%expected_array(1,s)%deallocate()
          if(this%expected_array(2,s)%allocated) &
               call this%expected_array(2,s)%deallocate()
          call this%expected_array(1,s)%allocate( &
               array_shape = [ &
                    output(1,s)%num_vertex_features, output(1,s)%num_vertices &
               ] &
          )
          call this%expected_array(1,s)%zero_grad()
          call this%expected_array(1,s)%set_requires_grad(.false.)
          call this%expected_array(1,s)%set( output(1,s)%vertex_features )
          this%expected_array(1,s)%is_temporary = .false.
          if(output(1,s)%num_edge_features.le.0) cycle
          call this%expected_array(2,s)%allocate( &
               array_shape = [ &
                    output(1,s)%num_edge_features, output(1,s)%num_edges &
               ] &
          )
          call this%expected_array(2,s)%set_requires_grad(.false.)
          call this%expected_array(2,s)%set( output(1,s)%edge_features )
          this%expected_array(2,s)%is_temporary = .false.
       end do
    class is(array_type)
       allocate(this%expected_array(size(output,1),size(output,2)))
       do s = 1, size(output,2)
          do i = 1, size(output,1)
             if(this%expected_array(i,s)%allocated) &
                  call this%expected_array(i,s)%deallocate()
             call this%expected_array(i,s)%allocate( &
                  array_shape = [ &
                       output(i,s)%shape, size(output(i,s)%val,2) &
                  ] &
             )
             call this%expected_array(i,s)%set_requires_grad(.false.)
             call this%expected_array(i,s)%set( output(i,s)%val )
             this%expected_array(i,s)%is_temporary = .false.
          end do
       end do
    type is(real)
       allocate(this%expected_array(1,1))
       if(this%expected_array(1,1)%allocated) &
            call this%expected_array(1,1)%deallocate()
       call this%expected_array(1,1)%allocate( &
            array_shape = [ size(output,1), size(output,2) ] &
       )
       call this%expected_array(1,1)%set_requires_grad(.false.)
       call this%expected_array(1,1)%set( output )
       this%expected_array(1,1)%is_temporary = .false.
    type is(integer)
       allocate(this%expected_array(1,1))
       if(this%expected_array(1,1)%allocated) &
            call this%expected_array(1,1)%deallocate()
       call this%expected_array(1,1)%allocate( &
            array_shape = [ size(output,1), size(output,2) ] &
       )
       call this%expected_array(1,1)%set_requires_grad(.false.)
       this%expected_array(1,1)%val = real(output, real32)
       this%expected_array(1,1)%is_temporary = .false.
    class default
       call stop_program("output type not supported in training")
    end select

  end subroutine save_output_to_network