Build Duvenaud message-passing layer from ONNX metadata and return layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | meta_key |
GNN metadata key (e.g. "athena_gnn_node_1") |
||
| character(len=*), | intent(in) | :: | meta_value |
Semicolon-separated GNN metadata value string |
||
| type(onnx_initialiser_type), | intent(in), | dimension(:) | :: | inits |
ONNX initialisers (valid entries only) |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Constructed Duvenaud message-passing layer
function create_from_onnx_duvenaud_layer( & meta_key, meta_value, inits, verbose & ) result(layer) !! Build Duvenaud message-passing layer from ONNX metadata and return layer use athena__duvenaud_msgpass_layer, only: duvenaud_msgpass_layer_type use athena__onnx_utils, only: row_to_col_major_2d, & parse_space_separated_ints implicit none ! Arguments character(*), intent(in) :: meta_key !! GNN metadata key (e.g. "athena_gnn_node_1") character(*), intent(in) :: meta_value !! Semicolon-separated GNN metadata value string type(onnx_initialiser_type), dimension(:), intent(in) :: inits !! ONNX initialisers (valid entries only) integer, optional, intent(in) :: verbose !! Verbosity level class(base_layer_type), allocatable :: layer !! Constructed Duvenaud message-passing layer ! Local variables integer :: nts, n_out, min_deg, max_deg, num_deg integer, allocatable :: nv_arr(:), ne_arr(:) character(64) :: msg_activation character(128) :: gnn_prefix integer :: t, k, pos, pos2, verbose_ character(256) :: meta_str, token, key, val character(128) :: init_prefix real(real32), allocatable :: col_data(:) verbose_ = 0 if(present(verbose)) verbose_ = verbose ! Parse hyperparameters from the metadata value string meta_str = meta_value nts = 1 min_deg = 1 max_deg = 10 n_out = 1 msg_activation = 'sigmoid' pos = 1 do while(pos .le. len_trim(meta_str)) pos2 = index(meta_str(pos:), ';') if(pos2 .eq. 0)then token = meta_str(pos:len_trim(meta_str)) pos = len_trim(meta_str) + 1 else token = meta_str(pos:pos+pos2-2) pos = pos + pos2 end if k = index(token, '=') if(k .eq. 0) cycle key = trim(adjustl(token(1:k-1))) val = trim(adjustl(token(k+1:))) select case(trim(key)) case('num_time_steps') read(val, *) nts case('min_vertex_degree') read(val, *) min_deg case('max_vertex_degree') read(val, *) max_deg case('num_vertex_features') call parse_space_separated_ints(val, nv_arr) case('num_edge_features') call parse_space_separated_ints(val, ne_arr) case('num_outputs') read(val, *) n_out case('message_activation') msg_activation = trim(val) end select end do if(.not.allocated(nv_arr)) allocate(nv_arr(1), source=1) if(.not.allocated(ne_arr)) allocate(ne_arr(1), source=0) num_deg = max_deg - min_deg + 1 ! Derive initialiser name prefix from the metadata key gnn_prefix = trim(meta_key) pos = index(gnn_prefix, 'athena_gnn_') if(pos .gt. 0) gnn_prefix = gnn_prefix(pos+11:) block type(duvenaud_msgpass_layer_type) :: duvenaud_layer duvenaud_layer = duvenaud_msgpass_layer_type( & num_vertex_features = nv_arr, & num_edge_features = ne_arr, & num_time_steps = nts, & num_outputs = n_out, & min_vertex_degree = min_deg, & max_vertex_degree = max_deg, & message_activation = msg_activation & ) do t = 1, nts ! Message weight: node_X_t{t}_W write(init_prefix, '(A,"_t",I0,"_W")') trim(gnn_prefix), t do k = 1, size(inits) if(trim(inits(k)%name) .eq. trim(init_prefix))then if(allocated(inits(k)%data) .and. & allocated(duvenaud_layer%params))then allocate(col_data(size(inits(k)%data))) block integer :: d, slice_size slice_size = nv_arr(t+1) * (nv_arr(t) + ne_arr(1)) do d = 1, num_deg call row_to_col_major_2d( & inits(k)%data((d-1)*slice_size+1:d*slice_size), & col_data((d-1)*slice_size+1:d*slice_size), & nv_arr(t+1), nv_arr(t) + ne_arr(1)) end do end block duvenaud_layer%params(t)%val(:,1) = col_data deallocate(col_data) end if exit end if end do ! Readout weight: node_X_ro_t{t}_R write(init_prefix, '(A,"_ro_t",I0,"_R")') trim(gnn_prefix), t do k = 1, size(inits) if(trim(inits(k)%name) .eq. trim(init_prefix))then if(allocated(inits(k)%data) .and. & allocated(duvenaud_layer%params))then allocate(col_data(size(inits(k)%data))) call row_to_col_major_2d( & inits(k)%data, col_data, n_out, nv_arr(t+1)) duvenaud_layer%params(nts + t)%val(:,1) = col_data deallocate(col_data) end if exit end if end do end do allocate(layer, source=duvenaud_layer) end block end function create_from_onnx_duvenaud_layer