Build Kipf GCN layer from ONNX metadata and return layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | meta_key |
GNN metadata key (e.g. "athena_gnn_node_1") |
||
| character(len=*), | intent(in) | :: | meta_value |
Semicolon-separated GNN metadata value string |
||
| type(onnx_initialiser_type), | intent(in), | dimension(:) | :: | inits |
ONNX initialisers (valid entries only) |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Constructed Kipf GCN layer
function create_from_onnx_kipf_layer( & meta_key, meta_value, inits, verbose & ) result(layer) !! Build Kipf GCN layer from ONNX metadata and return layer use athena__kipf_msgpass_layer, only: kipf_msgpass_layer_type use athena__onnx_utils, only: row_to_col_major_2d, & parse_space_separated_ints implicit none ! Arguments character(*), intent(in) :: meta_key !! GNN metadata key (e.g. "athena_gnn_node_1") character(*), intent(in) :: meta_value !! Semicolon-separated GNN metadata value string type(onnx_initialiser_type), dimension(:), intent(in) :: inits !! ONNX initialisers (valid entries only) integer, optional, intent(in) :: verbose !! Verbosity level class(base_layer_type), allocatable :: layer !! Constructed Kipf GCN layer ! Local variables integer :: nts integer, allocatable :: nv_arr(:) character(64) :: msg_activation character(128) :: gnn_prefix integer :: t, k, pos, pos2, verbose_ character(256) :: meta_str, token, key, val character(128) :: init_prefix real(real32), allocatable :: col_data(:) verbose_ = 0 if(present(verbose)) verbose_ = verbose ! Parse hyperparameters from the metadata value string meta_str = meta_value nts = 1 msg_activation = 'sigmoid' pos = 1 do while(pos .le. len_trim(meta_str)) pos2 = index(meta_str(pos:), ';') if(pos2 .eq. 0)then token = meta_str(pos:len_trim(meta_str)) pos = len_trim(meta_str) + 1 else token = meta_str(pos:pos+pos2-2) pos = pos + pos2 end if k = index(token, '=') if(k .eq. 0) cycle key = trim(adjustl(token(1:k-1))) val = trim(adjustl(token(k+1:))) select case(trim(key)) case('num_time_steps') read(val, *) nts case('num_vertex_features') call parse_space_separated_ints(val, nv_arr) case('message_activation') msg_activation = trim(val) end select end do if(.not.allocated(nv_arr)) allocate(nv_arr(1), source=1) ! Derive initialiser name prefix from the metadata key gnn_prefix = trim(meta_key) pos = index(gnn_prefix, 'athena_gnn_') if(pos .gt. 0) gnn_prefix = gnn_prefix(pos+11:) block type(kipf_msgpass_layer_type) :: kipf_layer kipf_layer = kipf_msgpass_layer_type( & num_vertex_features = nv_arr, & num_time_steps = nts, & activation = trim(msg_activation) & ) do t = 1, nts ! Message weight: node_X_t{t}_W write(init_prefix, '(A,"_t",I0,"_W")') trim(gnn_prefix), t do k = 1, size(inits) if(trim(inits(k)%name) .eq. trim(init_prefix))then if(allocated(inits(k)%data) .and. & allocated(kipf_layer%params))then allocate(col_data(size(inits(k)%data))) call row_to_col_major_2d( & inits(k)%data, col_data, nv_arr(t+1), nv_arr(t)) kipf_layer%params(t)%val(:,1) = col_data deallocate(col_data) end if exit end if end do end do allocate(layer, source=kipf_layer) end block end function create_from_onnx_kipf_layer