get_orthogonality_metric Function

public function get_orthogonality_metric(this) result(metric)

Compute max(|Phi^T @ Phi - I|) as a measure of basis orthogonality

Type Bound

orthogonal_nop_block_type

Arguments

Type IntentOptional Attributes Name
class(orthogonal_nop_block_type), intent(in) :: this

Layer instance providing basis parameters

Return Value real(kind=real32)

Maximum absolute deviation from orthogonality


Source Code

  function get_orthogonality_metric(this) result(metric)
    !! Compute max(|Phi^T @ Phi - I|) as a measure of basis orthogonality
    implicit none

    ! Arguments
    class(orthogonal_nop_block_type), intent(in) :: this
    !! Layer instance providing basis parameters
    real(real32) :: metric
    !! Maximum absolute deviation from orthogonality

    ! Local variables
   integer :: n, k, i, j
    !! Matrix dimensions and traversal indices
   real(real32), allocatable :: basis_matrix(:,:), orthogonal_basis(:,:)
   !! Raw basis weights and orthogonalised basis matrix
   real(real32) :: norm_val, projection, val
   !! Gram-Schmidt scalars and current absolute deviation entry

    n = this%num_inputs
    k = this%num_basis

    allocate(basis_matrix(n, k), orthogonal_basis(n, k))
    basis_matrix = reshape(this%params(2)%val(:,1), [n, k])
    orthogonal_basis = basis_matrix

    do j = 1, k
       do i = 1, j - 1
          projection = dot_product(orthogonal_basis(:,i), orthogonal_basis(:,j))
          orthogonal_basis(:,j) = orthogonal_basis(:,j) - &
               projection * orthogonal_basis(:,i)
       end do
       norm_val = sqrt(dot_product( &
            orthogonal_basis(:,j), orthogonal_basis(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          orthogonal_basis(:,j) = orthogonal_basis(:,j) / norm_val
       else
          orthogonal_basis(:,j) = 0.0_real32
       end if
    end do

    ! max(|Q^T Q - I|)
    metric = 0.0_real32
    do j = 1, k
       do i = 1, k
          val = dot_product(orthogonal_basis(:,i), orthogonal_basis(:,j))
          if(i .eq. j)then
             val = abs(val - 1.0_real32)
          else
             val = abs(val)
          end if
          if(val .gt. metric) metric = val
       end do
    end do

    deallocate(basis_matrix, orthogonal_basis)

  end function get_orthogonality_metric