create_from_onnx_kipf_layer Function

public function create_from_onnx_kipf_layer(meta_key, meta_value, inits, verbose) result(layer)

Build Kipf GCN layer from ONNX metadata and return layer

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: meta_key

GNN metadata key (e.g. "athena_gnn_node_1")

character(len=*), intent(in) :: meta_value

Semicolon-separated GNN metadata value string

type(onnx_initialiser_type), intent(in), dimension(:) :: inits

ONNX initialisers (valid entries only)

integer, intent(in), optional :: verbose

Verbosity level

Return Value class(base_layer_type), allocatable

Constructed Kipf GCN layer


Source Code

  function create_from_onnx_kipf_layer( &
       meta_key, meta_value, inits, verbose &
  ) result(layer)
    !! Build Kipf GCN layer from ONNX metadata and return layer
    use athena__kipf_msgpass_layer, only: kipf_msgpass_layer_type
    use athena__onnx_utils, only: row_to_col_major_2d, &
         parse_space_separated_ints
    implicit none

    ! Arguments
    character(*), intent(in) :: meta_key
    !! GNN metadata key (e.g. "athena_gnn_node_1")
    character(*), intent(in) :: meta_value
    !! Semicolon-separated GNN metadata value string
    type(onnx_initialiser_type), dimension(:), intent(in) :: inits
    !! ONNX initialisers (valid entries only)
    integer, optional, intent(in) :: verbose
    !! Verbosity level
    class(base_layer_type), allocatable :: layer
    !! Constructed Kipf GCN layer

    ! Local variables
    integer :: nts
    integer, allocatable :: nv_arr(:)
    character(64) :: msg_activation
    character(128) :: gnn_prefix
    integer :: t, k, pos, pos2, verbose_
    character(256) :: meta_str, token, key, val
    character(128) :: init_prefix
    real(real32), allocatable :: col_data(:)

    verbose_ = 0
    if(present(verbose)) verbose_ = verbose

    ! Parse hyperparameters from the metadata value string
    meta_str = meta_value
    nts = 1
    msg_activation = 'sigmoid'

    pos = 1
    do while(pos .le. len_trim(meta_str))
       pos2 = index(meta_str(pos:), ';')
       if(pos2 .eq. 0)then
          token = meta_str(pos:len_trim(meta_str))
          pos = len_trim(meta_str) + 1
       else
          token = meta_str(pos:pos+pos2-2)
          pos = pos + pos2
       end if
       k = index(token, '=')
       if(k .eq. 0) cycle
       key = trim(adjustl(token(1:k-1)))
       val = trim(adjustl(token(k+1:)))
       select case(trim(key))
       case('num_time_steps')
          read(val, *) nts
       case('num_vertex_features')
          call parse_space_separated_ints(val, nv_arr)
       case('message_activation')
          msg_activation = trim(val)
       end select
    end do

    if(.not.allocated(nv_arr)) allocate(nv_arr(1), source=1)

    ! Derive initialiser name prefix from the metadata key
    gnn_prefix = trim(meta_key)
    pos = index(gnn_prefix, 'athena_gnn_')
    if(pos .gt. 0) gnn_prefix = gnn_prefix(pos+11:)

    block
      type(kipf_msgpass_layer_type) :: kipf_layer

      kipf_layer = kipf_msgpass_layer_type( &
           num_vertex_features = nv_arr, &
           num_time_steps = nts, &
           activation = trim(msg_activation) &
      )

      do t = 1, nts
         ! Message weight: node_X_t{t}_W
         write(init_prefix, '(A,"_t",I0,"_W")') trim(gnn_prefix), t
         do k = 1, size(inits)
            if(trim(inits(k)%name) .eq. trim(init_prefix))then
               if(allocated(inits(k)%data) .and. &
                    allocated(kipf_layer%params))then
                  allocate(col_data(size(inits(k)%data)))
                  call row_to_col_major_2d( &
                       inits(k)%data, col_data, nv_arr(t+1), nv_arr(t))
                  kipf_layer%params(t)%val(:,1) = col_data
                  deallocate(col_data)
               end if
               exit
            end if
         end do
      end do

      allocate(layer, source=kipf_layer)
    end block

  end function create_from_onnx_kipf_layer