module athena__msgpass_layer !! Module containing the types and interfaces of a message passing layer use coreutils, only: real32 use graphstruc, only: graph_type use athena__base_layer, only: learnable_layer_type use athena__clipper, only: clip_type use diffstruc, only: array_type implicit none private public :: msgpass_layer_type !------------------------------------------------------------------------------- ! Message passing layer !------------------------------------------------------------------------------- type, abstract, extends(learnable_layer_type) :: msgpass_layer_type !! Type for message passing layer with overloaded procedures !! !! This derived type contains the implementation of a message passing !! layer. These are useful for graph neural networks and other models !! that require message passing. !! For graphs, the terms there are two common terms used seemingly !! interchangeably in the literature: !! - vertex/node - the individual elements in the graph !! - edge - the connections between the nodes !! Here, we use the term vertex to refer to the individual elements !! in the graph and edge to refer to the connections between vertices. integer, dimension(:), allocatable :: num_vertex_features !! Number of vertex features for each time step integer, dimension(:), allocatable :: num_edge_features !! Number of edge features for each time step integer :: num_time_steps !! Number of time steps integer :: num_output_vertex_features !! Number of output vertex features integer :: num_output_edge_features !! Number of output edge features integer :: num_outputs !! Number of outputs (if output is not graph structure) integer, dimension(:), allocatable :: num_params_msg !! Number of learnable parameters for each message integer :: num_params_readout !! Number of learnable parameters for the readout contains ! procedure, pass(this) :: set_hyperparams => set_hyperparams_msgpass ! !! Set the hyperparameters for message passing layer procedure, pass(this) :: init => init_msgpass !! Initialise message passing layer ! procedure, pass(this) :: print => print_msgpass ! !! Print the message passing layer ! procedure, pass(this) :: read => read_msgpass ! !! Read the message passing layer procedure, pass(this) :: set_graph => set_graph_msgpass ! procedure, pass(this) :: reduce => layer_reduction ! !! Reduce message passing layer ! procedure, pass(this) :: merge => layer_merge ! !! Merge message passing layer procedure, pass(this) :: get_num_params => get_num_params_msgpass !! Get the number of learnable parameters for message passing layer procedure, pass(this) :: forward => forward_msgpass !! Forward pass for message passing layer procedure(update_message_msgpass), deferred, pass(this) :: update_message !! Update the message procedure(update_readout_msgpass), deferred, pass(this) :: update_readout !! Update the readout end type msgpass_layer_type ! Interface for setting up the MPNN layer !----------------------------------------------------------------------------- interface msgpass_layer_type !! Interface for setting up the MPNN layer module function layer_setup( & num_features, num_time_steps, & verbose & ) result(layer) !! Set up the MPNN layer !!! MAKE THESE ASSUMED RANK integer, dimension(2), intent(in) :: num_features !! Number of features integer, intent(in) :: num_time_steps !! Number of time steps integer, optional, intent(in) :: verbose !! Verbosity level class(msgpass_layer_type), allocatable :: layer !! Instance of the message passing layer end function layer_setup end interface msgpass_layer_type ! Interface for handling the message passing layer parameters !----------------------------------------------------------------------------- interface !! Interfaces for handling learnable parameters and gradients pure module function get_num_params_msgpass(this) result(num_params) !! Get the number of learnable parameters for the message passing layer class(msgpass_layer_type), intent(in) :: this !! Instance of the message passing layer integer :: num_params !! Number of learnable parameters end function get_num_params_msgpass module subroutine set_graph_msgpass(this, graph) !! Set the graph structure of the input data class(msgpass_layer_type), intent(inout) :: this !! Instance of the layer type(graph_type), dimension(:), intent(in) :: graph !! Graph structure of input data end subroutine set_graph_msgpass end interface ! ! Interface for reducing and merging layers ! !----------------------------------------------------------------------------- ! interface ! !! Interfaces for reducing and merging layers ! module subroutine layer_reduction(this, rhs) ! !! Reduce the layer ! class(msgpass_layer_type), intent(inout) :: this ! !! Instance of the message passing layer ! class(learnable_layer_type), intent(in) :: rhs ! !! Instance of the learnable layer (expects a message passing layer) ! end subroutine layer_reduction ! module subroutine layer_merge(this, input) ! !! Merge the layer ! class(msgpass_layer_type), intent(inout) :: this ! !! Instance of the message passing layer ! class(learnable_layer_type), intent(in) :: input ! !! Instance of the learnable layer (expects a message passing layer) ! end subroutine layer_merge ! end interface ! Interface for handling forward and backward passes !----------------------------------------------------------------------------- interface module subroutine forward_msgpass(this, input) !! Forward pass for the message passing layer class(msgpass_layer_type), intent(inout) :: this !! Instance of the layer type class(array_type), dimension(:,:), intent(in) :: input !! Input data (i.e. vertex and edge features) end subroutine forward_msgpass end interface ! Interface for handling graphs and outputs !----------------------------------------------------------------------------- interface !! Interfaces for handling graphs and outputs, and initialising the layer ! module subroutine print_msgpass(this, file) ! !! Print the message passing layer ! class(msgpass_layer_type), intent(in) :: this ! !! Instance of the message passing layer ! character(*), intent(in) :: file ! !! File to print to ! end subroutine print_msgpass ! module subroutine read_msgpass(this, unit, verbose) ! !! Read the message passing layer ! class(msgpass_layer_type), intent(inout) :: this ! !! Instance of the message passing layer ! integer, intent(in) :: unit ! !! Unit to read from ! integer, optional, intent(in) :: verbose ! !! Verbosity level ! end subroutine read_msgpass module subroutine init_msgpass(this, input_shape, verbose) !! Initialise the message passing layer class(msgpass_layer_type), intent(inout) :: this !! Instance of the message passing layer integer, dimension(:), intent(in) :: input_shape !! Input shape integer, optional, intent(in) :: verbose !! Verbosity level end subroutine init_msgpass ! module subroutine set_hyperparams_msgpass( & ! this, num_features, num_time_steps, num_outputs, verbose & ! ) ! !! Set the hyperparameters for the message passing layer ! class(msgpass_layer_type), intent(inout) :: this ! !! Instance of the message passing layer ! integer, dimension(2), intent(in) :: num_features ! !! Number of features ! integer, intent(in) :: num_time_steps ! !! Number of time steps ! integer, intent(in) :: num_outputs ! !! Number of outputs ! integer, optional, intent(in) :: verbose ! !! Verbosity level ! end subroutine set_hyperparams_msgpass end interface !------------------------------------------------------------------------------- !------------------------------------------------------------------------------! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !------------------------------------------------------------------------------! interface !! interface for the message forward and backward passes module subroutine update_message_msgpass(this, input) !! Update the message class(msgpass_layer_type), intent(inout), target :: this !! Instance of the message passing layer class(array_type), dimension(:,:), intent(in), target :: input !! Input data (i.e. vertex and edge features) end subroutine update_message_msgpass end interface interface !! interface for the readout forward and backward passes module subroutine update_readout_msgpass(this) !! Update the message class(msgpass_layer_type), intent(inout), target :: this !! Instance of the message passing layer end subroutine update_readout_msgpass end interface end module athena__msgpass_layer