Module implementing Duvenaud message passing for molecular graphs
This module implements the graph neural network architecture from Duvenaud et al. (2015) for learning on molecular graphs with both vertex (node) and edge features.
Mathematical operation (per time step t):
Graph readout (aggregation to fixed-size vector):
where is a learned message function, is activation function, are neighbors of , are edge features, are degree-specific weight matrices, and is max vertex degree.
Reference: Duvenaud et al. (2015), NeurIPS
| Type | Visibility | Attributes | Name | Initial | |||
|---|---|---|---|---|---|---|---|
| character(len=*), | private, | parameter | :: | default_message_actv_name | = | "sigmoid" | |
| character(len=*), | private, | parameter | :: | default_readout_actv_name | = | "softmax" |
Interface for setting up the MPNN layer
Set up the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| integer, | intent(in), | dimension(:) | :: | num_vertex_features |
Number of vertex features |
|
| integer, | intent(in), | dimension(:) | :: | num_edge_features |
Number of edge features |
|
| integer, | intent(in) | :: | num_time_steps |
Number of time steps |
||
| integer, | intent(in) | :: | max_vertex_degree |
Maximum vertex degree |
||
| integer, | intent(in) | :: | num_outputs |
Number of outputs |
||
| integer, | intent(in), | optional | :: | min_vertex_degree |
Minimum vertex degree |
|
| class(*), | intent(in), | optional | :: | message_activation |
Message and readout activation functions |
|
| class(*), | intent(in), | optional | :: | readout_activation |
Message and readout activation functions |
|
| character(len=*), | intent(in), | optional | :: | kernel_initialiser | ||
| integer, | intent(in), | optional | :: | verbose |
Kernel initialiser ! Kernel initialiser Verbosity level |
Instance of the message passing layer
| Type | Visibility | Attributes | Name | Initial | |||
|---|---|---|---|---|---|---|---|
| class(base_actv_type), | public, | allocatable | :: | activation |
Activation function |
||
| class(base_actv_type), | public, | allocatable | :: | activation_readout |
Activation function |
||
| class(base_init_type), | public, | allocatable | :: | bias_init |
Initialisers for kernel and bias |
||
| character(len=14), | public | :: | bias_initialiser | = | '' |
Initialisers for kernel and bias |
|
| integer, | public, | allocatable, dimension(:) | :: | bias_shape |
Shape of biases |
||
| type(graph_type), | public, | allocatable, dimension(:) | :: | graph |
Graph structure of input data |
||
| integer, | public | :: | id |
Unique identifier |
|||
| logical, | public | :: | inference | = | .false. |
Inference mode |
|
| integer, | public | :: | input_rank | = | 0 |
Rank of input data |
|
| integer, | public, | allocatable, dimension(:) | :: | input_shape |
Input shape |
||
| class(base_init_type), | public, | allocatable | :: | kernel_init |
Initialisers for kernel and bias |
||
| character(len=14), | public | :: | kernel_initialiser | = | '' |
Initialisers for kernel and bias |
|
| integer, | public | :: | max_vertex_degree | = | 0 |
Maximum vertex degree |
|
| integer, | public | :: | min_vertex_degree | = | 1 | ||
| character(len=:), | public, | allocatable | :: | name |
Layer name |
||
| integer, | public, | dimension(:), allocatable | :: | num_edge_features |
Number of edge features for each time step |
||
| integer, | public | :: | num_output_edge_features |
Number of output edge features |
|||
| integer, | public | :: | num_output_vertex_features |
Number of output vertex features |
|||
| integer, | public | :: | num_outputs |
Number of outputs (if output is not graph structure) |
|||
| integer, | public | :: | num_params | = | 0 |
Number of learnable parameters |
|
| integer, | public, | dimension(:), allocatable | :: | num_params_msg |
Number of learnable parameters for each message |
||
| integer, | public | :: | num_params_readout |
Number of learnable parameters for the readout |
|||
| integer, | public | :: | num_time_steps |
Number of time steps |
|||
| integer, | public, | dimension(:), allocatable | :: | num_vertex_features |
Number of vertex features for each time step |
||
| class(array_type), | public, | allocatable, dimension(:,:) | :: | output |
Output |
||
| integer, | public | :: | output_rank | = | 0 |
Rank of output data |
|
| integer, | public, | allocatable, dimension(:) | :: | output_shape |
Output shape |
||
| type(array_type), | public, | allocatable, dimension(:) | :: | params |
Learnable parameters |
||
| character(len=20), | public | :: | subtype | = | repeat(" ", 20) | ||
| character(len=4), | public | :: | type | = | 'base' |
Layer type |
|
| logical, | public | :: | use_bias | = | .false. |
Layer has bias |
|
| logical, | public | :: | use_graph_input | = | .false. |
Use graph input |
|
| logical, | public | :: | use_graph_output | = | .false. |
Use graph output |
|
| integer, | public, | allocatable, dimension(:,:) | :: | weight_shape |
Shape of weights |
||
| type(array_type), | public, | allocatable, dimension(:,:) | :: | z | |||
| type(array_type), | public, | allocatable, dimension(:,:) | :: | z_readout |
Input gradients |
Interface for setting up the MPNN layer
| private module function layer_setup (num_vertex_features, num_edge_features, num_time_steps, max_vertex_degree, num_outputs, min_vertex_degree, message_activation, readout_activation, kernel_initialiser, verbose) | Set up the message passing layer |
| final :: finalise_duvenaud | Finalise the message passing layer |
| procedure, public :: add_t_t => add_learnable | Add two layers |
| procedure, public, pass(this) :: build_from_onnx => build_from_onnx_base | Build layer from ONNX node and initialiser |
| procedure, public, pass(this) :: emit_onnx_graph_inputs => emit_onnx_graph_inputs_duvenaud | Emit graph input tensor declarations for Duvenaud GNN layer |
| procedure, public, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_duvenaud | Emit ONNX JSON nodes for Duvenaud GNN layer |
| procedure, public, pass(this) :: extract_output => extract_output_base | Extract the output of the layer as a standard real array |
| procedure, public, pass(this) :: forward => forward_msgpass | Forward pass for message passing layer |
| procedure, public, pass(this) :: forward_eval => forward_eval_base | Forward pass of layer and return output for evaluation |
| procedure, public, pass(this) :: get_attributes => get_attributes_duvenaud | Get the attributes of the layer (for ONNX export) |
| procedure, public, pass(this) :: get_gradients | Get parameter gradients of layer |
| procedure, public, pass(this) :: get_num_params => get_num_params_duvenaud | Get the number of parameters for the message passing layer |
| procedure, public, pass(this) :: get_params | Get learnable parameters of layer |
| procedure, public, pass(this) :: init => init_duvenaud | Initialise the message passing layer |
| procedure, public, pass(this) :: nullify_graph => nullify_graph_base | Nullify the forward pass data of the layer to free memory |
| generic, public :: operator(+) => add_t_t | Operator overloading for addition |
| procedure, public, pass(this) :: print => print_base | Print the layer to a file with additional information |
| procedure, public, pass(this) :: print_to_unit => print_to_unit_duvenaud | |
| procedure, public, pass(this) :: read => read_duvenaud | Read the message passing layer |
| procedure, public, pass(this) :: reduce => reduce_learnable | Merge another learnable layer into this one |
| procedure, public, pass(this) :: set_gradients | Set learnable parameters of layer |
| procedure, public, pass(this) :: set_graph => set_graph_duvenaud | Set the graph for the message passing layer |
| procedure, public, pass(this) :: set_hyperparams => set_hyperparams_duvenaud | Set the hyperparameters for the message passing layer |
| procedure, public, pass(this) :: set_params | Set learnable parameters of layer |
| procedure, public, pass(this) :: set_rank => set_rank_base | Set the input and output ranks of the layer |
| procedure, public, pass(this) :: set_shape => set_shape_base | Set the input shape of the layer |
| procedure, public, pass(this) :: update_message => update_message_duvenaud | Update the message |
| procedure, public, pass(this) :: update_readout => update_readout_duvenaud | Update the readout |
Read duvenaud message passing layer from file and return layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| integer, | intent(in) | :: | unit |
Unit number |
||
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Instance of the message passing layer
Get the attributes of the Duvenaud message passing layer (for ONNX export)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | this |
Instance of the layer |
Attributes of the layer
Get the number of parameters for the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | this |
Instance of the message passing layer |
Number of parameters
Set up the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| integer, | intent(in), | dimension(:) | :: | num_vertex_features |
Number of vertex features |
|
| integer, | intent(in), | dimension(:) | :: | num_edge_features |
Number of edge features |
|
| integer, | intent(in) | :: | num_time_steps |
Number of time steps |
||
| integer, | intent(in) | :: | max_vertex_degree |
Maximum vertex degree |
||
| integer, | intent(in) | :: | num_outputs |
Number of outputs |
||
| integer, | intent(in), | optional | :: | min_vertex_degree |
Minimum vertex degree |
|
| class(*), | intent(in), | optional | :: | message_activation |
Message and readout activation functions |
|
| class(*), | intent(in), | optional | :: | readout_activation |
Message and readout activation functions |
|
| character(len=*), | intent(in), | optional | :: | kernel_initialiser | ||
| integer, | intent(in), | optional | :: | verbose |
Kernel initialiser ! Kernel initialiser Verbosity level |
Instance of the message passing layer
Emit the degree-dependent weight selection and update block.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | tp | |||
| character(len=*), | intent(in) | :: | degree_in | |||
| integer, | intent(in) | :: | min_degree | |||
| integer, | intent(in) | :: | max_degree | |||
| integer, | intent(in) | :: | feature_dim | |||
| integer, | intent(in) | :: | nv_out | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| character(len=*), | intent(in) | :: | aggr_in | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| character(len=128), | intent(out) | :: | sq_out |
Emit ONNX nodes for Duvenaud readout
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | layer | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| integer, | intent(in) | :: | max_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| integer, | intent(in) | :: | max_inits | |||
| character(len=128), | intent(out) | :: | readout_output |
Emit one Duvenaud readout timestep.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| character(len=*), | intent(in) | :: | activation_name | |||
| integer, | intent(in) | :: | t | |||
| integer, | intent(in) | :: | nv | |||
| integer, | intent(in) | :: | no | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| character(len=128), | intent(out) | :: | step_sum |
Emit ONNX nodes for one Duvenaud message passing time step.
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| character(len=*), | intent(in) | :: | prefix | |||
| integer, | intent(in) | :: | t | |||
| integer, | intent(in) | :: | nv_in | |||
| integer, | intent(in) | :: | ne_in | |||
| integer, | intent(in) | :: | nv_out | |||
| integer, | intent(in) | :: | min_degree | |||
| integer, | intent(in) | :: | max_degree | |||
| real(kind=real32), | intent(in) | :: | weight_data(:) | |||
| character(len=*), | intent(in) | :: | activation_name | |||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes | ||
| integer, | intent(inout) | :: | num_nodes | |||
| integer, | intent(in) | :: | max_nodes | |||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits | ||
| integer, | intent(inout) | :: | num_inits | |||
| integer, | intent(in) | :: | max_inits | |||
| character(len=128), | intent(out) | :: | vertex_out |
Emit graph input tensor declarations for Duvenaud GNN layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | this |
Instance of the layer |
||
| character(len=*), | intent(in) | :: | prefix |
Input name prefix (e.g. "input_1") |
||
| type(onnx_tensor_type), | intent(inout), | dimension(:) | :: | graph_inputs |
Accumulator for graph input tensor declarations |
|
| integer, | intent(inout) | :: | num_inputs |
Current number of graph input declarations |
Emit ONNX JSON nodes for Duvenaud GNN layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | this |
Instance of the layer |
||
| character(len=*), | intent(in) | :: | prefix |
Node name prefix (e.g. "node_2") |
||
| type(onnx_node_type), | intent(inout), | dimension(:) | :: | nodes |
Accumulator for ONNX nodes |
|
| integer, | intent(inout) | :: | num_nodes |
Current number of nodes |
||
| integer, | intent(in) | :: | max_nodes |
Maximum capacity |
||
| type(onnx_initialiser_type), | intent(inout), | dimension(:) | :: | inits |
Accumulator for ONNX initialisers |
|
| integer, | intent(inout) | :: | num_inits |
Current number of initialisers |
||
| integer, | intent(in) | :: | max_inits |
Maximum capacity |
||
| character(len=*), | intent(in), | optional | :: | input_name |
Unused sequential input name |
|
| logical, | intent(in), | optional | :: | is_last_layer |
Unused last-layer flag |
|
| integer, | intent(in), | optional | :: | format |
Unused export format selector |
Finalise the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| type(duvenaud_msgpass_layer_type), | intent(inout) | :: | this |
Instance of the fully connected layer |
Initialise the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(inout) | :: | this |
Instance of the fully connected layer |
||
| integer, | intent(in), | dimension(:) | :: | input_shape |
Input shape |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Print kipf message passing layer to unit
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(in) | :: | this |
Instance of the message passing layer |
||
| integer, | intent(in) | :: | unit |
Filename |
Read the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(inout) | :: | this |
Instance of the message passing layer |
||
| integer, | intent(in) | :: | unit |
Unit to read from |
||
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Set the graph structure of the input data
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(inout) | :: | this |
Instance of the layer |
||
| type(graph_type), | intent(in), | dimension(:) | :: | graph |
Graph structure of input data |
Set the hyperparameters for the message passing layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(inout) | :: | this |
Instance of the message passing layer |
||
| integer, | intent(in), | dimension(:) | :: | num_vertex_features |
Number of vertex features |
|
| integer, | intent(in), | dimension(:) | :: | num_edge_features |
Number of edge features |
|
| integer, | intent(in) | :: | min_vertex_degree |
Minimum vertex degree |
||
| integer, | intent(in) | :: | max_vertex_degree |
Maximum vertex degree |
||
| integer, | intent(in) | :: | num_time_steps |
Number of time steps |
||
| integer, | intent(in) | :: | num_outputs |
Number of outputs |
||
| class(base_actv_type), | intent(in), | allocatable | :: | message_activation |
Message and readout activation functions |
|
| class(base_actv_type), | intent(in), | allocatable | :: | readout_activation |
Message and readout activation functions |
|
| class(base_init_type), | intent(in), | allocatable | :: | kernel_initialiser |
Kernel and bias initialisers |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Update the message
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(inout), | target | :: | this |
Instance of the message passing layer |
|
| class(array_type), | intent(in), | dimension(:,:), target | :: | input |
Input to the message passing layer |
Update the readout
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(duvenaud_msgpass_layer_type), | intent(inout), | target | :: | this |
Instance of the message passing layer |