Message Passing Neural Networks

Complete examples demonstrating graph neural networks (GNNs) using message passing layers for learning on graph-structured data.

This tutorial covers the example/msgpass_chemical and example/msgpass_euler examples from the athena repository.

Overview

Message passing neural networks operate on graph data by:

  • Aggregating information from neighboring nodes

  • Updating node features based on local graph structure

  • Learning graph-level or node-level representations

  • Preserving graph topology throughout the network

These examples demonstrate two different applications:

  1. Chemical graphs (msgpass_chemical): Predicting molecular energy from atomic structure

  2. Euler flow (msgpass_euler): Predicting steady-state fluid flow over geometry

Chemical Graph Example

Predicting Molecular Energy

The example/msgpass_chemical uses Duvenaud message passing [1] to predict total energy from molecular graphs. Note that this is an illustrative example; for production use, consider more advanced architectures.

Network Architecture

! Message passing layer with graph readout
call network%add(duvenaud_msgpass_layer_type( &
     num_time_steps=4, &
     num_vertex_features=[num_atom_features], &
     num_edge_features=[num_bond_features], &
     num_outputs=10, &
     kernel_initialiser="glorot_normal", &
     readout_activation="softmax", &
     min_vertex_degree=1, &
     max_vertex_degree=10))

! Dense layers for prediction
call network%add(full_layer_type( &
     num_inputs=10, &
     num_outputs=128, &
     activation="leaky_relu", &
     kernel_initialiser="he_normal"))

call network%add(full_layer_type( &
     num_outputs=64, &
     activation="leaky_relu"))

call network%add(full_layer_type( &
     num_outputs=1, &
     activation="leaky_relu"))

Architecture components:

  1. Duvenaud message passing: Aggregates neighbor features over multiple time steps

  2. Graph readout: Reduces graph to fixed-size vector using softmax aggregation

  3. Dense layers: Learn from aggregated graph representation to scalar energy prediction

Complete Program Structure

program msgpass_chemical_example
  use athena
  use read_chemical_graphs, only: read_extxyz_db
  implicit none

  type(network_type) :: network
  type(graph_type), allocatable :: graphs_in(:,:)
  type(array_type) :: output(1,1)
  type(metric_dict_type) :: metric_dict(2)
  class(clip_type), allocatable :: clip

  integer, parameter :: num_epochs = 100
  integer, parameter :: batch_size = 8
  integer, parameter :: num_time_steps = 4
  integer :: seed = 42

  ! Load chemical graphs from XYZ format
  call read_extxyz_db("database.xyz", graphs_in, output)

  ! Add self-loops and convert to sparse format
  do i = 1, size(graphs_in)
    call graphs_in(1,i)%add_self_loops()
    if (.not. graphs_in(1,i)%is_sparse) then
      call graphs_in(1,i)%convert_to_sparse()
    end if
  end do

  ! Initialise random seed
  call random_setup(seed, restart=.false.)

  ! Build network (see architecture above)

  ! Compile with gradient clipping
  allocate(clip, source=clip_type(clip_norm=0.1_real32))

  metric_dict%active = .false.
  metric_dict(1)%key = "loss"
  metric_dict(2)%key = "accuracy"
  metric_dict%threshold = 0.1

  call network%compile( &
       optimiser=adam_optimiser_type( &
            clip_dict=clip, &
            learning_rate=0.01_real32), &
       loss_method="mse", &
       accuracy_method="mse", &
       metrics=metric_dict, &
       batch_size=batch_size, &
       verbose=1)

  ! Normalise outputs
  output_min = minval(output(1,1)%val)
  output_max = maxval(output(1,1)%val)
  output(1,1)%val = (output(1,1)%val - output_min) / &
                    (output_max - output_min)

  ! Train network
  call network%train( &
       graphs_in, &
       output, &
       num_epochs=num_epochs, &
       shuffle_batches=.true.)

  ! Test and save
  call network%test(graphs_in, output)
  call network%print(file="network.txt")

end program msgpass_chemical_example

Graph Data Format

Chemical graphs contain:

type(graph_type) :: molecule

! Vertex (atom) features: [num_features, num_atoms]
! Example: atomic number, valence, hybridisation, etc.
molecule%num_vertex_features = 6
molecule%vertex_features(:, atom_id)

! Edge (bond) features: [num_features, num_edges]
! Example: bond type, bond order, ring membership
molecule%num_edge_features = 4
molecule%edge_features(:, bond_id)

! Sparse adjacency representation
molecule%adjacency_matrix  ! Connectivity
molecule%is_sparse = .true.

Euler Flow Example

Predicting Steady-State Flow

The example/msgpass_euler uses Kipf message passing (Graph Convolutional Network) [2] to predict steady-state fluid flow from initial conditions. Unlike the chemical example, the Kipf layer outputs node-level features, preserving graph structure. This example utilises skip connections via concatenation to improve information flow.

Network Architecture

! First message passing layer
call network%add(kipf_msgpass_layer_type( &
     num_time_steps=1, &
     num_vertex_features=[3, 6], &  ! [input, output] dimensions
     activation="softmax", &
     kernel_initialiser="he_normal"))

! Second layer with concatenation
call network%add(kipf_msgpass_layer_type( &
     num_time_steps=1, &
     num_vertex_features=[9, 14], &  ! 6 + 3 = 9 from concatenation
     activation="softmax"), &
     input_list=[0, -1], &           ! Concatenate layer 0 and previous
     operator="concatenate")

! Additional layers continuing pattern
call network%add(kipf_msgpass_layer_type( &
     num_time_steps=1, &
     num_vertex_features=[17, 32], &
     activation="softmax"), &
     input_list=[0, -1], &
     operator="concatenate")

