forward_neural_operator Subroutine

private subroutine forward_neural_operator(this, input)

Forward propagation for the neural operator layer

Computes: v = sigma( W * u + W_k * mean(u) + b )

where mean(u) is the global mean of the input (scalar per sample), approximating the integral operator.

Type Bound

neural_operator_layer_type

Arguments

Type IntentOptional Attributes Name
class(neural_operator_layer_type), intent(inout) :: this

Instance of the neural operator layer

class(array_type), intent(in), dimension(:,:) :: input

Input values


Source Code

  subroutine forward_neural_operator(this, input)
    !! Forward propagation for the neural operator layer
    !!
    !! Computes:
    !!   v = sigma( W * u  +  W_k * mean(u)  +  b )
    !!
    !! where mean(u) is the global mean of the input (scalar per sample),
    !! approximating the integral operator.
    implicit none

    ! Arguments
    class(neural_operator_layer_type), intent(inout) :: this
    !! Instance of the neural operator layer
    class(array_type), dimension(:,:), intent(in) :: input
    !! Input values

    ! Local variables
    type(array_type), pointer :: ptr, ptr_mean, ptr_kern


    ! Local transform: W · u  →  shape [n_out]
    !---------------------------------------------------------------------------
    ptr => matmul(this%params(1), input(1,1))

    ! Integral (mean-field) term: W_k · mean(u)  →  shape [n_out]
    !   mean(input, dim=1) reduces over all spatial elements, giving a scalar
    !   per batch sample (shape [1]).  matmul then expands W_k ([n_out x 1])
    !   by this scalar to produce a [n_out] correction vector.
    !---------------------------------------------------------------------------
    ptr_mean => mean(input(1,1), dim=1)
    ptr_kern => matmul(this%params(2), ptr_mean)

    ! Combine local + integral terms
    !---------------------------------------------------------------------------
    ptr => ptr + ptr_kern

    ! Add bias if used
    !---------------------------------------------------------------------------
    if(this%use_bias)then
       ptr => ptr + this%params(3)
    end if

    ! Apply activation function
    !---------------------------------------------------------------------------
    call this%output(1,1)%zero_grad()
    if(trim(this%activation%name) .eq. "none")then
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    else
       call this%z(1)%zero_grad()
       call this%z(1)%assign_and_deallocate_source(ptr)
       this%z(1)%is_temporary = .false.
       ptr => this%activation%apply(this%z(1))
       call this%output(1,1)%assign_and_deallocate_source(ptr)
    end if
    this%output(1,1)%is_temporary = .false.

  end subroutine forward_neural_operator