athena__orthogonal_attention_layer Module

Module containing implementation of an Orthogonal Attention layer

This module implements the Orthogonal Attention mechanism from "Improved Operator Learning by Orthogonal Attention" (Luo et al., 2024).

Instead of softmax attention, this layer projects queries and keys onto a learned orthonormal basis of dimension , giving a linear-cost approximation to the attention kernel.

Given input :

The orthogonal basis is obtained by QR decomposition of learnable weights .

The attention output is:

The layer output is:

Parameters (learnable): - - - - (basis, orthogonalised) - (bypass) - (optional bias)



Interfaces

  • private module function layer_setup(num_outputs, num_basis, key_dim, num_inputs, use_bias, activation, kernel_initialiser, bias_initialiser, verbose) result(layer)

    Arguments

    Type IntentOptional Attributes Name
    integer, intent(in) :: num_outputs
    integer, intent(in) :: num_basis
    integer, intent(in), optional :: key_dim
    integer, intent(in), optional :: num_inputs
    logical, intent(in), optional :: use_bias
    class(*), intent(in), optional :: activation
    class(*), intent(in), optional :: kernel_initialiser
    class(*), intent(in), optional :: bias_initialiser
    integer, intent(in), optional :: verbose

    Return Value type(orthogonal_attention_layer_type)


Derived Types

Type for an Orthogonal Attention layer

Components

Type Visibility Attributes Name Initial
class(base_actv_type), public, allocatable :: activation

Activation function

class(base_init_type), public, allocatable :: bias_init

Initialisers for kernel and bias

character(len=14), public :: bias_initialiser = ''

Initialisers for kernel and bias

integer, public, allocatable, dimension(:) :: bias_shape

Shape of biases

type(graph_type), public, allocatable, dimension(:) :: graph

Graph structure of input data

integer, public :: id

Unique identifier

logical, public :: inference = .false.

Inference mode

integer, public :: input_rank = 0

Rank of input data

integer, public, allocatable, dimension(:) :: input_shape

Input shape

class(base_init_type), public, allocatable :: kernel_init

Initialisers for kernel and bias

character(len=14), public :: kernel_initialiser = ''

Initialisers for kernel and bias

integer, public :: key_dim = 0

Dimension of query/key projections (d_k)

character(len=:), public, allocatable :: name

Layer name

integer, public :: num_basis = 0

Number of orthogonal basis functions (k)

integer, public :: num_inputs = 0

Number of input features / discretisation points

integer, public :: num_outputs = 0

Number of output features / discretisation points

integer, public :: num_params = 0

Number of learnable parameters

class(array_type), public, allocatable, dimension(:,:) :: output

Output

integer, public :: output_rank = 0

Rank of output data

integer, public, allocatable, dimension(:) :: output_shape

Output shape

type(array_type), public, allocatable, dimension(:) :: params

Learnable parameters

character(len=20), public :: subtype = repeat(" ", 20)
character(len=4), public :: type = 'base'

Layer type

logical, public :: use_bias = .false.

Layer has bias

logical, public :: use_graph_input = .false.

Use graph input

logical, public :: use_graph_output = .false.

Use graph output

integer, public, allocatable, dimension(:,:) :: weight_shape

Shape of weights

type(array_type), public, dimension(1) :: z

Temporary array for pre-activation values

Constructor

private module function layer_setup (num_outputs, num_basis, key_dim, num_inputs, use_bias, activation, kernel_initialiser, bias_initialiser, verbose)

Finalizations Procedures

final :: finalise_ono_attn

Type-Bound Procedures

procedure, public :: add_t_t => add_learnable

Add two layers

procedure, public, pass(this) :: build_from_onnx => build_from_onnx_base

Build layer from ONNX node and initialiser

procedure, public, pass(this) :: emit_onnx_graph_inputs => emit_onnx_graph_inputs_base

Emit graph input tensor declarations for this layer

procedure, public, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_base

Emit ONNX JSON nodes for this layer (format-aware and polymorphic)

procedure, public, pass(this) :: extract_output => extract_output_base

Extract the output of the layer as a standard real array

procedure, public, pass(this) :: forward => forward_ono_attn
procedure, public, pass(this) :: forward_eval => forward_eval_base

Forward pass of layer and return output for evaluation

procedure, public, pass(this) :: get_attributes => get_attributes_ono_attn
procedure, public, pass(this) :: get_bases => get_bases_ono_attn
procedure, public, pass(this) :: get_gradients

Get parameter gradients of layer

procedure, public, pass(this) :: get_num_params => get_num_params_ono_attn
procedure, public, pass(this) :: get_params

Get learnable parameters of layer

procedure, public, pass(this) :: init => init_ono_attn
procedure, public, pass(this) :: nullify_graph => nullify_graph_base

Nullify the forward pass data of the layer to free memory

Read more…
generic, public :: operator(+) => add_t_t

Operator overloading for addition

procedure, public, pass(this) :: print => print_base

