Orthogonal Attention Layer¶
orthogonal_attention_layer_type
orthogonal_attention_layer_type(
num_outputs,
num_basis,
key_dim=...,
num_inputs=...,
use_bias=.true.,
activation="none",
kernel_initialiser=...,
bias_initialiser=...
)
The orthogonal_attention_layer_type derived type provides a stabilised orthogonal attention layer.
It uses a learned low-rank orthonormal basis to construct a global spectral representation, applies
normalised per-mode attention in that basis, and combines this with a local bypass:
The attention operation is defined in three stages: projection to an orthogonal basis, stable attention weighting, and reconstruction:
where:
\(\mathbf{u} \in \mathbb{R}^{n_{in}}\) is the input sampled on a grid
\(\mathbf{\Phi} \in \mathbb{R}^{n_{in} \times k}\) is the learned orthonormal basis obtained from basis weights \(\mathbf{B}\)
\(\mathbf{c}\) are spectral coefficients in the orthogonal basis
\(\mathbf{W}_Q, \mathbf{W}_K \in \mathbb{R}^{d_k \times n_{in}}\) are the query and key projection weights
\(\mathbf{a} \in \mathbb{R}^{k}\) are normalised per-basis attention weights
\(\mathbf{W}_V \in \mathbb{R}^{n_{out} \times n_{in}}\) is the value projection matrix
\(\mathbf{W} \in \mathbb{R}^{n_{out} \times n_{in}}\) is the local bypass matrix
\(\mathbf{b} \in \mathbb{R}^{n_{out}}\) is the bias vector when
use_bias=.true.\(k\) is
num_basisand \(d_k\) iskey_dim\(\odot\) denotes element-wise multiplication
\(\sigma\) is the activation function
This formulation differs from a standard dot-product attention mechanism in that attention is applied directly to orthogonal spectral coefficients rather than pairwise token interactions. The use of bounded interactions (\(\tanh\)) and softmax normalisation ensures numerical stability, while the residual spectral update preserves information across basis modes.
Arguments¶
num_outputs (
integer): Number of output discretisation points.num_basis (
integer): Number of orthogonal basis functions.key_dim (
integer): Dimension of the query and key projections. If not provided, it defaults tonum_basis.num_inputs (
integer): Number of input discretisation points. If not provided, it is inferred when the layer is initialised.use_bias (
logical): If.false., the layer will not use a bias term. Default:.true..activation (
class(*)): Activation function for the layer.Accepts
character(*)orclass(base_actv_type).See Activation Functions for available options.
Default:
none_actv_type.
kernel_initialiser (
class(*)): Initialiser for \(\mathbf{W}_Q\), \(\mathbf{W}_K\), \(\mathbf{W}_V\), \(\mathbf{B}\), and \(\mathbf{W}\) (see Initialisers).If
activationisselu_actv_type, default:lecun_normal_init_type.If
activationis a version ofrelu_actv_type, default:he_normal_init_type.For all other activations, default:
glorot_uniform_init_type.
bias_initialiser (
class(*)): Initialiser for the biases (see Initialisers). Default:zeros_init_type.
Shape¶
Input:
(num_inputs, batch_size).Output:
(num_outputs, batch_size).
Parameters¶
The layer contains the following learnable parameters:
W_Q: Query projection matrix of shape
(key_dim, num_inputs).W_K: Key projection matrix of shape
(key_dim, num_inputs).W_V: Value projection matrix of shape
(num_outputs, num_inputs).B: Basis weight matrix of shape
(num_inputs, num_basis).W: Local bypass matrix of shape
(num_outputs, num_inputs).b: Bias vector of shape
(num_outputs)whenuse_bias=.true..
The following tensor is derived from the basis weights and rebuilt during forward propagation:
Phi: Orthogonal basis of shape
(num_inputs, num_basis).
Total learnable parameters:
With bias:
2 * key_dim * num_inputs + 2 * num_outputs * num_inputs + num_inputs * num_basis + num_outputsWithout bias:
2 * key_dim * num_inputs + 2 * num_outputs * num_inputs + num_inputs * num_basis
Examples¶
Basic orthogonal attention block:
use athena
type(network_type) :: network
call network%add(orthogonal_attention_layer_type( &
num_inputs=128, &
num_outputs=128, &
num_basis=16, &
key_dim=16, &
activation="relu" &
))
Orthogonal attention stack with dense readout:
call network%add(orthogonal_attention_layer_type( &
num_inputs=256, &
num_outputs=256, &
num_basis=32, &
key_dim=32, &
activation="swish" &
))
call network%add(orthogonal_attention_layer_type( &
num_outputs=128, &
num_basis=16, &
key_dim=16, &
activation="swish" &
))
call network%add(full_layer_type( &
num_outputs=1, &
activation="none" &
))
Notes¶
num_basiscontrols the rank of the orthogonal projection used to approximate the global interaction.key_dimcontrols the size of the exposed query and key parameterisation, even though the present forward path uses the orthogonal projection form.This layer is useful when you want an operator-style global coupling block without fixing a spectral basis analytically.
See Also¶
orthogonal_nop_block_type - Orthogonal neural operator block with spectral mixing on the same learned basis
neural_operator_layer_type - Simpler mean-field neural operator layer
fixed_lno_layer_type - Laplace neural operator layer with fixed encoder/decoder bases and spectral mixing
full_layer_type - Standard dense layer