get_partial_ono_decode_basis_val Subroutine

pure subroutine get_partial_ono_decode_basis_val(this, upstream_grad, output)

dL/dB per sample through Gram-Schmidt backward.

For decode y = Q @ x: dL/dQ from sample s: upstream(:,s) @ x(:,s)^T → [n, k] dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k]

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_ono_decode_basis_val( &
       this, upstream_grad, output)
    !! dL/dB per sample through Gram-Schmidt backward.
    !!
    !! For decode y = Q @ x:
    !!   dL/dQ from sample s: upstream(:,s) @ x(:,s)^T  → [n, k]
    !!   dL/dB from sample s: gs_backward(B, dL/dQ_s)   → [n, k]
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in)  :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    integer :: n, k, s, i, j, num_samples
    real(real32), allocatable :: B(:,:), Q(:,:), R(:,:)
    real(real32), allocatable :: dQ(:,:), dQ_work(:,:), dB(:,:)
    real(real32), allocatable :: dv(:), v_recon(:)
    real(real32) :: norm_j, dprod, dR_ij, proj

    n = this%indices(1)
    k = this%indices(2)
    num_samples = size(upstream_grad, 2)

    ! Recompute Q and R from B
    allocate(B(n, k), Q(n, k), R(k, k))
    B = reshape(this%right_operand%val(:,1), [n, k])
    Q = B
    R = 0.0_real32
    do j = 1, k
       do i = 1, j - 1
          R(i,j) = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - R(i,j) * Q(:,i)
       end do
       R(j,j) = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(R(j,j) .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / R(j,j)
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    allocate(dQ(n, k), dQ_work(n, k), dB(n, k))
    allocate(dv(n), v_recon(n))

    output = 0.0_real32

    do s = 1, num_samples
       ! dL/dQ for this sample: upstream(:,s) outer x(:,s)
       ! dQ[i_n, j_k] = upstream(i_n, s) * x(j_k, s)
       do j = 1, k
          do i = 1, n
             dQ(i, j) = upstream_grad(i, s) * this%left_operand%val(j, s)
          end do
       end do

       ! Gram-Schmidt backward: dQ -> dB
       dQ_work = dQ
       dB = 0.0_real32

       do j = k, 1, -1
          norm_j = R(j, j)
          if(norm_j .le. 1.0e-12_real32)then
             dB(:,j) = 0.0_real32
             cycle
          end if

          ! Backward through normalization
          dprod = dot_product(dQ_work(:,j), Q(:,j))
          dv = (dQ_work(:,j) - dprod * Q(:,j)) / norm_j

          ! Reconstruct v before normalization
          v_recon = norm_j * Q(:,j)

          ! Backward through projections (reverse order)
          do i = j-1, 1, -1
             v_recon = v_recon + R(i,j) * Q(:,i)
             dR_ij = -dot_product(dv, Q(:,i))
             dQ_work(:,i) = dQ_work(:,i) - R(i,j) * dv
             dQ_work(:,i) = dQ_work(:,i) + dR_ij * v_recon
             dv = dv + dR_ij * Q(:,i)
          end do

          dB(:,j) = dv
       end do

       output(:, s) = reshape(dB, [n*k])
    end do

    deallocate(B, Q, R, dQ, dQ_work, dB, dv, v_recon)

  end subroutine get_partial_ono_decode_basis_val