Forward pass for array derived type input
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(inout), | target | :: | this |
Instance of network |
|
| class(*), | intent(in), | dimension(:,:) | :: | input |
Input |
module subroutine forward_generic2d(this, input) !! Forward pass for array derived type input implicit none ! Arguments class(network_type), intent(inout), target :: this !! Instance of network class(*), dimension(:,:), intent(in) :: input !! Input ! Local variables integer :: l, i, j, vertex_idx, layer_id, parent_id !! Loop index and vertex index integer :: num_input_layers !! Number of input layers type(array_type), pointer :: input_ptr(:,:) => null() type(array_ptr_type), dimension(:), allocatable :: input_list logical :: use_precomp select type(input) type is(graph_type) do j = 1, this%batch_size if(any(input(1,j)%adj_ja(1,:).gt.input(1,j)%num_vertices))then call stop_program( & "input graph has more vertices than expected" & ) end if end do end select ! Use pre-computed navigation if available use_precomp = allocated(this%fwd_layer_id) ! Forward pass !--------------------------------------------------------------------------- do l = 1, size(this%vertex_order,1) if(use_precomp)then layer_id = this%fwd_layer_id(l) num_input_layers = this%fwd_num_inputs(l) else vertex_idx = this%vertex_order(l) layer_id = this%auto_graph%vertex(vertex_idx)%id num_input_layers = count(this%auto_graph%adjacency(:,vertex_idx).gt.0) end if if(num_input_layers.eq.0)then select type(layer => this%model(layer_id)%layer) class is(input_layer_type) select type(input) type is(graph_type) call layer%set_input_graph( [ input(layer%index, :) ] ) cycle class is(array_type) call layer%forward(input(layer%index:layer%index,:)) do concurrent(i=1:size(layer%output,1), j=1:size(layer%output,2)) call layer%output(i,j)%set_requires_grad(.false.) end do cycle type is(real(real32)) allocate(input_ptr(1,1)) call input_ptr(1,1)%allocate(shape(input)) call input_ptr(1,1)%set(input) call layer%forward(input_ptr) call layer%output(1,1)%set_requires_grad(.false.) deallocate(input_ptr) input_ptr => null() cycle class default call stop_program( & "input type for layer "// & trim(layer%name) // & " is not supported" & ) end select class default return end select elseif(num_input_layers.eq.1)then if(use_precomp)then parent_id = this%fwd_parent_id(l) else vertex_idx = this%vertex_order(l) j = maxloc( & this%auto_graph%adjacency(:,vertex_idx), dim=1) parent_id = this%auto_graph%vertex(j)%id end if input_ptr => this%model(parent_id)%layer%output select type(input) type is(graph_type) call this%model(layer_id)%layer%set_graph( [ input(1,:) ] ) end select else vertex_idx = this%vertex_order(l) allocate(input_list(num_input_layers)) i = 0 do j = 1, size(this%vertex_order,1) if(this%auto_graph%adjacency(j,vertex_idx).gt.0)then i = i + 1 parent_id = this%auto_graph%vertex(j)%id input_list(i)%array => this%model(parent_id)%layer%output end if end do end if if(use_precomp)then if(this%fwd_layer_type(l).eq.1)then select type(layer => this%model(layer_id)%layer) class is(merge_layer_type) call layer%combine(input_list) end select deallocate(input_list) else call this%model(layer_id)%layer%forward(input_ptr) input_ptr => null() end if else select type(layer => this%model(layer_id)%layer) class is(merge_layer_type) call layer%combine(input_list) deallocate(input_list) class default call layer%forward(input_ptr) input_ptr => null() end select end if end do end subroutine forward_generic2d