! ... more layers with increasing then decreasing dimensions ...

! Final layer
call network%add(kipf_msgpass_layer_type( &
     num_time_steps=1, &
     num_vertex_features=[17, 7], &  ! Output: 7 flow features
     activation="swish"), &
     input_list=[0, -1], &
     operator="concatenate")

Key architecture features:

  1. U-Net style: Features expand then contract (6 → 14 → 32 → 64 → 32 → 14 → 7)

  2. Skip connections: Concatenate original input at each layer via input_list=[0, -1]

  3. Kipf message passing: Graph convolution normalises by node degree

  4. Multiple aggregations: Information propagates through graph structure

Training on Graphs

program msgpass_euler_example
  use athena
  use read_euler, only: read_graph
  implicit none

  type(network_type) :: network
  type(graph_type), allocatable :: graphs_in(:,:), graphs_out(:,:)
  type(graph_type), allocatable :: graphs_predicted(:,:)
  class(clip_type), allocatable :: clip

  integer, parameter :: num_epochs = 200
  integer, parameter :: batch_size = 2
  integer :: seed = 1

  ! Load graph data from files
  allocate(graphs_in(1, 2), graphs_out(1, 2))

  do i = 1, 2
    call read_graph( &
         vertex_file="bump_nodeData_in_"//trim(str(i))//".txt", &
         edge_file="bump_edgeData_1.txt", &
         graph=graphs_in(1,i))

    call read_graph( &
         vertex_file="bump_nodeData_out_"//trim(str(i))//".txt", &
         edge_file="bump_edgeData_1.txt", &
         graph=graphs_out(1,i))
  end do

  ! Initialise random seed
  call random_setup(seed, restart=.false.)

  ! Build network (see architecture above)

  ! Compile with gradient clipping and learning rate decay
  allocate(clip, source=clip_type(-1.0_real32, 1.0_real32))

  call network%compile( &
       optimiser=adam_optimiser_type( &
            clip_dict=clip, &
            learning_rate=0.02_real32, &
            lr_decay=exp_lr_decay_type(0.001_real32)), &
       loss_method="mse", &
       accuracy_method="mse", &
       batch_size=batch_size, &
       verbose=1)

  ! Train on graph pairs
  call network%train( &
       graphs_in, &
       graphs_out, &
       num_epochs=num_epochs)

  ! Test and predict
  call network%test(graphs_in, graphs_out)
  graphs_predicted = network%predict(graphs_in)

  ! Save results
  call network%print(file="network.txt")

end program msgpass_euler_example

Message Passing Types

Currently, athena supports two message passing layer types: the Duvenaud layer and the Kipf layer. However, the framework is extensible to implement custom message passing schemes by extending the msgpass_layer_type.

Duvenaud Message Passing

Designed for molecular graphs:

duvenaud_msgpass_layer_type( &
     num_time_steps=4, &          ! Number of message passing iterations
     num_vertex_features=[...], & ! Node feature dimensions
     num_edge_features=[...], &   ! Edge feature dimensions
     num_outputs=10, &            ! Graph-level output dimension
     readout_activation="softmax" & ! Aggregation method
)

Characteristics:

  • Considers edge features in message passing

  • Performs graph-level readout (reduces entire graph to vector)

  • Suitable for graph-level prediction tasks

Graph Convolutional Network (Kipf)

Graph Convolutional Network:

kipf_msgpass_layer_type( &
     num_time_steps=1, &          ! Usually 1 per layer
     num_vertex_features=[in, out], & ! [input_dim, output_dim]
     activation="softmax" &
)

Characteristics:

  • Degree-normalised aggregation

  • Node-level outputs (preserves graph structure)

  • Can be stacked with skip connections

  • Suitable for node-level prediction tasks

Training on Graphs

Key Differences from Standard Networks

Graph batching:

! Graphs are batched as array of graphs
type(graph_type), allocatable :: graphs(:,:)
! Shape: [1, num_samples]

! Call train with graphs directly
call network%train(graphs_in, graphs_out, num_epochs=100)

Sparse representation:

! Convert to sparse format for efficiency
call graph%add_self_loops()        ! Add diagonal connections
call graph%convert_to_sparse()     ! Use sparse matrix format

Output formats:

! Graph-level output (scalar per graph)
type(array_type) :: output(1,1)
output(1,1)%val  ! [1, num_samples]

! Node-level output (features per node)
type(graph_type) :: output_graphs(:,:)
output_graphs(1,s)%vertex_features  ! [num_features, num_nodes]

Other important considerations when training message passing networks include:

  • Gradient clipping - to stabilise training on graphs

  • Learning rate decay - to improve convergence

Key Takeaways

The message passing examples illustrate how to build and train graph neural networks using athena. The main points to consider when working with message passing NNs are:

  1. Graph structure matters: Message passing leverages connectivity information

  2. Sparse is faster: Use sparse representation for large graphs

  3. Gradient clipping essential: Prevents exploding gradients in deep message passing

  4. Skip connections help: Concatenating early features improves information flow

  5. Layer choice matters: Duvenaud for graph-level, Kipf for node-level predictions

When to Use Message Passing NNs

Good for:

  • Molecular property prediction

  • Social network analysis

  • Traffic/flow prediction on networks

  • Point cloud processing

  • Physics simulations on meshes

Not ideal for:

  • Regular grid data (use CNNs instead)

  • Sequential data (use RNNs instead)

  • When graph structure is unknown

  • Very large graphs (>100k nodes)

See Also

Footnotes