Save input to network
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(network_type), | intent(inout) | :: | this |
Instance of network |
||
| class(*), | intent(in), | dimension(..) | :: | input |
Input |
Number of samples
module function save_input_to_network( this, input ) result(num_samples) !! Save input to network implicit none ! Arguments class(network_type), intent(inout) :: this !! Instance of network class(*), dimension(..), intent(in) :: input !! Input integer :: num_samples !! Number of samples ! Local variables integer :: i, j, l, ip, input_rank, num_inputs !! Loop index integer :: num_input_layers !! Number of input layers logical :: l_valid_rank_type !! Boolean whether rank type is valid character(256) :: err_msg !! Error message num_samples = get_num_samples(this, input) if(num_samples.le.0) return num_input_layers = size(this%root_vertices, 1) if(allocated(this%input_array))then do i = 1, size(this%input_array, 1) do j = 1, size(this%input_array, 2) call this%input_array(i,j)%deallocate() end do end do deallocate(this%input_array) end if if(allocated(this%input_graph)) deallocate(this%input_graph) ! Determine the rank of the input !--------------------------------------------------------------------------- select rank(input) rank(0) rank(1) rank(2) select type(input) class is(array_type) num_inputs = size(input(1,1)%val, 1) allocate(this%input_array(size(input,1), size(input,2))) do i = 1, size(input,1) do j = 1, size(input,2) call this%input_array(i,j)%assign_shallow(input(i,j)) end do end do return class default input_rank = rank(input) num_inputs = size(input) / num_samples allocate(this%input_array(1,1)) call this%input_array(1,1)%allocate(array_shape=[num_inputs, num_samples]) end select rank default input_rank = rank(input) num_inputs = size(input) / num_samples allocate(this%input_array(1,1)) call this%input_array(1,1)%allocate(array_shape=shape(input)) end select l_valid_rank_type = .false. ! Process input based on its rank !--------------------------------------------------------------------------- rank_select: select rank(input) rank(0) select type(input) type is(real); exit rank_select class default; l_valid_rank_type = .true. end select if(num_input_layers.ne.1)then call stop_program( & "number of input arrays does not match expected number of & &input layers" & ) return end if select type(input) class is(array_type) allocate(this%input_array(1,1)) call handle_array_type(input, this%input_array(1,1), num_samples) type is(array_ptr_type) allocate(this%input_array(size(input%array,1), size(input%array,2))) do i = 1, size(input%array,1) do j = 1, size(input%array,2) call handle_array_type( & input%array(i,j), this%input_array(i,j), num_samples & ) end do end do end select rank(1) select type(input) type is(real(real32)) exit rank_select type is(graph_type) allocate(this%input_graph(num_input_layers, num_samples)) this%input_graph(1,:) = input(:) return class default l_valid_rank_type = .true. end select if(size(input,1).ne.num_input_layers)then call stop_program( & "number of input arrays does not match expected number of & &input layers" & ) return end if select type(input) class is(array_type) allocate(this%input_array(1,size(input,1))) do l = 1, size(input,1) call handle_array_type(input(l), this%input_array(1,l), num_samples) end do type is(array_ptr_type) call stop_program("Use of array_ptr_type with rank 1 input not yet supported") return ! ip = 0 ! do l = 1, size(input,1) ! do i = 1, size(input%array,1) ! ip = ip + 1 ! do j = 1, size(input%array,2) ! call handle_array_type( & ! input(l)%array(i,j), this%input_array(ip,j), num_samples & ! ) ! end do ! end do ! end do end select rank(2) select type(input) type is(real(real32)) this%input_array(1,1)%val = reshape(input, [num_inputs, num_samples]) l_valid_rank_type = .true. type is(graph_type) num_samples = size(input, dim=2) allocate(this%input_graph(num_input_layers, num_samples)) this%input_graph(:,:) = input(:,:) return type is(array_type) call stop_program("SHOULD NOT GET HERE") this%input_array = input l_valid_rank_type = .true. end select rank(3) select type(input) type is(real(real32)) call this%input_array(1,1)%set(input) l_valid_rank_type = .true. end select rank(4) select type(input) type is(real(real32)) call this%input_array(1,1)%set(input) l_valid_rank_type = .true. end select rank(5) select type(input) type is(real(real32)) call this%input_array(1,1)%set(input) l_valid_rank_type = .true. end select end select rank_select if(.not.l_valid_rank_type)then write(err_msg,'("Unknown input type for rank ",I0)') input_rank call stop_program(err_msg) return end if contains function get_num_samples(network, input) result(num_samples) implicit none !! Get the number of samples in the input ! Arguments type(network_type), intent(in) :: network !! Instance of network class(*), dimension(..), intent(in) :: input !! Input integer :: num_samples !! Number of samples ! Local variables integer :: layer_id !! Layer ID logical :: use_graph_input !! Whether to use graph input num_samples = 0 layer_id = network%auto_graph%vertex(network%root_vertices(1))%id use_graph_input = network%model(layer_id)%layer%use_graph_input select rank(input) rank(0) select type(input) class is(array_type) num_samples = size(input%val, 2) class is(array_ptr_type) num_samples = size(input%array(1,1)%val, 2) class default call stop_program("Unknown input type in get_num_samples for rank 0") return end select rank(1) select type(input) class is(array_type) if(use_graph_input)then num_samples = size(input) else num_samples = size(input(1)%val, 2) end if class is(array_ptr_type) if(use_graph_input)then num_samples = size(input(1)%array, 2) else num_samples = size(input(1)%array(1,1)%val, 2) end if class is(graph_type) num_samples = size(input, dim=1) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 1") return end select rank(2) select type(input) class is(array_type) if(use_graph_input)then num_samples = size(input, 2) else num_samples = size(input(1,1)%val, 2) end if class is(graph_type) num_samples = size(input, dim=2) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 2") return end select rank(3) select type(input) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 3") return end select rank(4) select type(input) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 4") return end select rank(5) select type(input) type is(real) num_samples = size(input, rank(input)) class default call stop_program("Unknown input type in get_num_samples for rank 5") return end select rank default call stop_program("Unknown input rank in get_num_samples") return end select end function get_num_samples subroutine handle_array_type(input, output, num_samples) !! Handle array type input ! Arguments class(array_type), intent(in) :: input !! Input type(array_type), intent(out) :: output !! Output integer, intent(in) :: num_samples !! Number of samples if(size(input%val,2).ne.num_samples)then call stop_program("number of samples in input arrays do not match") return end if call output%allocate( array_shape = & [ product(input%shape(1:input%rank)), num_samples ] & ) output%val = input%val end subroutine handle_array_type end function save_input_to_network