Rebuild the dynamic Laplace encoder/decoder bases from the current learnable pole values (params(1)).
Called at the start of each forward pass so that the computation graph always uses up-to-date poles. The rebuilt bases are non-tracked (requires_grad = .false.); gradient flow for the residues beta goes through the diffstruc * operator, and gradient flow for the bypass weights W goes through matmul.
E(mu)[n,j] = exp(-mu_n * t_j), t_j = (j-1)/(n_in-1) D(mu)[i,n] = exp(-mu_n * tau_i), tau_i = (i-1)/(n_out-1)
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(dynamic_lno_layer_type), | intent(in) | :: | this |
Layer instance providing pole values |
Encoder and decoder basis tensors rebuilt from poles
function get_bases_dynamic_lno(this) result(bases) !! Rebuild the dynamic Laplace encoder/decoder bases from the current !! learnable pole values (params(1)). !! !! Called at the start of each forward pass so that the computation graph !! always uses up-to-date poles. The rebuilt bases are non-tracked !! (requires_grad = .false.); gradient flow for the residues beta goes !! through the diffstruc * operator, and gradient flow for the bypass !! weights W goes through matmul. !! !! E(mu)[n,j] = exp(-mu_n * t_j), t_j = (j-1)/(n_in-1) !! D(mu)[i,n] = exp(-mu_n * tau_i), tau_i = (i-1)/(n_out-1) implicit none ! Arguments class(dynamic_lno_layer_type), intent(in) :: this !! Layer instance providing pole values type(array_type), dimension(2) :: bases !! Encoder and decoder basis tensors rebuilt from poles ! Local variables integer :: j, k, i, idx !! Basis-construction loop indices and flattened index real(real32) :: s, t !! Pole value and normalised coordinate !--------------------------------------------------------------------------- ! Encoder E [num_modes x num_inputs] !--------------------------------------------------------------------------- call bases(1)%allocate( [this%num_modes, this%num_inputs, 1] ) bases(1)%is_sample_dependent = .false. bases(1)%requires_grad = .false. bases(1)%fix_pointer = .true. bases(1)%is_temporary = .false. do j = 1, this%num_inputs if(this%num_inputs .gt. 1)then t = real(j-1, real32) / real(this%num_inputs-1, real32) else t = 0.0_real32 end if do k = 1, this%num_modes s = this%params(1)%val(k, 1) idx = k + (j-1) * this%num_modes bases(1)%val(idx, 1) = exp(-s * t) end do end do !--------------------------------------------------------------------------- ! Decoder D [num_outputs x num_modes] !--------------------------------------------------------------------------- call bases(2)%allocate( [this%num_outputs, this%num_modes, 1] ) bases(2)%is_sample_dependent = .false. bases(2)%requires_grad = .false. bases(2)%fix_pointer = .true. bases(2)%is_temporary = .false. do k = 1, this%num_modes s = this%params(1)%val(k, 1) do i = 1, this%num_outputs if(this%num_outputs .gt. 1)then t = real(i-1, real32) / real(this%num_outputs-1, real32) else t = 0.0_real32 end if idx = i + (k-1) * this%num_outputs bases(2)%val(idx, 1) = exp(-s * t) end do end do end function get_bases_dynamic_lno