lno_decode Module Function

module function lno_decode(spectral, poles, num_outputs, num_modes) result(c)

Decode through the Laplace basis built from learnable poles.

Forward: y = D(mu) @ x [n_out, batch] D[i,m] = exp(-mu_m * tau_i), tau_i = (i-1)/(n_out-1)

left_operand → spectral x [M, batch] right_operand → poles mu [M, 1] output → decoded [n_out, batch]

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in), target :: spectral

Spectral tensor [M, batch]

class(array_type), intent(in), target :: poles

Learnable poles [M, 1]

integer, intent(in) :: num_outputs

Output dimension and number of modes

integer, intent(in) :: num_modes

Output dimension and number of modes

Return Value type(array_type), pointer

Decoded output tensor


Source Code

  module function lno_decode( &
       spectral, poles, num_outputs, num_modes &
  ) result(c)
    !! Decode through the Laplace basis built from learnable poles.
    !!
    !! Forward:  y = D(mu) @ x   [n_out, batch]
    !!   D[i,m] = exp(-mu_m * tau_i),  tau_i = (i-1)/(n_out-1)
    !!
    !! left_operand  → spectral x  [M, batch]
    !! right_operand → poles mu    [M, 1]
    !! output        → decoded     [n_out, batch]
    implicit none

    ! Arguments
    class(array_type), intent(in), target :: spectral
    !! Spectral tensor [M, batch]
    class(array_type), intent(in), target :: poles
    !! Learnable poles [M, 1]
    integer, intent(in) :: num_outputs, num_modes
    !! Output dimension and number of modes
    type(array_type), pointer :: c
    !! Decoded output tensor

    ! Local variables
    integer :: num_samples, m, i
    !! Batch and loop indices
    real(real32) :: t, s
    !! Normalised coordinate and current pole value
    real(real32), allocatable :: D(:,:)  ! [n_out, M]
    !! Decoder basis matrix

    num_samples = size(spectral%val, 2)

    ! Build decoder basis D [n_out x M]
    allocate(D(num_outputs, num_modes))
    do m = 1, num_modes
       s = poles%val(m, 1)
       do i = 1, num_outputs
          if(num_outputs .gt. 1)then
             t = real(i-1, real32) / real(num_outputs-1, real32)
          else
             t = 0.0_real32
          end if
          D(i, m) = exp(-s * t)
       end do
    end do

    ! Forward: y = D @ x
    c => spectral%create_result(array_shape=[num_outputs, num_samples])
    c%val = matmul(D, spectral%val)

    deallocate(D)

    ! Store metadata for backward
    allocate(c%indices(2))
    c%indices = [num_outputs, num_modes]

    c%get_partial_left     => get_partial_lno_decode_spectral
    c%get_partial_right    => get_partial_lno_decode_poles
    c%get_partial_left_val => get_partial_lno_decode_spectral_val
    c%get_partial_right_val => get_partial_lno_decode_poles_val
    if(spectral%requires_grad .or. poles%requires_grad)then
       c%requires_grad    = .true.
       c%is_forward       = spectral%is_forward .or. poles%is_forward
       c%operation        = 'lno_decode'
       c%left_operand     => spectral
       c%right_operand    => poles
       c%owns_left_operand  = spectral%is_temporary
       c%owns_right_operand = poles%is_temporary
    end if

  end function lno_decode