module athena__diffstruc_extd !! Module for extended differential structure types for Athena use coreutils, only: real32 use diffstruc, only: array_type use athena__misc_types, only: facets_type implicit none private public :: array_ptr_type public :: add_layers, concat_layers public :: add_bias public :: piecewise, softmax, swish public :: huber public :: avgpool1d, avgpool2d, avgpool3d public :: maxpool1d, maxpool2d, maxpool3d public :: pad1d, pad2d, pad3d public :: merge_over_channels public :: batchnorm_array_type, batchnorm, batchnorm_inference public :: conv1d, conv2d, conv3d public :: kipf_propagate, kipf_update public :: duvenaud_propagate, duvenaud_update public :: gno_kernel_eval, gno_aggregate public :: lno_encode, lno_decode, elem_scale public :: ono_encode, ono_decode type, extends(array_type) :: batchnorm_array_type real(real32), dimension(:), allocatable :: mean real(real32), dimension(:), allocatable :: variance real(real32) :: epsilon end type batchnorm_array_type !------------------------------------------------------------------------------- ! Array container types !------------------------------------------------------------------------------- type :: array_ptr_type type(array_type), pointer :: array(:,:) end type array_ptr_type ! Operator interfaces !----------------------------------------------------------------------------- interface add_layers module function add_array_ptr(a, idx1, idx2) result(c) type(array_ptr_type), dimension(:), intent(in) :: a integer, intent(in) :: idx1, idx2 type(array_type), pointer :: c end function add_array_ptr end interface interface concat_layers module function concat_array_ptr(a, idx1, idx2, dim) result(c) type(array_ptr_type), dimension(:), intent(in) :: a integer, intent(in) :: idx1, idx2, dim type(array_type), pointer :: c end function concat_array_ptr end interface !------------------------------------------------------------------------------- !------------------------------------------------------------------------------- ! Activation functions and other operations !------------------------------------------------------------------------------- interface module function add_bias(input, bias, dim, dim_act_on_shape) result(output) class(array_type), intent(in), target :: input class(array_type), intent(in), target :: bias integer, intent(in) :: dim logical, intent(in), optional :: dim_act_on_shape type(array_type), pointer :: output end function add_bias end interface interface piecewise module function piecewise_array(input, gradient, limit) result( output ) class(array_type), intent(in), target :: input real(real32), intent(in) :: gradient real(real32), intent(in) :: limit type(array_type), pointer :: output end function piecewise_array end interface interface softmax module function softmax_array(input, dim) result(output) class(array_type), intent(in), target :: input integer, intent(in) :: dim type(array_type), pointer :: output end function softmax_array end interface interface swish module function swish_array(input, beta) result(output) class(array_type), intent(in), target :: input real(real32), intent(in) :: beta type(array_type), pointer :: output end function swish_array end interface !------------------------------------------------------------------------------- !------------------------------------------------------------------------------- ! Loss functions !------------------------------------------------------------------------------- interface huber module function huber_array(delta, gamma) result( output ) class(array_type), intent(in), target :: delta real(real32), intent(in) :: gamma type(array_type), pointer :: output end function huber_array end interface !------------------------------------------------------------------------------- !------------------------------------------------------------------------------- ! Layer operations !------------------------------------------------------------------------------- interface module function avgpool1d(input, pool_size, stride) result(output) type(array_type), intent(in), target :: input integer, intent(in) :: pool_size integer, intent(in) :: stride type(array_type), pointer :: output end function avgpool1d module function avgpool2d(input, pool_size, stride) result(output) type(array_type), intent(in), target :: input integer, dimension(2), intent(in) :: pool_size integer, dimension(2), intent(in) :: stride type(array_type), pointer :: output end function avgpool2d module function avgpool3d(input, pool_size, stride) result(output) type(array_type), intent(in), target :: input integer, dimension(3), intent(in) :: pool_size integer, dimension(3), intent(in) :: stride type(array_type), pointer :: output end function avgpool3d end interface interface module function maxpool1d(input, pool_size, stride) result(output) type(array_type), intent(in), target :: input integer, intent(in) :: pool_size integer, intent(in) :: stride type(array_type), pointer :: output end function maxpool1d module function maxpool2d(input, pool_size, stride) result(output) type(array_type), intent(in), target :: input integer, dimension(2), intent(in) :: pool_size integer, dimension(2), intent(in) :: stride type(array_type), pointer :: output end function maxpool2d module function maxpool3d(input, pool_size, stride) result(output) type(array_type), intent(in), target :: input integer, dimension(3), intent(in) :: pool_size integer, dimension(3), intent(in) :: stride type(array_type), pointer :: output end function maxpool3d end interface interface module function pad1d(input, facets, pad_size, imethod) result(output) type(array_type), intent(in), target :: input type(facets_type), intent(in) :: facets integer, intent(in) :: pad_size integer, intent(in) :: imethod type(array_type), pointer :: output end function pad1d module function pad2d(input, facets, pad_size, imethod) result(output) type(array_type), intent(in), target :: input type(facets_type), dimension(2), intent(in) :: facets integer, dimension(2), intent(in) :: pad_size integer, intent(in) :: imethod type(array_type), pointer :: output end function pad2d module function pad3d(input, facets, pad_size, imethod) result(output) type(array_type), intent(in), target :: input type(facets_type), dimension(3), intent(in) :: facets integer, dimension(3), intent(in) :: pad_size integer, intent(in) :: imethod type(array_type), pointer :: output end function pad3d end interface interface merge_over_channels module function merge_scalar_over_channels(tsource, fsource, mask) result(output) class(array_type), intent(in), target :: tsource real(real32), intent(in) :: fsource logical, dimension(:,:), intent(in) :: mask type(array_type), pointer :: output end function merge_scalar_over_channels end interface interface module function batchnorm( & input, params, momentum, mean, variance, epsilon & ) result( output ) class(array_type), intent(in), target :: input class(array_type), intent(in), target :: params real(real32), intent(in) :: momentum real(real32), dimension(:), intent(in) :: mean real(real32), dimension(:), intent(in) :: variance real(real32), intent(in) :: epsilon type(batchnorm_array_type), pointer :: output end function batchnorm module function batchnorm_inference( & input, params, mean, variance, epsilon & ) result( output ) class(array_type), intent(in), target :: input class(array_type), intent(in), target :: params real(real32), dimension(:), intent(in) :: mean real(real32), dimension(:), intent(in) :: variance real(real32), intent(in) :: epsilon type(batchnorm_array_type), pointer :: output end function batchnorm_inference end interface interface module function conv1d(input, kernel, stride, dilation) result(output) type(array_type), intent(in), target :: input type(array_type), intent(in), target :: kernel integer, intent(in) :: stride integer, intent(in) :: dilation type(array_type), pointer :: output end function conv1d module function conv2d(input, kernel, stride, dilation) result(output) type(array_type), intent(in), target :: input type(array_type), intent(in), target :: kernel integer, dimension(2), intent(in) :: stride integer, dimension(2), intent(in) :: dilation type(array_type), pointer :: output end function conv2d module function conv3d(input, kernel, stride, dilation) result(output) type(array_type), intent(in), target :: input type(array_type), intent(in), target :: kernel integer, dimension(3), intent(in) :: stride integer, dimension(3), intent(in) :: dilation type(array_type), pointer :: output end function conv3d end interface interface module function kipf_propagate(vertex_features, adj_ia, adj_ja) result(c) !! Propagate values from one autodiff array to another class(array_type), intent(in), target :: vertex_features integer, dimension(:), intent(in) :: adj_ia integer, dimension(:,:), intent(in) :: adj_ja type(array_type), pointer :: c end function kipf_propagate module function kipf_update(a, weight, adj_ia) result(c) !! Update the message passing layer class(array_type), intent(in), target :: a class(array_type), intent(in), target :: weight integer, dimension(:), intent(in) :: adj_ia type(array_type), pointer :: c end function kipf_update end interface interface module function duvenaud_propagate( & vertex_features, edge_features, adj_ia, adj_ja & ) result(c) !! Duvenaud message passing function class(array_type), intent(in), target :: vertex_features class(array_type), intent(in), target :: edge_features integer, dimension(:), intent(in) :: adj_ia integer, dimension(:,:), intent(in) :: adj_ja type(array_type), pointer :: c end function duvenaud_propagate module function duvenaud_update( & a, weight, adj_ia, min_degree, max_degree & ) result(c) !! Duvenaud update function class(array_type), intent(in), target :: a class(array_type), intent(in), target :: weight integer, dimension(:), intent(in) :: adj_ia integer, intent(in) :: min_degree, max_degree type(array_type), pointer :: c end function duvenaud_update end interface interface module function gno_kernel_eval( & coords, kernel_params, adj_ia, adj_ja, & coord_dim, kernel_hidden, F_in, F_out & ) result(c) !! Evaluate GNO kernel MLP on every edge class(array_type), intent(in), target :: coords class(array_type), intent(in), target :: kernel_params integer, dimension(:), intent(in) :: adj_ia integer, dimension(:,:), intent(in) :: adj_ja integer, intent(in) :: coord_dim, kernel_hidden, F_in, F_out type(array_type), pointer :: c end function gno_kernel_eval module function gno_aggregate( & features, edge_kernels, adj_ia, adj_ja, F_in, F_out & ) result(c) !! Aggregate neighbour messages using per-edge kernels class(array_type), intent(in), target :: features class(array_type), intent(in), target :: edge_kernels integer, dimension(:), intent(in) :: adj_ia integer, dimension(:,:), intent(in) :: adj_ja integer, intent(in) :: F_in, F_out type(array_type), pointer :: c end function gno_aggregate end interface interface module function lno_encode( & input, poles, num_inputs, num_modes & ) result(c) !! Encode input via Laplace basis: E(mu) @ u class(array_type), intent(in), target :: input class(array_type), intent(in), target :: poles integer, intent(in) :: num_inputs, num_modes type(array_type), pointer :: c end function lno_encode module function lno_decode( & spectral, poles, num_outputs, num_modes & ) result(c) !! Decode via Laplace basis: D(mu) @ spectral class(array_type), intent(in), target :: spectral class(array_type), intent(in), target :: poles integer, intent(in) :: num_outputs, num_modes type(array_type), pointer :: c end function lno_decode end interface interface module function elem_scale(input, scale) result(c) !! Element-wise multiply: out[i,s] = input[i,s] * scale[i,1] !! Correctly handles non-sample-dependent scale vectors. class(array_type), intent(in), target :: input class(array_type), intent(in), target :: scale type(array_type), pointer :: c end function elem_scale end interface interface module function ono_encode( & input, basis_weights, num_inputs, num_basis & ) result(c) !! Encode via orthogonal basis: Q(B)^T @ u class(array_type), intent(in), target :: input class(array_type), intent(in), target :: basis_weights integer, intent(in) :: num_inputs, num_basis type(array_type), pointer :: c end function ono_encode module function ono_decode( & mixed, basis_weights, num_inputs, num_basis & ) result(c) !! Decode via orthogonal basis: Q(B) @ mixed class(array_type), intent(in), target :: mixed class(array_type), intent(in), target :: basis_weights integer, intent(in) :: num_inputs, num_basis type(array_type), pointer :: c end function ono_decode end interface !------------------------------------------------------------------------------- end module athena__diffstruc_extd