dL/dB per sample through Gram-Schmidt backward.
For encode y = Q^T @ u: dL/dQ from sample s: u(:,s) @ upstream(:,s)^T → [n, k] dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k]
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:,:) | :: | output |
pure subroutine get_partial_ono_encode_basis_val( & this, upstream_grad, output) !! dL/dB per sample through Gram-Schmidt backward. !! !! For encode y = Q^T @ u: !! dL/dQ from sample s: u(:,s) @ upstream(:,s)^T → [n, k] !! dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k] implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: n, k, s, i, j, num_samples real(real32), allocatable :: B(:,:), Q(:,:), R(:,:) real(real32), allocatable :: dQ(:,:), dQ_work(:,:), dB(:,:) real(real32), allocatable :: dv(:), v_recon(:) real(real32) :: norm_j, dprod, dR_ij, proj n = this%indices(1) k = this%indices(2) num_samples = size(upstream_grad, 2) ! Recompute Q and R from B via modified Gram-Schmidt allocate(B(n, k), Q(n, k), R(k, k)) B = reshape(this%right_operand%val(:,1), [n, k]) Q = B R = 0.0_real32 do j = 1, k do i = 1, j - 1 R(i,j) = dot_product(Q(:,i), Q(:,j)) Q(:,j) = Q(:,j) - R(i,j) * Q(:,i) end do R(j,j) = sqrt(dot_product(Q(:,j), Q(:,j))) if(R(j,j) .gt. 1.0e-12_real32)then Q(:,j) = Q(:,j) / R(j,j) else Q(:,j) = 0.0_real32 end if end do allocate(dQ(n, k), dQ_work(n, k), dB(n, k)) allocate(dv(n), v_recon(n)) output = 0.0_real32 do s = 1, num_samples ! dL/dQ for this sample: u(:,s) outer upstream(:,s) ! dQ[j_n, i_k] = u(j_n, s) * upstream(i_k, s) do j = 1, k do i = 1, n dQ(i, j) = this%left_operand%val(i, s) * upstream_grad(j, s) end do end do ! Gram-Schmidt backward: dQ -> dB dQ_work = dQ dB = 0.0_real32 do j = k, 1, -1 norm_j = R(j, j) if(norm_j .le. 1.0e-12_real32)then dB(:,j) = 0.0_real32 cycle end if ! Backward through normalization dprod = dot_product(dQ_work(:,j), Q(:,j)) dv = (dQ_work(:,j) - dprod * Q(:,j)) / norm_j ! Reconstruct v before normalization v_recon = norm_j * Q(:,j) ! Backward through projections (reverse order) do i = j-1, 1, -1 v_recon = v_recon + R(i,j) * Q(:,i) dR_ij = -dot_product(dv, Q(:,i)) dQ_work(:,i) = dQ_work(:,i) - R(i,j) * dv dQ_work(:,i) = dQ_work(:,i) + dR_ij * v_recon dv = dv + dR_ij * Q(:,i) end do dB(:,j) = dv end do output(:, s) = reshape(dB, [n*k]) end do deallocate(B, Q, R, dQ, dQ_work, dB, dv, v_recon) end subroutine get_partial_ono_encode_basis_val