Build a Duvenaud layer from an expanded-ONNX cluster.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix |
Layer node prefix (e.g. "node_2") |
||
| type(onnx_node_type), | intent(in) | :: | nodes(:) |
Parsed ONNX nodes |
||
| integer, | intent(in) | :: | num_nodes |
Number of valid node entries |
||
| type(onnx_initialiser_type), | intent(in) | :: | inits(:) |
Parsed ONNX initialisers |
||
| integer, | intent(in) | :: | num_inits |
Number of valid initialiser entries |
||
| type(onnx_tensor_type), | intent(in) | :: | inputs(:) |
Parsed ONNX graph input tensors |
||
| integer, | intent(in) | :: | num_inputs |
Number of valid graph input entries |
Constructed Duvenaud layer
function build_duvenaud_onnx_expanded_gnn( & prefix, nodes, num_nodes, inits, & num_inits, inputs, num_inputs) & result(layer) !! Build a Duvenaud layer from an expanded-ONNX cluster. use athena__duvenaud_msgpass_layer, only: & duvenaud_msgpass_layer_type use athena__onnx_nop_utils, only: & find_initialiser_by_name use athena__onnx_utils, only: row_to_col_major_2d implicit none ! Arguments character(*), intent(in) :: prefix !! Layer node prefix (e.g. "node_2") type(onnx_node_type), intent(in) :: nodes(:) !! Parsed ONNX nodes integer, intent(in) :: num_nodes !! Number of valid node entries type(onnx_initialiser_type), intent(in) :: inits(:) !! Parsed ONNX initialisers integer, intent(in) :: num_inits !! Number of valid initialiser entries type(onnx_tensor_type), intent(in) :: inputs(:) !! Parsed ONNX graph input tensors integer, intent(in) :: num_inputs !! Number of valid graph input entries class(base_layer_type), allocatable :: layer !! Constructed Duvenaud layer ! Local variables integer :: t, nts, idx, n_out integer :: num_deg, min_deg, max_deg integer :: ne_in, nv_in_first, total_in integer, allocatable :: nv_arr(:), ne_arr(:) character(128) :: init_name, rename_name character(64) :: msg_activation real(real32), allocatable :: col_data(:) integer :: i, rename_idx, slice_size, d ! Count timesteps nts = 0 do t = 1, 99 write(init_name, '(A,"_t",I0,"_W")') & trim(prefix), t idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) if(idx .le. 0) exit nts = nts + 1 end do if(nts .eq. 0)then call stop_program( & 'Duvenaud ONNX cluster has no weights' & // ' for ' // trim(prefix)) end if ! Get first weight init to extract dims write(init_name, '(A,"_t1_W")') trim(prefix) idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) ! 3D shape: [num_deg, nv_out, nv_in+ne] num_deg = inits(idx)%dims(1) total_in = inits(idx)%dims(3) ! Get degree bounds from constant inits write(init_name, '(A,"_t1_min_deg")') trim(prefix) idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) min_deg = 1 if(idx .gt. 0)then if(allocated(inits(idx)%data))then min_deg = nint(inits(idx)%data(1)) end if end if write(init_name, '(A,"_t1_max_deg")') trim(prefix) idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) max_deg = min_deg + num_deg - 1 if(idx .gt. 0)then if(allocated(inits(idx)%data))then max_deg = nint(inits(idx)%data(1)) end if end if ! Determine ne from graph inputs via the rename ! Identity node: {prefix}_rename_edge → input is ! the graph input tensor with edge feature dims ne_in = 0 rename_name = trim(prefix) // '_rename_edge' rename_idx = find_gnn_node( & nodes, num_nodes, trim(rename_name)) if(rename_idx .gt. 0 .and. & allocated(nodes(rename_idx)%inputs))then do i = 1, num_inputs if(trim(inputs(i)%name) .eq. & trim(nodes(rename_idx)%inputs(1)))then if(allocated(inputs(i)%dims) .and. & size(inputs(i)%dims) .ge. 2)then ne_in = inputs(i)%dims(2) end if exit end if end do end if ! Build vertex and edge feature arrays allocate(nv_arr(nts + 1)) allocate(ne_arr(nts + 1)) ne_arr = ne_in ! First timestep: infer nv_in from total - ne write(init_name, '(A,"_t1_W")') trim(prefix) idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) nv_in_first = total_in - ne_in nv_arr(1) = nv_in_first do t = 1, nts write(init_name, '(A,"_t",I0,"_W")') & trim(prefix), t idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) nv_arr(t+1) = inits(idx)%dims(2) end do ! Get num_outputs from readout weight write(init_name, '(A,"_ro_t1_R")') trim(prefix) idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) if(idx .gt. 0)then n_out = inits(idx)%dims(1) else n_out = nv_arr(nts + 1) end if msg_activation = detect_gnn_expanded_activation( & prefix, nodes, num_nodes) block type(duvenaud_msgpass_layer_type) :: duv_layer duv_layer = duvenaud_msgpass_layer_type( & num_vertex_features = nv_arr, & num_edge_features = ne_arr, & num_time_steps = nts, & num_outputs = n_out, & min_vertex_degree = min_deg, & max_vertex_degree = max_deg, & message_activation = msg_activation & ) do t = 1, nts ! Message weight: [num_deg, nv_out, nv_in+ne] write(init_name, '(A,"_t",I0,"_W")') & trim(prefix), t idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) if(allocated(inits(idx)%data) .and. & allocated(duv_layer%params))then allocate(col_data(size(inits(idx)%data))) slice_size = nv_arr(t+1) * & (nv_arr(t) + ne_arr(1)) do d = 1, num_deg call row_to_col_major_2d( & inits(idx)%data( & (d-1)*slice_size+1 : & d*slice_size), & col_data( & (d-1)*slice_size+1 : & d*slice_size), & nv_arr(t+1), & nv_arr(t) + ne_arr(1)) end do duv_layer%params(t)%val(:,1) = col_data deallocate(col_data) end if ! Readout weight: [n_out, nv_out] write(init_name, '(A,"_ro_t",I0,"_R")') & trim(prefix), t idx = find_initialiser_by_name( & trim(init_name), inits, num_inits) if(idx .gt. 0)then if(allocated(inits(idx)%data) .and. & allocated(duv_layer%params))then allocate(col_data( & size(inits(idx)%data))) call row_to_col_major_2d( & inits(idx)%data, col_data, & n_out, nv_arr(t+1)) duv_layer%params(nts + t)%val(:,1) = & col_data deallocate(col_data) end if end if end do allocate(layer, source=duv_layer) end block end function build_duvenaud_onnx_expanded_gnn