forward_generic2d Module Subroutine

module subroutine forward_generic2d(this, input)

Forward pass for array derived type input

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout), target :: this

Instance of network

class(*), intent(in), dimension(:,:) :: input

Input


Source Code

  module subroutine forward_generic2d(this, input)
    !! Forward pass for array derived type input
    implicit none

    ! Arguments
    class(network_type), intent(inout), target :: this
    !! Instance of network
    class(*), dimension(:,:), intent(in) :: input
    !! Input

    ! Local variables
    integer :: l, i, j, vertex_idx, layer_id, parent_id
    !! Loop index and vertex index
    integer :: num_input_layers
    !! Number of input layers
    type(array_type), pointer :: input_ptr(:,:) => null()
    type(array_ptr_type), dimension(:), allocatable :: input_list
    logical :: use_precomp


    select type(input)
    type is(graph_type)
       do j = 1, this%batch_size
          if(any(input(1,j)%adj_ja(1,:).gt.input(1,j)%num_vertices))then
             call stop_program( &
                  "input graph has more vertices than expected" &
             )
          end if
       end do
    end select

    ! Use pre-computed navigation if available
    use_precomp = allocated(this%fwd_layer_id)

    ! Forward pass
    !---------------------------------------------------------------------------
    do l = 1, size(this%vertex_order,1)
       if(use_precomp)then
          layer_id = this%fwd_layer_id(l)
          num_input_layers = this%fwd_num_inputs(l)
       else
          vertex_idx = this%vertex_order(l)
          layer_id = this%auto_graph%vertex(vertex_idx)%id
          num_input_layers = count(this%auto_graph%adjacency(:,vertex_idx).gt.0)
       end if

       if(num_input_layers.eq.0)then
          select type(layer => this%model(layer_id)%layer)
          class is(input_layer_type)
             select type(input)
             type is(graph_type)
                call layer%set_input_graph( [ input(layer%index, :) ] )
                cycle
             class is(array_type)
                call layer%forward(input(layer%index:layer%index,:))
                do concurrent(i=1:size(layer%output,1), j=1:size(layer%output,2))
                   call layer%output(i,j)%set_requires_grad(.false.)
                end do
                cycle
             type is(real(real32))
                allocate(input_ptr(1,1))
                call input_ptr(1,1)%allocate(shape(input))
                call input_ptr(1,1)%set(input)
                call layer%forward(input_ptr)
                call layer%output(1,1)%set_requires_grad(.false.)
                deallocate(input_ptr)
                input_ptr => null()
                cycle
             class default
                call stop_program( &
                     "input type for layer "// &
                     trim(layer%name) // &
                     " is not supported" &
                )
             end select
          class default
             return
          end select
       elseif(num_input_layers.eq.1)then
          if(use_precomp)then
             parent_id = this%fwd_parent_id(l)
          else
             vertex_idx = this%vertex_order(l)
             j = maxloc( &
                  this%auto_graph%adjacency(:,vertex_idx), dim=1)
             parent_id = this%auto_graph%vertex(j)%id
          end if
          input_ptr => this%model(parent_id)%layer%output
          select type(input)
          type is(graph_type)
             call this%model(layer_id)%layer%set_graph( [ input(1,:) ] )
          end select
       else
          vertex_idx = this%vertex_order(l)
          allocate(input_list(num_input_layers))
          i = 0
          do j = 1, size(this%vertex_order,1)
             if(this%auto_graph%adjacency(j,vertex_idx).gt.0)then
                i = i + 1
                parent_id = this%auto_graph%vertex(j)%id
                input_list(i)%array => this%model(parent_id)%layer%output
             end if
          end do
       end if

       if(use_precomp)then
          if(this%fwd_layer_type(l).eq.1)then
             select type(layer => this%model(layer_id)%layer)
             class is(merge_layer_type)
                call layer%combine(input_list)
             end select
             deallocate(input_list)
          else
             call this%model(layer_id)%layer%forward(input_ptr)
             input_ptr => null()
          end if
       else
          select type(layer => this%model(layer_id)%layer)
          class is(merge_layer_type)
             call layer%combine(input_list)
             deallocate(input_list)
          class default
             call layer%forward(input_ptr)
             input_ptr => null()
          end select
       end if

    end do

  end subroutine forward_generic2d