build_duvenaud_onnx_expanded_gnn Function

public function build_duvenaud_onnx_expanded_gnn(prefix, nodes, num_nodes, inits, num_inits, inputs, num_inputs) result(layer)

Build a Duvenaud layer from an expanded-ONNX cluster.

Arguments

Type IntentOptional Attributes Name
character(len=*), intent(in) :: prefix

Layer node prefix (e.g. "node_2")

type(onnx_node_type), intent(in) :: nodes(:)

Parsed ONNX nodes

integer, intent(in) :: num_nodes

Number of valid node entries

type(onnx_initialiser_type), intent(in) :: inits(:)

Parsed ONNX initialisers

integer, intent(in) :: num_inits

Number of valid initialiser entries

type(onnx_tensor_type), intent(in) :: inputs(:)

Parsed ONNX graph input tensors

integer, intent(in) :: num_inputs

Number of valid graph input entries

Return Value class(base_layer_type), allocatable

Constructed Duvenaud layer


Source Code

  function build_duvenaud_onnx_expanded_gnn( &
       prefix, nodes, num_nodes, inits, &
       num_inits, inputs, num_inputs) &
  result(layer)
    !! Build a Duvenaud layer from an expanded-ONNX cluster.
    use athena__duvenaud_msgpass_layer, only: &
         duvenaud_msgpass_layer_type
    use athena__onnx_nop_utils, only: &
         find_initialiser_by_name
    use athena__onnx_utils, only: row_to_col_major_2d
    implicit none

    ! Arguments
    character(*), intent(in) :: prefix
    !! Layer node prefix (e.g. "node_2")
    type(onnx_node_type), intent(in) :: nodes(:)
    !! Parsed ONNX nodes
    integer, intent(in) :: num_nodes
    !! Number of valid node entries
    type(onnx_initialiser_type), intent(in) :: inits(:)
    !! Parsed ONNX initialisers
    integer, intent(in) :: num_inits
    !! Number of valid initialiser entries
    type(onnx_tensor_type), intent(in) :: inputs(:)
    !! Parsed ONNX graph input tensors
    integer, intent(in) :: num_inputs
    !! Number of valid graph input entries
    class(base_layer_type), allocatable :: layer
    !! Constructed Duvenaud layer

    ! Local variables
    integer :: t, nts, idx, n_out
    integer :: num_deg, min_deg, max_deg
    integer :: ne_in, nv_in_first, total_in
    integer, allocatable :: nv_arr(:), ne_arr(:)
    character(128) :: init_name, rename_name
    character(64) :: msg_activation
    real(real32), allocatable :: col_data(:)
    integer :: i, rename_idx, slice_size, d

    ! Count timesteps
    nts = 0
    do t = 1, 99
       write(init_name, '(A,"_t",I0,"_W")') &
            trim(prefix), t
       idx = find_initialiser_by_name( &
            trim(init_name), inits, num_inits)
       if(idx .le. 0) exit
       nts = nts + 1
    end do

    if(nts .eq. 0)then
       call stop_program( &
            'Duvenaud ONNX cluster has no weights' &
            // ' for ' // trim(prefix))
    end if

    ! Get first weight init to extract dims
    write(init_name, '(A,"_t1_W")') trim(prefix)
    idx = find_initialiser_by_name( &
         trim(init_name), inits, num_inits)
    ! 3D shape: [num_deg, nv_out, nv_in+ne]
    num_deg = inits(idx)%dims(1)
    total_in = inits(idx)%dims(3)

    ! Get degree bounds from constant inits
    write(init_name, '(A,"_t1_min_deg")') trim(prefix)
    idx = find_initialiser_by_name( &
         trim(init_name), inits, num_inits)
    min_deg = 1
    if(idx .gt. 0)then
       if(allocated(inits(idx)%data))then
          min_deg = nint(inits(idx)%data(1))
       end if
    end if

    write(init_name, '(A,"_t1_max_deg")') trim(prefix)
    idx = find_initialiser_by_name( &
         trim(init_name), inits, num_inits)
    max_deg = min_deg + num_deg - 1
    if(idx .gt. 0)then
       if(allocated(inits(idx)%data))then
          max_deg = nint(inits(idx)%data(1))
       end if
    end if

    ! Determine ne from graph inputs via the rename
    ! Identity node: {prefix}_rename_edge → input is
    ! the graph input tensor with edge feature dims
    ne_in = 0
    rename_name = trim(prefix) // '_rename_edge'
    rename_idx = find_gnn_node( &
         nodes, num_nodes, trim(rename_name))
    if(rename_idx .gt. 0 .and. &
         allocated(nodes(rename_idx)%inputs))then
       do i = 1, num_inputs
          if(trim(inputs(i)%name) .eq. &
               trim(nodes(rename_idx)%inputs(1)))then
             if(allocated(inputs(i)%dims) .and. &
                  size(inputs(i)%dims) .ge. 2)then
                ne_in = inputs(i)%dims(2)
             end if
             exit
          end if
       end do
    end if

    ! Build vertex and edge feature arrays
    allocate(nv_arr(nts + 1))
    allocate(ne_arr(nts + 1))
    ne_arr = ne_in

    ! First timestep: infer nv_in from total - ne
    write(init_name, '(A,"_t1_W")') trim(prefix)
    idx = find_initialiser_by_name( &
         trim(init_name), inits, num_inits)
    nv_in_first = total_in - ne_in
    nv_arr(1) = nv_in_first

    do t = 1, nts
       write(init_name, '(A,"_t",I0,"_W")') &
            trim(prefix), t
       idx = find_initialiser_by_name( &
            trim(init_name), inits, num_inits)
       nv_arr(t+1) = inits(idx)%dims(2)
    end do

    ! Get num_outputs from readout weight
    write(init_name, '(A,"_ro_t1_R")') trim(prefix)
    idx = find_initialiser_by_name( &
         trim(init_name), inits, num_inits)
    if(idx .gt. 0)then
       n_out = inits(idx)%dims(1)
    else
       n_out = nv_arr(nts + 1)
    end if

    msg_activation = detect_gnn_expanded_activation( &
         prefix, nodes, num_nodes)

    block
      type(duvenaud_msgpass_layer_type) :: duv_layer

      duv_layer = duvenaud_msgpass_layer_type( &
           num_vertex_features = nv_arr, &
           num_edge_features = ne_arr, &
           num_time_steps = nts, &
           num_outputs = n_out, &
           min_vertex_degree = min_deg, &
           max_vertex_degree = max_deg, &
           message_activation = msg_activation &
      )

      do t = 1, nts
         ! Message weight: [num_deg, nv_out, nv_in+ne]
         write(init_name, '(A,"_t",I0,"_W")') &
              trim(prefix), t
         idx = find_initialiser_by_name( &
              trim(init_name), inits, num_inits)
         if(allocated(inits(idx)%data) .and. &
              allocated(duv_layer%params))then
            allocate(col_data(size(inits(idx)%data)))
            slice_size = nv_arr(t+1) * &
                 (nv_arr(t) + ne_arr(1))
            do d = 1, num_deg
               call row_to_col_major_2d( &
                    inits(idx)%data( &
                         (d-1)*slice_size+1 : &
                         d*slice_size), &
                    col_data( &
                         (d-1)*slice_size+1 : &
                         d*slice_size), &
                    nv_arr(t+1), &
                    nv_arr(t) + ne_arr(1))
            end do
            duv_layer%params(t)%val(:,1) = col_data
            deallocate(col_data)
         end if

         ! Readout weight: [n_out, nv_out]
         write(init_name, '(A,"_ro_t",I0,"_R")') &
              trim(prefix), t
         idx = find_initialiser_by_name( &
              trim(init_name), inits, num_inits)
         if(idx .gt. 0)then
            if(allocated(inits(idx)%data) .and. &
                 allocated(duv_layer%params))then
               allocate(col_data( &
                    size(inits(idx)%data)))
               call row_to_col_major_2d( &
                    inits(idx)%data, col_data, &
                    n_out, nv_arr(t+1))
               duv_layer%params(nts + t)%val(:,1) = &
                    col_data
               deallocate(col_data)
            end if
         end if
      end do

      allocate(layer, source=duv_layer)
    end block

  end function build_duvenaud_onnx_expanded_gnn