dL/dB per sample through Gram-Schmidt backward.
For decode y = Q @ x: dL/dQ from sample s: upstream(:,s) @ x(:,s)^T → [n, k] dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k]
| Type | Intent | Optional | Attributes | Name | ||
|---|---|---|---|---|---|---|
| class(array_type), | intent(in) | :: | this | |||
| real(kind=real32), | intent(in), | dimension(:,:) | :: | upstream_grad | ||
| real(kind=real32), | intent(out), | dimension(:,:) | :: | output |
pure subroutine get_partial_ono_decode_basis_val( & this, upstream_grad, output) !! dL/dB per sample through Gram-Schmidt backward. !! !! For decode y = Q @ x: !! dL/dQ from sample s: upstream(:,s) @ x(:,s)^T → [n, k] !! dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k] implicit none class(array_type), intent(in) :: this real(real32), dimension(:,:), intent(in) :: upstream_grad real(real32), dimension(:,:), intent(out) :: output integer :: n, k, s, i, j, num_samples real(real32), allocatable :: B(:,:), Q(:,:), R(:,:) real(real32), allocatable :: dQ(:,:), dQ_work(:,:), dB(:,:) real(real32), allocatable :: dv(:), v_recon(:) real(real32) :: norm_j, dprod, dR_ij, proj n = this%indices(1) k = this%indices(2) num_samples = size(upstream_grad, 2) ! Recompute Q and R from B allocate(B(n, k), Q(n, k), R(k, k)) B = reshape(this%right_operand%val(:,1), [n, k]) Q = B R = 0.0_real32 do j = 1, k do i = 1, j - 1 R(i,j) = dot_product(Q(:,i), Q(:,j)) Q(:,j) = Q(:,j) - R(i,j) * Q(:,i) end do R(j,j) = sqrt(dot_product(Q(:,j), Q(:,j))) if(R(j,j) .gt. 1.0e-12_real32)then Q(:,j) = Q(:,j) / R(j,j) else Q(:,j) = 0.0_real32 end if end do allocate(dQ(n, k), dQ_work(n, k), dB(n, k)) allocate(dv(n), v_recon(n)) output = 0.0_real32 do s = 1, num_samples ! dL/dQ for this sample: upstream(:,s) outer x(:,s) ! dQ[i_n, j_k] = upstream(i_n, s) * x(j_k, s) do j = 1, k do i = 1, n dQ(i, j) = upstream_grad(i, s) * this%left_operand%val(j, s) end do end do ! Gram-Schmidt backward: dQ -> dB dQ_work = dQ dB = 0.0_real32 do j = k, 1, -1 norm_j = R(j, j) if(norm_j .le. 1.0e-12_real32)then dB(:,j) = 0.0_real32 cycle end if ! Backward through normalization dprod = dot_product(dQ_work(:,j), Q(:,j)) dv = (dQ_work(:,j) - dprod * Q(:,j)) / norm_j ! Reconstruct v before normalization v_recon = norm_j * Q(:,j) ! Backward through projections (reverse order) do i = j-1, 1, -1 v_recon = v_recon + R(i,j) * Q(:,i) dR_ij = -dot_product(dv, Q(:,i)) dQ_work(:,i) = dQ_work(:,i) - R(i,j) * dv dQ_work(:,i) = dQ_work(:,i) + dR_ij * v_recon dv = dv + dR_ij * Q(:,i) end do dB(:,j) = dv end do output(:, s) = reshape(dB, [n*k]) end do deallocate(B, Q, R, dQ, dQ_work, dB, dv, v_recon) end subroutine get_partial_ono_decode_basis_val