Print the layer to a file with additional information

procedure, public, pass(this) :: print_to_unit => print_to_unit_ono_attn
procedure, public, pass(this) :: read => read_ono_attn
procedure, public, pass(this) :: reduce => reduce_learnable

Merge another learnable layer into this one

procedure, public, pass(this) :: set_gradients

Set learnable parameters of layer

procedure, public, pass(this) :: set_graph => set_graph_base

Set the graph structure of the input data !! this is adjacency and edge weighting

procedure, public, pass(this) :: set_hyperparams => set_hyperparams_ono_attn
procedure, public, pass(this) :: set_params

Set learnable parameters of layer

procedure, public, pass(this) :: set_rank => set_rank_base

Set the input and output ranks of the layer

procedure, public, pass(this) :: set_shape => set_shape_base

Set the input shape of the layer


Functions

public function read_orthogonal_attention_layer(unit, verbose) result(layer)

Read an orthogonal attention layer from file and return it

Arguments

Type IntentOptional Attributes Name
integer, intent(in) :: unit

Input unit number

integer, intent(in), optional :: verbose

Verbosity level

Return Value class(base_layer_type), allocatable

Allocated base-layer instance containing the result

private function get_attributes_ono_attn(this) result(attributes)

Return list of orthogonal attention attributes for ONNX export

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(in) :: this

Instance of the orthogonal attention layer

Return Value type(onnx_attribute_type), allocatable, dimension(:)

List of attributes for ONNX export

private function get_bases_ono_attn(this) result(phi)

Orthogonalise the basis matrix B using modified Gram-Schmidt

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(in) :: this

Layer instance providing basis parameters

Return Value type(array_type)

Orthogonalised basis matrix packed in an array_type

private pure function get_num_params_ono_attn(this) result(num_params)

Return the number of learnable parameters for the layer

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(in) :: this

Layer instance

Return Value integer

Total number of learnable parameters

private module function layer_setup(num_outputs, num_basis, key_dim, num_inputs, use_bias, activation, kernel_initialiser, bias_initialiser, verbose) result(layer)

Arguments

Type IntentOptional Attributes Name
integer, intent(in) :: num_outputs

Number of output features

integer, intent(in) :: num_basis

Number of orthogonal basis vectors

integer, intent(in), optional :: key_dim

Query/key projection dimension

integer, intent(in), optional :: num_inputs

Number of input features when known at construction time

logical, intent(in), optional :: use_bias

Whether to allocate a bias term

class(*), intent(in), optional :: activation

Activation function specification

class(*), intent(in), optional :: kernel_initialiser

Kernel and bias initialiser specifications

class(*), intent(in), optional :: bias_initialiser

Kernel and bias initialiser specifications

integer, intent(in), optional :: verbose

Verbosity level

Return Value type(orthogonal_attention_layer_type)

Constructed orthogonal attention layer


Subroutines

private subroutine finalise_ono_attn(this)

Finalise the orthogonal attention layer

Arguments

Type IntentOptional Attributes Name
type(orthogonal_attention_layer_type), intent(inout) :: this

Layer instance to release

private subroutine forward_ono_attn(this, input)

Forward propagation for the Orthogonal Attention layer

Read more…

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(inout) :: this

Layer instance to execute

class(array_type), intent(in), dimension(:,:) :: input

Input batch tensor collection

private subroutine init_ono_attn(this, input_shape, verbose)

Initialise parameter storage and output buffers for the layer

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(inout) :: this

Layer instance to initialise

integer, intent(in), dimension(:) :: input_shape

Input shape used to infer num_inputs

integer, intent(in), optional :: verbose

Verbosity level

private subroutine print_to_unit_ono_attn(this, unit)

Print orthogonal attention layer settings and parameters to a unit

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(in) :: this

Layer instance to print

integer, intent(in) :: unit

Output unit number

private subroutine read_ono_attn(this, unit, verbose)

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(inout) :: this

Layer instance to populate from file data

integer, intent(in) :: unit

Input unit number

integer, intent(in), optional :: verbose

Verbosity level

private subroutine set_hyperparams_ono_attn(this, num_outputs, num_basis, key_dim, use_bias, activation, kernel_initialiser, bias_initialiser, verbose)

Arguments

Type IntentOptional Attributes Name
class(orthogonal_attention_layer_type), intent(inout) :: this

Layer instance to configure

integer, intent(in) :: num_outputs

Number of output features

integer, intent(in) :: num_basis

Number of orthogonal basis vectors

integer, intent(in) :: key_dim

Query/key projection dimension

logical, intent(in) :: use_bias

Whether to use a bias term

class(base_actv_type), intent(in), allocatable :: activation

Activation function object

class(base_init_type), intent(in), allocatable :: kernel_initialiser

Kernel and bias initialiser objects

class(base_init_type), intent(in), allocatable :: bias_initialiser

Kernel and bias initialiser objects

integer, intent(in), optional :: verbose

Verbosity level