build_kipf_onnx_expanded_gnn Function

public function build_kipf_onnx_expanded_gnn(prefix, nodes, num_nodes, inits, num_inits, inputs, num_inputs) result(layer)

Build a Kipf GCN layer from an expanded-ONNX cluster.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix

Layer node prefix (e.g. "node_2")

type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Number of valid node entries

type(onnx_initialiser_type), intent(in) :: inits(:)

Parsed ONNX initialisers

integer, intent(in) :: num_inits

Number of valid initialiser entries

type(onnx_tensor_type), intent(in) :: inputs(:)

Parsed ONNX graph input tensors

integer, intent(in) :: num_inputs

Number of valid graph input entries

Return Value class(base_layer_type), allocatable

Constructed Kipf GCN layer


Source Code

  function build_kipf_onnx_expanded_gnn( &
       prefix, nodes, num_nodes, inits, &
       num_inits, inputs, num_inputs) &
  result(layer)
    !! Build a Kipf GCN layer from an expanded-ONNX cluster.
    use athena__kipf_msgpass_layer, only: &
         kipf_msgpass_layer_type
    use athena__onnx_nop_utils, only: &
         find_initialiser_by_name
    use athena__onnx_utils, only: row_to_col_major_2d
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Layer node prefix (e.g. "node_2")
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    integer, intent(in) :: num_nodes
    !! Number of valid node entries
    type(onnx_initialiser_type), intent(in) :: inits(:)
    !! Parsed ONNX initialisers
    integer, intent(in) :: num_inits
    !! Number of valid initialiser entries
    type(onnx_tensor_type), intent(in) :: inputs(:)
    !! Parsed ONNX graph input tensors
    integer, intent(in) :: num_inputs
    !! Number of valid graph input entries
    class(base_layer_type), allocatable :: layer
    !! Constructed Kipf GCN layer

    integer :: t, nts, idx
    integer, allocatable :: nv_arr(:)
    character(128) :: init_name
    character(64) :: msg_activation
    real(real32), allocatable :: col_data(:)

    ! Count timesteps by scanning for _t{N}_W inits
    nts = 0
    do t = 1, 99
       write(init_name, '(A,"_t",I0,"_W")') &
            trim(prefix), t
       idx = find_initialiser_by_name( &
            trim(init_name), inits, num_inits)
       if(idx .le. 0) exit
       nts = nts + 1
    end do

    if(nts .eq. 0)then
       call stop_program( &
            'Kipf ONNX cluster has no weights for ' &
            // trim(prefix))
    end if

    ! Build vertex feature array from init dims
    allocate(nv_arr(nts + 1))
    do t = 1, nts
       write(init_name, '(A,"_t",I0,"_W")') &
            trim(prefix), t
       idx = find_initialiser_by_name( &
            trim(init_name), inits, num_inits)
       ! Kipf weight: [nv_out, nv_in]
       nv_arr(t+1) = inits(idx)%dims(1)
       if(t .eq. 1) nv_arr(1) = inits(idx)%dims(2)
    end do

    msg_activation = detect_gnn_expanded_activation( &
         prefix, nodes, num_nodes)

    block
      type(kipf_msgpass_layer_type) :: kipf_layer

      kipf_layer = kipf_msgpass_layer_type( &
           num_vertex_features = nv_arr, &
           num_time_steps = nts, &
           activation = trim(msg_activation) &
      )

      do t = 1, nts
         write(init_name, '(A,"_t",I0,"_W")') &
              trim(prefix), t
         idx = find_initialiser_by_name( &
              trim(init_name), inits, num_inits)
         if(allocated(inits(idx)%data) .and. &
              allocated(kipf_layer%params))then
            allocate(col_data(size(inits(idx)%data)))
            call row_to_col_major_2d( &
                 inits(idx)%data, col_data, &
                 nv_arr(t+1), nv_arr(t))
            kipf_layer%params(t)%val(:,1) = col_data
            deallocate(col_data)
         end if
      end do

      allocate(layer, source=kipf_layer)
    end block

  end function build_kipf_onnx_expanded_gnn