init_fixed_lno Subroutine

private subroutine init_fixed_lno(this, input_shape, verbose)

Initialise parameter storage, fixed bases and output buffers

Type Bound

fixed_lno_layer_type

Arguments

Type IntentOptional Attributes Name
class(fixed_lno_layer_type), intent(inout) :: this

Layer instance to initialise

integer, intent(in), dimension(:) :: input_shape

Input shape used to infer num_inputs

integer, intent(in), optional :: verbose

Verbosity level


Source Code

  subroutine init_fixed_lno(this, input_shape, verbose)
    !! Initialise parameter storage, fixed bases and output buffers
    implicit none

    ! Arguments
    class(fixed_lno_layer_type), intent(inout) :: this
    !! Layer instance to initialise
    integer, dimension(:), intent(in) :: input_shape
    !! Input shape used to infer num_inputs
    integer, optional, intent(in) :: verbose
    !! Verbosity level

    ! Local variables
    integer :: num_inputs, j, k, i, idx
    !! Effective fan-in size and basis-construction indices
    integer :: verbose_ = 0
    !! Effective verbosity level
    real(real32) :: s, t
    !! Spectral pole value and normalised coordinate

    if(present(verbose)) verbose_ = verbose

    !---------------------------------------------------------------------------
    ! Set shapes
    !---------------------------------------------------------------------------
    if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
    this%num_inputs = this%input_shape(1)
    this%output_shape = [this%num_outputs]
    this%num_params = this%get_num_params()


    !---------------------------------------------------------------------------
    ! Allocate learnable parameters
    !
    ! params(1): R  spectral mixing weights [num_modes x num_modes]
    ! params(2): W  local bypass weights    [num_outputs x num_inputs]
    ! params(3): b  bias                    [num_outputs]  (optional)
    !---------------------------------------------------------------------------
    allocate(this%weight_shape(2,2))
    this%weight_shape(:,1) = [ this%num_modes, this%num_modes ]
    this%weight_shape(:,2) = [ this%num_outputs, this%num_inputs ]

    if(this%use_bias)then
       this%bias_shape = [ this%num_outputs ]
       allocate(this%params(3))
    else
       allocate(this%params(2))
    end if

    ! R: spectral mixing weights
    call this%params(1)%allocate([this%num_modes, this%num_modes, 1])
    call this%params(1)%set_requires_grad(.true.)
    this%params(1)%fix_pointer = .true.
    this%params(1)%is_sample_dependent = .false.
    this%params(1)%is_temporary = .false.

    ! W: local bypass weights
    call this%params(2)%allocate([this%num_outputs, this%num_inputs, 1])
    call this%params(2)%set_requires_grad(.true.)
    this%params(2)%fix_pointer = .true.
    this%params(2)%is_sample_dependent = .false.
    this%params(2)%is_temporary = .false.

    num_inputs = this%num_inputs
    if(this%use_bias)then
       num_inputs = this%num_inputs + 1
       call this%params(3)%allocate([this%bias_shape, 1])
       call this%params(3)%set_requires_grad(.true.)
       this%params(3)%fix_pointer = .true.
       this%params(3)%is_sample_dependent = .false.
       this%params(3)%is_temporary = .false.
    end if


    !---------------------------------------------------------------------------
    ! Initialise learnable parameters
    !---------------------------------------------------------------------------
    call this%kernel_init%initialise( &
         this%params(1)%val(:,1), &
         fan_in = this%num_modes, fan_out = this%num_modes, &
         spacing = [ this%num_modes ] &
    )
    call this%kernel_init%initialise( &
         this%params(2)%val(:,1), &
         fan_in = num_inputs, fan_out = this%num_outputs, &
         spacing = [ this%num_outputs ] &
    )
    if(this%use_bias)then
       call this%bias_init%initialise( &
            this%params(3)%val(:,1), &
            fan_in = num_inputs, fan_out = this%num_outputs &
       )
    end if


    !---------------------------------------------------------------------------
    ! Build fixed encoder basis E [num_modes x num_inputs]
    !   E(k,j) = exp(-s_k * t_j)
    !   s_k = k * pi,  t_j = (j-1)/(n_in-1)
    !---------------------------------------------------------------------------
    if(this%encoder_basis%allocated) call this%encoder_basis%deallocate()
    call this%encoder_basis%allocate( &
         [this%num_modes, this%num_inputs, 1])
    this%encoder_basis%is_sample_dependent = .false.
    this%encoder_basis%requires_grad = .false.
    this%encoder_basis%fix_pointer = .true.
    this%encoder_basis%is_temporary = .false.

    do j = 1, this%num_inputs
       if(this%num_inputs .gt. 1)then
          t = real(j-1, real32) / real(this%num_inputs-1, real32)
       else
          t = 0.0_real32
       end if
       do k = 1, this%num_modes
          s = real(k, real32) * pi
          idx = k + (j-1) * this%num_modes
          this%encoder_basis%val(idx, 1) = exp(-s * t)
       end do
    end do


    !---------------------------------------------------------------------------
    ! Build fixed decoder basis D [num_outputs x num_modes]
    !   D(i,k) = exp(-s_k * tau_i)
    !   s_k = k * pi,  tau_i = (i-1)/(n_out-1)
    !---------------------------------------------------------------------------
    if(this%decoder_basis%allocated) call this%decoder_basis%deallocate()
    call this%decoder_basis%allocate( &
         [this%num_outputs, this%num_modes, 1])
    this%decoder_basis%is_sample_dependent = .false.
    this%decoder_basis%requires_grad = .false.
    this%decoder_basis%fix_pointer = .true.
    this%decoder_basis%is_temporary = .false.

    do k = 1, this%num_modes
       s = real(k, real32) * pi
       do i = 1, this%num_outputs
          if(this%num_outputs .gt. 1)then
             t = real(i-1, real32) / real(this%num_outputs-1, real32)
          else
             t = 0.0_real32
          end if
          idx = i + (k-1) * this%num_outputs
          this%decoder_basis%val(idx, 1) = exp(-s * t)
       end do
    end do


    !---------------------------------------------------------------------------
    ! Allocate output arrays
    !---------------------------------------------------------------------------
    if(allocated(this%output)) deallocate(this%output)
    allocate(this%output(1,1))
    if(this%z(1)%allocated) call this%z(1)%deallocate()

  end subroutine init_fixed_lno