get_partial_ono_decode_mixed_val Subroutine

pure subroutine get_partial_ono_decode_mixed_val(this, upstream_grad, output)

dL/dx = Q^T @ upstream [k, batch]

Arguments

Type IntentOptional Attributes Name
class(array_type), intent(in) :: this
real(kind=real32), intent(in), dimension(:,:) :: upstream_grad
real(kind=real32), intent(out), dimension(:,:) :: output

Source Code

  pure subroutine get_partial_ono_decode_mixed_val( &
       this, upstream_grad, output)
    !! dL/dx = Q^T @ upstream  [k, batch]
    implicit none
    class(array_type), intent(in) :: this
    real(real32), dimension(:,:), intent(in)  :: upstream_grad
    real(real32), dimension(:,:), intent(out) :: output

    integer :: n, k, i, j
    real(real32), allocatable :: B(:,:), Q(:,:), QT(:,:)
    real(real32) :: norm_val, proj

    n = this%indices(1)
    k = this%indices(2)

    ! Recompute Q from B
    allocate(B(n, k), Q(n, k), QT(k, n))
    B = reshape(this%right_operand%val(:,1), [n, k])
    Q = B
    do j = 1, k
       do i = 1, j - 1
          proj = dot_product(Q(:,i), Q(:,j))
          Q(:,j) = Q(:,j) - proj * Q(:,i)
       end do
       norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
       if(norm_val .gt. 1.0e-12_real32)then
          Q(:,j) = Q(:,j) / norm_val
       else
          Q(:,j) = 0.0_real32
       end if
    end do

    ! Transpose
    do j = 1, n
       do i = 1, k
          QT(i, j) = Q(j, i)
       end do
    end do

    output = matmul(QT, upstream_grad)

    deallocate(B, Q, QT)

  end subroutine get_partial_ono_decode_mixed_val