save_input_to_network Module Function

module function save_input_to_network(this, input) result(num_samples)

Save input to network

Arguments

Type IntentOptional Attributes Name
class(network_type), intent(inout) :: this

Instance of network

class(*), intent(in), dimension(..) :: input

Input

Return Value integer

Number of samples


Source Code

  module function save_input_to_network( this, input ) result(num_samples)
    !! Save input to network
    implicit none

    ! Arguments
    class(network_type), intent(inout) :: this
    !! Instance of network
    class(*), dimension(..), intent(in) :: input
    !! Input

    integer :: num_samples
    !! Number of samples

    ! Local variables
    integer :: i, j, l, ip, input_rank, num_inputs
    !! Loop index
    integer :: num_input_layers
    !! Number of input layers
    logical :: l_valid_rank_type
    !! Boolean whether rank type is valid
    character(256) :: err_msg
    !! Error message

    num_samples = get_num_samples(this, input)
    if(num_samples.le.0) return
    num_input_layers = size(this%root_vertices, 1)
    if(allocated(this%input_array))then
       do i = 1, size(this%input_array, 1)
          do j = 1, size(this%input_array, 2)
             call this%input_array(i,j)%deallocate()
          end do
       end do
       deallocate(this%input_array)
    end if
    if(allocated(this%input_graph)) deallocate(this%input_graph)

    ! Determine the rank of the input
    !---------------------------------------------------------------------------
    select rank(input)
    rank(0)
    rank(1)
    rank(2)
       select type(input)
       class is(array_type)
          num_inputs = size(input(1,1)%val, 1)
          allocate(this%input_array(size(input,1), size(input,2)))
          do i = 1, size(input,1)
             do j = 1, size(input,2)
                call this%input_array(i,j)%assign_shallow(input(i,j))
             end do
          end do
          return
       class default
          input_rank = rank(input)
          num_inputs = size(input) / num_samples
          allocate(this%input_array(1,1))
          call this%input_array(1,1)%allocate(array_shape=[num_inputs, num_samples])
       end select
    rank default
       input_rank = rank(input)
       num_inputs = size(input) / num_samples
       allocate(this%input_array(1,1))
       call this%input_array(1,1)%allocate(array_shape=shape(input))
    end select
    l_valid_rank_type = .false.


    ! Process input based on its rank
    !---------------------------------------------------------------------------
    rank_select: select rank(input)
    rank(0)
       select type(input)
       type is(real); exit rank_select
       class default; l_valid_rank_type = .true.
       end select
       if(num_input_layers.ne.1)then
          call stop_program( &
               "number of input arrays does not match expected number of &
               &input layers" &
          )
          return
       end if
       select type(input)
       class is(array_type)
          allocate(this%input_array(1,1))
          call handle_array_type(input, this%input_array(1,1), num_samples)
       type is(array_ptr_type)
          allocate(this%input_array(size(input%array,1), size(input%array,2)))
          do i = 1, size(input%array,1)
             do j = 1, size(input%array,2)
                call handle_array_type( &
                     input%array(i,j), this%input_array(i,j), num_samples &
                )
             end do
          end do
       end select
    rank(1)
       select type(input)
       type is(real(real32))
          exit rank_select
       type is(graph_type)
          allocate(this%input_graph(num_input_layers, num_samples))
          this%input_graph(1,:) = input(:)
          return
       class default
          l_valid_rank_type = .true.
       end select
       if(size(input,1).ne.num_input_layers)then
          call stop_program( &
               "number of input arrays does not match expected number of &
               &input layers" &
          )
          return
       end if
       select type(input)
       class is(array_type)
          allocate(this%input_array(1,size(input,1)))
          do l = 1, size(input,1)
             call handle_array_type(input(l), this%input_array(1,l), num_samples)
          end do
       type is(array_ptr_type)
          call stop_program("Use of array_ptr_type with rank 1 input not yet supported")
          return
          ! ip = 0
          ! do l = 1, size(input,1)
          !       do i = 1, size(input%array,1)
          !          ip = ip + 1
          !          do j = 1, size(input%array,2)
          !             call handle_array_type( &
          !                  input(l)%array(i,j), this%input_array(ip,j), num_samples &
          !             )
          !          end do
          !       end do
          ! end do
       end select
    rank(2)
       select type(input)
       type is(real(real32))
          this%input_array(1,1)%val = reshape(input, [num_inputs, num_samples])
          l_valid_rank_type = .true.
       type is(graph_type)
          num_samples = size(input, dim=2)
          allocate(this%input_graph(num_input_layers, num_samples))
          this%input_graph(:,:) = input(:,:)
          return
       type is(array_type)
          call stop_program("SHOULD NOT GET HERE")
          this%input_array = input
          l_valid_rank_type = .true.
       end select
    rank(3)
       select type(input)
       type is(real(real32))
          call this%input_array(1,1)%set(input)
          l_valid_rank_type = .true.
       end select
    rank(4)
       select type(input)
       type is(real(real32))
          call this%input_array(1,1)%set(input)
          l_valid_rank_type = .true.
       end select
    rank(5)
       select type(input)
       type is(real(real32))
          call this%input_array(1,1)%set(input)
          l_valid_rank_type = .true.
       end select
    end select rank_select

    if(.not.l_valid_rank_type)then
       write(err_msg,'("Unknown input type for rank ",I0)') input_rank
       call stop_program(err_msg)
       return
    end if

  contains

    function get_num_samples(network, input) result(num_samples)
      implicit none
      !! Get the number of samples in the input

      ! Arguments
      type(network_type), intent(in) :: network
      !! Instance of network
      class(*), dimension(..), intent(in) :: input
      !! Input
      integer :: num_samples
      !! Number of samples

      ! Local variables
      integer :: layer_id
      !! Layer ID
      logical :: use_graph_input
      !! Whether to use graph input

      num_samples = 0
      layer_id = network%auto_graph%vertex(network%root_vertices(1))%id
      use_graph_input = network%model(layer_id)%layer%use_graph_input
      select rank(input)
      rank(0)
         select type(input)
         class is(array_type)
            num_samples = size(input%val, 2)
         class is(array_ptr_type)
            num_samples = size(input%array(1,1)%val, 2)
         class default
            call stop_program("Unknown input type in get_num_samples for rank 0")
            return
         end select
      rank(1)
         select type(input)
         class is(array_type)
            if(use_graph_input)then
               num_samples = size(input)
            else
               num_samples = size(input(1)%val, 2)
            end if
         class is(array_ptr_type)
            if(use_graph_input)then
               num_samples = size(input(1)%array, 2)
            else
               num_samples = size(input(1)%array(1,1)%val, 2)
            end if
         class is(graph_type)
            num_samples = size(input, dim=1)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 1")
            return
         end select
      rank(2)
         select type(input)
         class is(array_type)
            if(use_graph_input)then
               num_samples = size(input, 2)
            else
               num_samples = size(input(1,1)%val, 2)
            end if
         class is(graph_type)
            num_samples = size(input, dim=2)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 2")
            return
         end select
      rank(3)
         select type(input)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 3")
            return
         end select
      rank(4)
         select type(input)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 4")
            return
         end select
      rank(5)
         select type(input)
         type is(real)
            num_samples = size(input, rank(input))
         class default
            call stop_program("Unknown input type in get_num_samples for rank 5")
            return
         end select
      rank default
         call stop_program("Unknown input rank in get_num_samples")
         return
      end select

    end function get_num_samples


    subroutine handle_array_type(input, output, num_samples)
      !! Handle array type input

      ! Arguments
      class(array_type), intent(in) :: input
      !! Input
      type(array_type), intent(out) :: output
      !! Output
      integer, intent(in) :: num_samples
      !! Number of samples

      if(size(input%val,2).ne.num_samples)then
         call stop_program("number of samples in input arrays do not match")
         return
      end if
      call output%allocate( array_shape = &
           [ product(input%shape(1:input%rank)), num_samples ] &
      )
      output%val = input%val
    end subroutine handle_array_type

  end function save_input_to_network