Module containing implementation of an Orthogonal Attention layer
This module implements the Orthogonal Attention mechanism from "Improved Operator Learning by Orthogonal Attention" (Luo et al., 2024).
Instead of softmax attention, this layer projects queries and keys onto a learned orthonormal basis of dimension , giving a linear-cost approximation to the attention kernel.
Given input :
The orthogonal basis is obtained by QR decomposition of learnable weights .
The attention output is:
The layer output is:
Parameters (learnable): - - - - (basis, orthogonalised) - (bypass) - (optional bias)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| integer, | intent(in) | :: | num_outputs | |||
| integer, | intent(in) | :: | num_basis | |||
| integer, | intent(in), | optional | :: | key_dim | ||
| integer, | intent(in), | optional | :: | num_inputs | ||
| logical, | intent(in), | optional | :: | use_bias | ||
| class(*), | intent(in), | optional | :: | activation | ||
| class(*), | intent(in), | optional | :: | kernel_initialiser | ||
| class(*), | intent(in), | optional | :: | bias_initialiser | ||
| integer, | intent(in), | optional | :: | verbose |
Type for an Orthogonal Attention layer
| Type | Visibility | Attributes | Name | Initial | |||
|---|---|---|---|---|---|---|---|
| class(base_actv_type), | public, | allocatable | :: | activation |
Activation function |
||
| class(base_init_type), | public, | allocatable | :: | bias_init |
Initialisers for kernel and bias |
||
| character(len=14), | public | :: | bias_initialiser | = | '' |
Initialisers for kernel and bias |
|
| integer, | public, | allocatable, dimension(:) | :: | bias_shape |
Shape of biases |
||
| type(graph_type), | public, | allocatable, dimension(:) | :: | graph |
Graph structure of input data |
||
| integer, | public | :: | id |
Unique identifier |
|||
| logical, | public | :: | inference | = | .false. |
Inference mode |
|
| integer, | public | :: | input_rank | = | 0 |
Rank of input data |
|
| integer, | public, | allocatable, dimension(:) | :: | input_shape |
Input shape |
||
| class(base_init_type), | public, | allocatable | :: | kernel_init |
Initialisers for kernel and bias |
||
| character(len=14), | public | :: | kernel_initialiser | = | '' |
Initialisers for kernel and bias |
|
| integer, | public | :: | key_dim | = | 0 |
Dimension of query/key projections (d_k) |
|
| character(len=:), | public, | allocatable | :: | name |
Layer name |
||
| integer, | public | :: | num_basis | = | 0 |
Number of orthogonal basis functions (k) |
|
| integer, | public | :: | num_inputs | = | 0 |
Number of input features / discretisation points |
|
| integer, | public | :: | num_outputs | = | 0 |
Number of output features / discretisation points |
|
| integer, | public | :: | num_params | = | 0 |
Number of learnable parameters |
|
| class(array_type), | public, | allocatable, dimension(:,:) | :: | output |
Output |
||
| integer, | public | :: | output_rank | = | 0 |
Rank of output data |
|
| integer, | public, | allocatable, dimension(:) | :: | output_shape |
Output shape |
||
| type(array_type), | public, | allocatable, dimension(:) | :: | params |
Learnable parameters |
||
| character(len=20), | public | :: | subtype | = | repeat(" ", 20) | ||
| character(len=4), | public | :: | type | = | 'base' |
Layer type |
|
| logical, | public | :: | use_bias | = | .false. |
Layer has bias |
|
| logical, | public | :: | use_graph_input | = | .false. |
Use graph input |
|
| logical, | public | :: | use_graph_output | = | .false. |
Use graph output |
|
| integer, | public, | allocatable, dimension(:,:) | :: | weight_shape |
Shape of weights |
||
| type(array_type), | public, | dimension(1) | :: | z |
Temporary array for pre-activation values |
| private module function layer_setup (num_outputs, num_basis, key_dim, num_inputs, use_bias, activation, kernel_initialiser, bias_initialiser, verbose) |
| final :: finalise_ono_attn |
| procedure, public :: add_t_t => add_learnable | Add two layers |
| procedure, public, pass(this) :: build_from_onnx => build_from_onnx_base | Build layer from ONNX node and initialiser |
| procedure, public, pass(this) :: emit_onnx_graph_inputs => emit_onnx_graph_inputs_base | Emit graph input tensor declarations for this layer |
| procedure, public, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_base | Emit ONNX JSON nodes for this layer (format-aware and polymorphic) |
| procedure, public, pass(this) :: extract_output => extract_output_base | Extract the output of the layer as a standard real array |
| procedure, public, pass(this) :: forward => forward_ono_attn | |
| procedure, public, pass(this) :: forward_eval => forward_eval_base | Forward pass of layer and return output for evaluation |
| procedure, public, pass(this) :: get_attributes => get_attributes_ono_attn | |
| procedure, public, pass(this) :: get_bases => get_bases_ono_attn | |
| procedure, public, pass(this) :: get_gradients | Get parameter gradients of layer |
| procedure, public, pass(this) :: get_num_params => get_num_params_ono_attn | |
| procedure, public, pass(this) :: get_params | Get learnable parameters of layer |
| procedure, public, pass(this) :: init => init_ono_attn | |
| procedure, public, pass(this) :: nullify_graph => nullify_graph_base | Nullify the forward pass data of the layer to free memory |
| generic, public :: operator(+) => add_t_t | Operator overloading for addition |
| procedure, public, pass(this) :: print => print_base | Print the layer to a file with additional information |
| procedure, public, pass(this) :: print_to_unit => print_to_unit_ono_attn | |
| procedure, public, pass(this) :: read => read_ono_attn | |
| procedure, public, pass(this) :: reduce => reduce_learnable | Merge another learnable layer into this one |
| procedure, public, pass(this) :: set_gradients | Set learnable parameters of layer |
| procedure, public, pass(this) :: set_graph => set_graph_base | Set the graph structure of the input data !! this is adjacency and edge weighting |
| procedure, public, pass(this) :: set_hyperparams => set_hyperparams_ono_attn | |
| procedure, public, pass(this) :: set_params | Set learnable parameters of layer |
| procedure, public, pass(this) :: set_rank => set_rank_base | Set the input and output ranks of the layer |
| procedure, public, pass(this) :: set_shape => set_shape_base | Set the input shape of the layer |
Read an orthogonal attention layer from file and return it
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| integer, | intent(in) | :: | unit |
Input unit number |
||
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Allocated base-layer instance containing the result
Return list of orthogonal attention attributes for ONNX export
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(in) | :: | this |
Instance of the orthogonal attention layer |
List of attributes for ONNX export
Orthogonalise the basis matrix B using modified Gram-Schmidt
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(in) | :: | this |
Layer instance providing basis parameters |
Orthogonalised basis matrix packed in an array_type
Return the number of learnable parameters for the layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(in) | :: | this |
Layer instance |
Total number of learnable parameters
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| integer, | intent(in) | :: | num_outputs |
Number of output features |
||
| integer, | intent(in) | :: | num_basis |
Number of orthogonal basis vectors |
||
| integer, | intent(in), | optional | :: | key_dim |
Query/key projection dimension |
|
| integer, | intent(in), | optional | :: | num_inputs |
Number of input features when known at construction time |
|
| logical, | intent(in), | optional | :: | use_bias |
Whether to allocate a bias term |
|
| class(*), | intent(in), | optional | :: | activation |
Activation function specification |
|
| class(*), | intent(in), | optional | :: | kernel_initialiser |
Kernel and bias initialiser specifications |
|
| class(*), | intent(in), | optional | :: | bias_initialiser |
Kernel and bias initialiser specifications |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Constructed orthogonal attention layer
Finalise the orthogonal attention layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| type(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to release |
Forward propagation for the Orthogonal Attention layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to execute |
||
| class(array_type), | intent(in), | dimension(:,:) | :: | input |
Input batch tensor collection |
Initialise parameter storage and output buffers for the layer
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to initialise |
||
| integer, | intent(in), | dimension(:) | :: | input_shape |
Input shape used to infer num_inputs |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
Print orthogonal attention layer settings and parameters to a unit
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(in) | :: | this |
Layer instance to print |
||
| integer, | intent(in) | :: | unit |
Output unit number |
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to populate from file data |
||
| integer, | intent(in) | :: | unit |
Input unit number |
||
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(orthogonal_attention_layer_type), | intent(inout) | :: | this |
Layer instance to configure |
||
| integer, | intent(in) | :: | num_outputs |
Number of output features |
||
| integer, | intent(in) | :: | num_basis |
Number of orthogonal basis vectors |
||
| integer, | intent(in) | :: | key_dim |
Query/key projection dimension |
||
| logical, | intent(in) | :: | use_bias |
Whether to use a bias term |
||
| class(base_actv_type), | intent(in), | allocatable | :: | activation |
Activation function object |
|
| class(base_init_type), | intent(in), | allocatable | :: | kernel_initialiser |
Kernel and bias initialiser objects |
|
| class(base_init_type), | intent(in), | allocatable | :: | bias_initialiser |
Kernel and bias initialiser objects |
|
| integer, | intent(in), | optional | :: | verbose |
Verbosity level |