module athena__kipf_msgpass_layer !! Module implementing Kipf & Welling Graph Convolutional Network (GCN) !! !! This module implements the graph convolutional layer from Kipf & Welling !! (2017) with symmetric degree normalisation for semi-supervised learning. !! !! Mathematical operation: !! \[ H^{(l+1)} = \sigma\left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) \] !! !! where: !! * \( \tilde{A} = A + I \) (adjacency matrix with added self-loops) !! * \( \tilde{D} \) is the degree matrix of \( \tilde{A} \) !! * \( H^{(l)} \) is the node feature matrix at layer l !! * \( W^{(l)} \) is a learnable weight matrix !! * \( \sigma \) is the activation function !! !! The normalisation \( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} \) ensures !! proper scaling by degree. !! Preserves graph structure, producing node-level (not graph-level) outputs. !! !! Reference: Kipf & Welling (2017), ICLR use coreutils, only: real32, stop_program use graphstruc, only: graph_type use athena__misc_types, only: base_actv_type, base_init_type, & onnx_attribute_type, onnx_node_type, onnx_initialiser_type, & onnx_tensor_type use diffstruc, only: array_type use athena__base_layer, only: base_layer_type use athena__msgpass_layer, only: msgpass_layer_type use athena__diffstruc_extd, only: kipf_propagate, kipf_update use diffstruc, only: matmul implicit none private public :: kipf_msgpass_layer_type public :: read_kipf_msgpass_layer !------------------------------------------------------------------------------- ! Message passing layer !------------------------------------------------------------------------------- type, extends(msgpass_layer_type) :: kipf_msgpass_layer_type ! this is for chen 2021 et al ! type(array2d_type), dimension(:), allocatable :: edge_weight ! !! Weights for the edges ! type(array2d_type), dimension(:), allocatable :: vertex_weight ! !! Weights for the vertices contains procedure, pass(this) :: get_num_params => get_num_params_kipf !! Get the number of parameters for the message passing layer procedure, pass(this) :: set_hyperparams => set_hyperparams_kipf !! Set the hyperparameters for the message passing layer procedure, pass(this) :: init => init_kipf !! Initialise the message passing layer procedure, pass(this) :: print_to_unit => print_to_unit_kipf !! Print the message passing layer procedure, pass(this) :: read => read_kipf !! Read the message passing layer procedure, pass(this) :: update_message => update_message_kipf !! Update the message procedure, pass(this) :: update_readout => update_readout_kipf !! Update the readout procedure, pass(this) :: get_attributes => get_attributes_kipf !! Get the attributes of the layer (for ONNX export) procedure, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_kipf !! Emit ONNX JSON nodes for Kipf GCN layer end type kipf_msgpass_layer_type ! Interface for setting up the MPNN layer !----------------------------------------------------------------------------- interface kipf_msgpass_layer_type !! Interface for setting up the MPNN layer module function layer_setup( & num_vertex_features, num_time_steps, & activation, & kernel_initialiser, & verbose & ) result(layer) !! Set up the message passing layer integer, dimension(:), intent(in) :: num_vertex_features !! Number of features integer, intent(in) :: num_time_steps !! Number of time steps class(*), optional, intent(in) :: activation, kernel_initialiser !! Activation function and kernel initialiser integer, optional, intent(in) :: verbose !! Verbosity level type(kipf_msgpass_layer_type) :: layer !! Instance of the message passing layer end function layer_setup end interface kipf_msgpass_layer_type contains !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### pure function get_num_params_kipf(this) result(num_params) !! Get the number of parameters for the message passing layer !! !! This function calculates the number of parameters for the message passing !! layer. !! This procedure is based on code from the neural-fortran library implicit none ! Arguments class(kipf_msgpass_layer_type), intent(in) :: this !! Instance of the message passing layer integer :: num_params !! Number of parameters ! Local variables integer :: t !! Loop index num_params = 0 do t = 1, this%num_time_steps num_params = num_params + & this%num_vertex_features(t-1) * this%num_vertex_features(t) end do end function get_num_params_kipf !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### module function layer_setup( & num_vertex_features, num_time_steps, & activation, & kernel_initialiser, & verbose & ) result(layer) !! Set up the message passing layer use athena__activation, only: activation_setup use athena__initialiser, only: initialiser_setup implicit none ! Arguments integer, dimension(:), intent(in) :: num_vertex_features !! Number of features integer, intent(in) :: num_time_steps !! Number of time steps class(*), optional, intent(in) :: activation, kernel_initialiser !! Activation function and kernel initialiser integer, optional, intent(in) :: verbose !! Verbosity level type(kipf_msgpass_layer_type) :: layer !! Instance of the message passing layer ! Local variables integer :: verbose_ = 0 !! Verbosity level class(base_actv_type), allocatable :: activation_ !! Activation function object class(base_init_type), allocatable :: kernel_initialiser_ !! Kernel initialisers if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Set activation functions based on input name !--------------------------------------------------------------------------- if(present(activation))then activation_ = activation_setup(activation) else activation_ = activation_setup("none") end if !--------------------------------------------------------------------------- ! Define weights (kernels) and biases initialisers !--------------------------------------------------------------------------- if(present(kernel_initialiser))then kernel_initialiser_ = initialiser_setup(kernel_initialiser) end if !--------------------------------------------------------------------------- ! Set hyperparameters !--------------------------------------------------------------------------- call layer%set_hyperparams( & num_vertex_features = num_vertex_features, & num_time_steps = num_time_steps, & activation = activation_, & kernel_initialiser = kernel_initialiser_, & verbose = verbose_ & ) !--------------------------------------------------------------------------- ! Initialise layer shape !--------------------------------------------------------------------------- call layer%init(input_shape=[layer%num_vertex_features(0), 0]) end function layer_setup !############################################################################### !############################################################################### subroutine set_hyperparams_kipf( & this, & num_vertex_features, & num_time_steps, & activation, & kernel_initialiser, & verbose & ) !! Set the hyperparameters for the message passing layer use athena__activation, only: activation_setup use athena__initialiser, only: get_default_initialiser, initialiser_setup implicit none ! Arguments class(kipf_msgpass_layer_type), intent(inout) :: this !! Instance of the message passing layer integer, dimension(:), intent(in) :: num_vertex_features !! Number of vertex features integer, intent(in) :: num_time_steps !! Number of time steps class(base_actv_type), allocatable, intent(in) :: activation !! Activation function class(base_init_type), allocatable, intent(in) :: kernel_initialiser !! Kernel initialiser integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: t !! Loop index character(len=256) :: buffer this%name = 'kipf' this%type = 'msgp' this%input_rank = 2 this%output_rank = 2 this%use_graph_output = .true. this%num_time_steps = num_time_steps if(allocated(this%num_vertex_features)) & deallocate(this%num_vertex_features) if(allocated(this%num_edge_features)) & deallocate(this%num_edge_features) if(size(num_vertex_features, 1) .eq. 1)then allocate( & this%num_vertex_features(0:num_time_steps), & source = num_vertex_features(1) & ) elseif(size(num_vertex_features, 1) .eq. num_time_steps + 1)then allocate( & this%num_vertex_features(0:this%num_time_steps), & source = num_vertex_features & ) else call stop_program( & "Error: num_vertex_features must be a scalar or a vector of length & &num_time_steps + 1" & ) end if allocate( this%num_edge_features(0:this%num_time_steps), source = 0 ) this%use_graph_input = .true. if(allocated(this%activation)) deallocate(this%activation) if(.not.allocated(activation))then this%activation = activation_setup("none") else allocate(this%activation, source=activation) end if if(allocated(this%kernel_init)) deallocate(this%kernel_init) if(.not.allocated(kernel_initialiser))then buffer = get_default_initialiser(this%activation%name) this%kernel_init = initialiser_setup(buffer) else allocate(this%kernel_init, source=kernel_initialiser) end if if(present(verbose))then if(abs(verbose).gt.0)then write(*,'("KIPF activation function: ",A)') & trim(this%activation%name) write(*,'("KIPF kernel initialiser: ",A)') & trim(this%kernel_init%name) end if end if if(allocated(this%num_params_msg)) deallocate(this%num_params_msg) allocate(this%num_params_msg(1:this%num_time_steps)) do t = 1, this%num_time_steps this%num_params_msg(t) = & this%num_vertex_features(t-1) * this%num_vertex_features(t) end do if(allocated(this%input_shape)) deallocate(this%input_shape) if(allocated(this%output_shape)) deallocate(this%output_shape) end subroutine set_hyperparams_kipf !############################################################################### !############################################################################### subroutine init_kipf(this, input_shape, verbose) !! Initialise the message passing layer use athena__initialiser, only: initialiser_setup implicit none ! Arguments class(kipf_msgpass_layer_type), intent(inout) :: this !! Instance of the fully connected layer integer, dimension(:), intent(in) :: input_shape !! Input shape integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: t !! Loop index integer :: verbose_ = 0 !! Verbosity level !--------------------------------------------------------------------------- ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose !--------------------------------------------------------------------------- ! Initialise number of inputs !--------------------------------------------------------------------------- if(.not.allocated(this%input_shape)) call this%set_shape(input_shape) this%output_shape = [this%num_vertex_features(this%num_time_steps), 0] this%num_params = this%get_num_params() if(allocated(this%weight_shape)) deallocate(this%weight_shape) if(allocated(this%bias_shape)) deallocate(this%bias_shape) allocate(this%weight_shape(2,this%num_time_steps)) do t = 1, this%num_time_steps this%weight_shape(:,t) = & [ this%num_vertex_features(t), this%num_vertex_features(t-1) ] end do !--------------------------------------------------------------------------- ! Allocate weight, weight steps (velocities), output, and activation !--------------------------------------------------------------------------- if(allocated(this%params)) deallocate(this%params) allocate(this%params(this%num_time_steps)) do t = 1, this%num_time_steps call this%params(t)%allocate( & array_shape = [ this%weight_shape(:,t), 1 ] & ) call this%params(t)%set_requires_grad(.true.) this%params(t)%is_sample_dependent = .false. this%params(t)%is_temporary = .false. this%params(t)%fix_pointer = .true. end do !--------------------------------------------------------------------------- ! Initialise weights (kernels) !--------------------------------------------------------------------------- do t = 1, this%num_time_steps call this%kernel_init%initialise( & this%params(t)%val(:,1), & fan_in = this%num_vertex_features(t-1), & fan_out = this%num_vertex_features(t), & spacing = [ this%num_vertex_features(t) ] & ) end do !--------------------------------------------------------------------------- ! Allocate arrays !--------------------------------------------------------------------------- if(allocated(this%output)) deallocate(this%output) end subroutine init_kipf !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !############################################################################### subroutine print_to_unit_kipf(this, unit) !! Print kipf message passing layer to unit use coreutils, only: to_upper implicit none ! Arguments class(kipf_msgpass_layer_type), intent(in) :: this !! Instance of the message passing layer integer, intent(in) :: unit !! File unit ! Local variables integer :: t !! Loop index character(100) :: fmt !! Format string ! Write initial parameters !--------------------------------------------------------------------------- write(unit,'(3X,"NUM_TIME_STEPS = ",I0)') this%num_time_steps write(fmt,'("(3X,""NUM_VERTEX_FEATURES ="",",I0,"(1X,I0))")') & this%num_time_steps + 1 write(unit,fmt) this%num_vertex_features if(this%activation%name .ne. 'none')then call this%activation%print_to_unit(unit) end if ! Write learned parameters !--------------------------------------------------------------------------- write(unit,'("WEIGHTS")') do t = 1, this%num_time_steps, 1 write(unit,'(5(E16.8E2))') this%params(t)%val end do write(unit,'("END WEIGHTS")') end subroutine print_to_unit_kipf !############################################################################### !############################################################################### subroutine read_kipf(this, unit, verbose) !! Read the message passing layer use athena__tools_infile, only: assign_val, assign_vec, get_val, move use coreutils, only: to_lower, to_upper, icount use athena__activation, only: read_activation use athena__initialiser, only: initialiser_setup implicit none ! Arguments class(kipf_msgpass_layer_type), intent(inout) :: this !! Instance of the message passing layer integer, intent(in) :: unit !! Unit to read from integer, optional, intent(in) :: verbose !! Verbosity level ! Local variables integer :: stat !! Status of read integer :: verbose_ = 0 !! Verbosity level integer :: t, j, k, c, itmp1, iline !! Loop variables and temporary integer integer :: num_time_steps = 0 !! Number of time steps character(14) :: kernel_initialiser_name='' !! Initialisers character(20) :: activation_name='' !! Activation function name class(base_actv_type), allocatable :: activation !! Activation function class(base_init_type), allocatable :: kernel_initialiser !! Initialisers integer, dimension(:), allocatable :: num_vertex_features !! Number of vertex and edge features character(256) :: buffer, tag, err_msg !! Buffer, tag, and error message real(real32), allocatable, dimension(:) :: data_list !! Data list integer :: param_line, final_line !! Parameter line number ! Initialise optional arguments !--------------------------------------------------------------------------- if(present(verbose)) verbose_ = verbose ! Loop over tags in layer card !--------------------------------------------------------------------------- iline = 0 param_line = 0 final_line = 0 tag_loop: do ! Check for end of file !------------------------------------------------------------------------ read(unit,'(A)',iostat=stat) buffer if(stat.ne.0)then write(err_msg,'("file encountered error (EoF?) before END ",A)') & to_upper(this%name) call stop_program(err_msg) return end if if(trim(adjustl(buffer)).eq."") cycle tag_loop ! Check for end of layer card !------------------------------------------------------------------------ if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then final_line = iline backspace(unit) exit tag_loop end if iline = iline + 1 tag=trim(adjustl(buffer)) if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) ! Read parameters from file !------------------------------------------------------------------------ select case(trim(tag)) case("NUM_TIME_STEPS") call assign_val(buffer, num_time_steps, itmp1) case("NUM_VERTEX_FEATURES") itmp1 = icount(get_val(buffer)) allocate(num_vertex_features(itmp1), source=0) call assign_vec(buffer, num_vertex_features, itmp1) case("ACTIVATION") iline = iline - 1 backspace(unit) activation = read_activation(unit, iline) case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALisER") call assign_val(buffer, kernel_initialiser_name, itmp1) case("WEIGHTS") kernel_initialiser_name = 'zeros' param_line = iline case default ! Don't look for "e" due to scientific notation of numbers ! ... i.e. exponent (E+00) if(scan(to_lower(trim(adjustl(buffer))),& 'abcdfghijklmnopqrstuvwxyz').eq.0)then cycle tag_loop elseif(tag(:3).eq.'END')then cycle tag_loop end if write(err_msg,'("Unrecognised line in input file: ",A)') & trim(adjustl(buffer)) call stop_program(err_msg) return end select end do tag_loop kernel_initialiser = initialiser_setup(kernel_initialiser_name) ! Set hyperparameters and initialise layer !--------------------------------------------------------------------------- if(num_time_steps.gt.0 .and. num_time_steps.ne.size(num_vertex_features,1)-1)then write(err_msg,'("NUM_TIME_STEPS = ",I0," does not match length of "// & &"NUM_VERTEX_FEATURES = ",I0)') num_time_steps, & size(num_vertex_features,1)-1 call stop_program(err_msg) return end if call this%set_hyperparams( & num_time_steps = num_time_steps, & num_vertex_features = num_vertex_features, & activation = activation, & kernel_initialiser = kernel_initialiser, & verbose = verbose_ & ) call this%init(input_shape=[this%num_vertex_features(0), 0]) ! Check if WEIGHTS card was found !--------------------------------------------------------------------------- if(param_line.eq.0)then write(0,*) "WARNING: WEIGHTS card in "//to_upper(trim(this%name))//" not found" else call move(unit, param_line - iline, iostat=stat) do t = 1, this%num_time_steps allocate(data_list(this%num_params_msg(t)), source=0._real32) c = 1 k = 1 data_concat_loop: do while(c.le.this%num_params_msg(t)) read(unit,'(A)',iostat=stat) buffer if(stat.ne.0) exit data_concat_loop k = icount(buffer) read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1) c = c + k end do data_concat_loop this%params(t)%val(:,1) = data_list(1:this%num_params_msg(t)) deallocate(data_list) end do ! Check for end of weights card !------------------------------------------------------------------------ read(unit,'(A)') buffer if(trim(adjustl(buffer)).ne."END WEIGHTS")then write(0,*) trim(adjustl(buffer)) call stop_program("END WEIGHTS not where expected") return end if end if !--------------------------------------------------------------------------- ! Check for end of layer card !--------------------------------------------------------------------------- read(unit,'(A)') buffer if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then write(0,*) trim(adjustl(buffer)) write(err_msg,'("END ",A," not where expected")') to_upper(this%name) call stop_program(err_msg) return end if end subroutine read_kipf !############################################################################### !############################################################################### function read_kipf_msgpass_layer(unit, verbose) result(layer) !! Read kipf message passing layer from file and return layer implicit none ! Arguments integer, intent(in) :: unit !! Unit number integer, optional, intent(in) :: verbose !! Verbosity level class(base_layer_type), allocatable :: layer !! Instance of the message passing layer ! Local variables integer :: verbose_ = 0 !! Verbosity level if(present(verbose)) verbose_ = verbose allocate(layer, source = kipf_msgpass_layer_type( & num_time_steps = 1, & num_vertex_features = [ 0, 0 ] & )) call layer%read(unit, verbose=verbose_) end function read_kipf_msgpass_layer !############################################################################### !############################################################################### function get_attributes_kipf(this) result(attributes) !! Get the attributes of the Kipf GCN layer (for ONNX export) implicit none ! Arguments class(kipf_msgpass_layer_type), intent(in) :: this !! Instance of the message passing layer type(onnx_attribute_type), allocatable, dimension(:) :: attributes !! Attributes for ONNX export ! Local variables integer :: t !! Loop index character(256) :: buffer !! Buffer for converting attributes to strings allocate(attributes(3)) write(buffer, '(I0)') this%num_time_steps attributes(1) = onnx_attribute_type( & name='num_time_steps', type='int', val=trim(buffer)) buffer = '' do t = 0, this%num_time_steps if(t .eq. 0)then write(buffer, '(I0)') this%num_vertex_features(t) else write(buffer, '(A," ",I0)') trim(buffer), this%num_vertex_features(t) end if end do attributes(2) = onnx_attribute_type( & name='num_vertex_features', type='ints', val=trim(buffer)) attributes(3) = onnx_attribute_type( & name='message_activation', type='string', & val=trim(this%activation%name)) end function get_attributes_kipf !############################################################################### !############################################################################### subroutine emit_onnx_nodes_kipf( & this, prefix, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, & input_name, is_last_layer, format & ) !! Emit ONNX JSON nodes for Kipf GCN layer !! !! Decomposes the Kipf message passing layer into standard ONNX ops: !! Gather, ScatterElements, Mul, Pow, MatMul, activation !! !! Kipf GCN: H^(l+1) = sigma(D~^(-1/2) A~ D~^(-1/2) H^(l) W^(l)) !! Decomposed per timestep: !! 1. Extract source/target indices from edge_index !! 2. Gather source vertex features !! 3. Compute normalisation coeff = (deg_src * deg_tgt)^(-0.5) !! 4. Scale source features by coefficient !! 5. Scatter-add to target vertices !! 6. MatMul with weight W (transposed) !! 7. Apply activation use athena__onnx_msgpass_utils, only: emit_output_identity implicit none ! Arguments class(kipf_msgpass_layer_type), intent(in) :: this !! Instance of the layer character(*), intent(in) :: prefix !! Node name prefix (e.g. "node_2") type(onnx_node_type), intent(inout), dimension(:) :: nodes !! Accumulator for ONNX nodes integer, intent(inout) :: num_nodes !! Current number of nodes integer, intent(in) :: max_nodes !! Maximum capacity type(onnx_initialiser_type), intent(inout), dimension(:) :: inits !! Accumulator for ONNX initialisers integer, intent(inout) :: num_inits !! Current number of initialisers integer, intent(in) :: max_inits !! Maximum capacity character(*), optional, intent(in) :: input_name !! Unused sequential input name logical, optional, intent(in) :: is_last_layer !! Unused last-layer flag integer, optional, intent(in) :: format !! Unused export format selector ! Local variables integer :: t !! Time-step index character(128) :: cur_vertex_name !! Current timestep output tensor name do t = 1, this%num_time_steps call emit_kipf_timestep( & prefix, t, & this%num_vertex_features(t-1), & this%num_vertex_features(t), & this%params(t)%val(:,1), & this%activation%name, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, & cur_vertex_name & ) end do ! Kipf produces node-level output (no readout). call emit_output_identity( & prefix, trim(cur_vertex_name), this%activation%name, & nodes, num_nodes) end subroutine emit_onnx_nodes_kipf !############################################################################### !############################################################################### subroutine emit_kipf_timestep( & prefix, t, nv_in, nv_out, weight_data, activation_name, & nodes, num_nodes, max_nodes, & inits, num_inits, max_inits, vertex_out) !! Emit ONNX nodes for one Kipf GCN time step. use athena__onnx_utils, only: emit_node, emit_constant_int64, & emit_constant_float, emit_activation_node use athena__onnx_msgpass_utils, only: get_timestep_output_name, & emit_edge_index_component, emit_scatter_aggregator, & emit_weight_initialiser_2d implicit none ! Arguments character(*), intent(in) :: prefix integer, intent(in) :: t, nv_in, nv_out real(real32), intent(in) :: weight_data(:) character(*), intent(in) :: activation_name type(onnx_node_type), intent(inout), dimension(:) :: nodes integer, intent(inout) :: num_nodes integer, intent(in) :: max_nodes type(onnx_initialiser_type), intent(inout), dimension(:) :: inits integer, intent(inout) :: num_inits integer, intent(in) :: max_inits character(128), intent(out) :: vertex_out ! Local variables character(128) :: tp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7 character(128) :: vertex_in, edge_index_in, degree_in character(128) :: src_idx, target_idx, aggr_name character(len=*), parameter :: onnx_axis0_attr = & ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]' character(len=*), parameter :: onnx_transpose_10_attr = & ' "attribute": [{"name": "perm", "ints": ["1", "0"], ' // & '"type": "INTS"}]' character(len=*), parameter :: onnx_cast_float_attr = & ' "attribute": [{"name": "to", "i": "1", "type": "INT"}]' write(tp, '(A,"_t",I0)') trim(prefix), t write(vertex_in, '(A,"_vertex_in")') trim(prefix) write(edge_index_in, '(A,"_edge_index_in")') trim(prefix) write(degree_in, '(A,"_degree_in")') trim(prefix) if(t .gt. 1)then call get_timestep_output_name( & prefix, t-1, activation_name, '_mm_out', '', vertex_in) end if ! --- Step 1: Extract source and target indices from edge_index --- write(tmp1, '(A,"_idx0")') trim(tp) call emit_constant_int64(trim(tmp1), [0], [1], & nodes, num_nodes, inits, num_inits) write(tmp2, '(A,"_idx2")') trim(tp) call emit_constant_int64(trim(tmp2), [2], [1], & nodes, num_nodes, inits, num_inits) call emit_edge_index_component( & tp, edge_index_in, trim(tmp1), 'src', src_idx, nodes, num_nodes) call emit_edge_index_component( & tp, edge_index_in, trim(tmp2), 'tgt', target_idx, nodes, num_nodes) ! --- Step 2: Gather source features and compute normalisation --- write(tmp1, '(A,"_src_feat")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_vfeat', & trim(tmp1), onnx_axis0_attr, nodes, num_nodes, & in1=trim(vertex_in), in2=trim(src_idx)) write(tmp2, '(A,"_deg_f")') trim(tp) call emit_node('Cast', trim(tp)//'_cast_deg', & trim(tmp2), onnx_cast_float_attr, nodes, num_nodes, & in1=trim(degree_in)) write(tmp4, '(A,"_deg_src")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_deg_src', & trim(tmp4), onnx_axis0_attr, nodes, num_nodes, & in1=trim(tmp2), in2=trim(src_idx)) write(tmp6, '(A,"_deg_tgt")') trim(tp) call emit_node('Gather', trim(tp)//'_gather_deg_tgt', & trim(tmp6), onnx_axis0_attr, nodes, num_nodes, & in1=trim(tmp2), in2=trim(target_idx)) write(tmp7, '(A,"_deg_prod")') trim(tp) call emit_node('Mul', trim(tp)//'_mul_deg', & trim(tmp7), '', nodes, num_nodes, & in1=trim(tmp4), in2=trim(tmp6)) write(tmp2, '(A,"_neg_half")') trim(tp) call emit_constant_float(trim(tmp2), [-0.5_real32], [1], & nodes, num_nodes, inits, num_inits) write(tmp3, '(A,"_coeff")') trim(tp) call emit_node('Pow', trim(tp)//'_pow_coeff', & trim(tmp3), '', nodes, num_nodes, & in1=trim(tmp7), in2=trim(tmp2)) ! Unsqueeze coeff for broadcasting and scale the source features. write(tmp4, '(A,"_coeff_us")') trim(tp) write(tmp6, '(A,"_us_ax1")') trim(tp) call emit_constant_int64(trim(tmp6), [1], [1], & nodes, num_nodes, inits, num_inits) call emit_node('Unsqueeze', trim(tp)//'_us_coeff', & trim(tmp4), '', nodes, num_nodes, & in1=trim(tmp3), in2=trim(tmp6)) write(tmp2, '(A,"_scaled_feat")') trim(tp) call emit_node('Mul', trim(tp)//'_mul_coeff', & trim(tmp2), '', nodes, num_nodes, & in1=trim(tmp1), in2=trim(tmp4)) ! --- Step 3: Scatter-add normalised messages to target vertices --- call emit_scatter_aggregator( & tp, vertex_in, target_idx, trim(tmp2), nv_in, & nodes, num_nodes, inits, num_inits, aggr_name) ! --- Step 4: MatMul with weight W --- write(tmp1, '(A,"_W")') trim(tp) call emit_weight_initialiser_2d( & trim(tmp1), nv_out, nv_in, weight_data, inits, num_inits) write(tmp2, '(A,"_Wt")') trim(tp) call emit_node('Transpose', trim(tp)//'_transpose_W', & trim(tmp2), onnx_transpose_10_attr, nodes, num_nodes, & in1=trim(tmp1)) write(tmp3, '(A,"_mm_out")') trim(tp) call emit_node('MatMul', trim(tp)//'_matmul', & trim(tmp3), '', nodes, num_nodes, & in1=trim(aggr_name), in2=trim(tmp2)) ! --- Step 5: Activation --- if(trim(activation_name) .ne. 'none')then call emit_activation_node(activation_name, trim(tp), trim(tmp3), & nodes, num_nodes, max_nodes) vertex_out = trim(nodes(num_nodes)%outputs(1)) else vertex_out = trim(tmp3) end if end subroutine emit_kipf_timestep !############################################################################### !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! !##############################################################################! !##############################################################################! subroutine update_message_kipf(this, input) !! Update the message implicit none ! Arguments class(kipf_msgpass_layer_type), intent(inout), target :: this !! Instance of the message passing layer class(array_type), dimension(:,:), intent(in), target :: input !! Input to the message passing layer ! Local variables integer :: s, t !! Batch index, time step type(array_type), pointer :: ptr1, ptr2, ptr3 !! Pointers to arrays if(allocated(this%output))then if(size(this%output,2).ne.size(input,2))then deallocate(this%output) allocate(this%output(1,size(input,2))) end if else allocate(this%output(1,size(input,2))) end if do s = 1, size(input,2) ptr1 => input(1,s) do t = 1, this%num_time_steps ptr2 => kipf_propagate( & ptr1, & this%graph(s)%adj_ia, this%graph(s)%adj_ja & ) ! this%z(t,s) = kipf_update( & ! this%message(t,s), this%params(t), this%graph(s)%adj_ia & ! ) ptr3 => matmul( this%params(t), ptr2 ) ptr1 => this%activation%apply( ptr3 ) end do call this%output(1,s)%zero_grad() call this%output(1,s)%assign_and_deallocate_source(ptr1) this%output(1,s)%is_temporary = .false. end do end subroutine update_message_kipf !############################################################################### !############################################################################### subroutine update_readout_kipf(this) !! Update the readout (empty for node-level output) implicit none ! Arguments class(kipf_msgpass_layer_type), intent(inout), target :: this !! Instance of the message passing layer end subroutine update_readout_kipf !############################################################################### end module athena__kipf_msgpass_layer