Skip to content

Commit

Permalink
Fix for svd API (#1190)
Browse files Browse the repository at this point in the history
Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1190
  • Loading branch information
lowener authored Feb 4, 2023
1 parent a426bc9 commit bc764b8
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 63 deletions.
3 changes: 2 additions & 1 deletion cpp/include/raft/linalg/detail/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ void svdQR(raft::device_resources const& handle,
stream));

// Transpose the right singular vector back
if (trans_right) raft::linalg::transpose(right_sing_vecs, n_cols, stream);
if (trans_right && right_sing_vecs != nullptr)
raft::linalg::transpose(right_sing_vecs, n_cols, stream);

RAFT_CUDA_TRY(cudaGetLastError());

Expand Down
117 changes: 72 additions & 45 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,45 +191,42 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle,
* matrix using QR decomposition
* @tparam ValueType value type of parameters
* @tparam IndexType index type of parameters
* @tparam UType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* U_in
* @tparam VType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* V_in
* @param[in] handle raft::device_resources
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] sing_vals singular values raft::device_vector_view of shape (K)
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* @param[out] U std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with
* @param[out] V std::optional right singular values of raft::device_matrix_view with
* layout raft::col_major and dimensions (n, n)
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
void svd_qr(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in,
VType&& V_in)
template <typename ValueType, typename IndexType>
void svd_qr(
raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V = std::nullopt)
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);
ValueType* left_sing_vecs_ptr = nullptr;
ValueType* right_sing_vecs_ptr = nullptr;

if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
if (V) {
RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1),
"V should have dimensions n * n");
right_sing_vecs_ptr = V.value().data_handle();
}
svdQR(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
sing_vals.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
right_sing_vecs_ptr,
false,
U.has_value(),
V.has_value(),
Expand All @@ -243,57 +240,62 @@ void svd_qr(raft::device_resources const& handle,
*
* Please see above for documentation of `svd_qr`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr(Args... args)
template <typename ValueType, typename IndexType, typename UType, typename VType>
void svd_qr(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in = std::nullopt,
VType&& V_in = std::nullopt)
{
svd_qr(std::forward<Args>(args)..., std::nullopt, std::nullopt);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);

svd_qr(handle, in, sing_vals, U, V);
}

/**
* @brief singular value decomposition (SVD) on a column major
* matrix using QR decomposition. Right singular vector matrix is transposed before returning
* @tparam ValueType value type of parameters
* @tparam IndexType index type of parameters
* @tparam UType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* U_in
* @tparam VType std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> @c
* V_in
* @param[in] handle raft::device_resources
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] sing_vals singular values raft::device_vector_view of shape (K)
* @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout
* @param[out] U std::optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* @param[out] V_in std::optional right singular values of raft::device_matrix_view with
* @param[out] V std::optional right singular values of raft::device_matrix_view with
* layout raft::col_major and dimensions (n, n)
*/
template <typename ValueType, typename IndexType, typename UType, typename VType>
template <typename ValueType, typename IndexType>
void svd_qr_transpose_right_vec(
raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in,
VType&& V_in)
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V = std::nullopt)
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);
ValueType* left_sing_vecs_ptr = nullptr;
ValueType* right_sing_vecs_ptr = nullptr;

if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
if (V) {
RAFT_EXPECTS(in.extent(1) == V.value().extent(0) && in.extent(1) == V.value().extent(1),
"V should have dimensions n * n");
right_sing_vecs_ptr = V.value().data_handle();
}
svdQR(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
sing_vals.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
right_sing_vecs_ptr,
true,
U.has_value(),
V.has_value(),
Expand All @@ -307,10 +309,20 @@ void svd_qr_transpose_right_vec(
*
* Please see above for documentation of `svd_qr_transpose_right_vec`.
*/
template <typename... Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void svd_qr_transpose_right_vec(Args... args)
template <typename ValueType, typename IndexType, typename UType, typename VType>
void svd_qr_transpose_right_vec(
raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> sing_vals,
UType&& U_in = std::nullopt,
VType&& V_in = std::nullopt)
{
svd_qr_transpose_right_vec(std::forward<Args>(args)..., std::nullopt, std::nullopt);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U =
std::forward<UType>(U_in);
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> V =
std::forward<VType>(V_in);

svd_qr_transpose_right_vec(handle, in, sing_vals, U, V);
}

/**
Expand All @@ -320,7 +332,7 @@ void svd_qr_transpose_right_vec(Args... args)
* @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N)
* @param[out] S singular values raft::device_vector_view of shape (K)
* @param[out] V right singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
* raft::col_major and dimensions (n, n)
* @param[out] U optional left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, n)
*/
Expand All @@ -332,38 +344,52 @@ void svd_eig(
raft::device_matrix_view<ValueType, IndexType, raft::col_major> V,
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U = std::nullopt)
{
ValueType* left_sing_vecs_ptr = nullptr;
if (U) {
RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1),
"U should have dimensions m * n");
left_sing_vecs_ptr = U.value().data_handle();
}
RAFT_EXPECTS(in.extent(0) == V.extent(0) && in.extent(1) == V.extent(1),
RAFT_EXPECTS(in.extent(1) == V.extent(0) && in.extent(1) == V.extent(1),
"V should have dimensions n * n");
svdEig(handle,
const_cast<ValueType*>(in.data_handle()),
in.extent(0),
in.extent(1),
S.data_handle(),
U.value().data_handle(),
V.value().data_handle(),
left_sing_vecs_ptr,
V.data_handle(),
U.has_value(),
handle.get_stream());
}

template <typename ValueType, typename IndexType, typename UType>
void svd_eig(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> in,
raft::device_vector_view<ValueType, IndexType> S,
raft::device_matrix_view<ValueType, IndexType, raft::col_major> V,
UType&& U = std::nullopt)
{
std::optional<raft::device_matrix_view<ValueType, IndexType, raft::col_major>> U_optional =
std::forward<UType>(U);
svd_eig(handle, in, S, V, U_optional);
}

/**
* @brief reconstruct a matrix use left and right singular vectors and
* singular values
* @param[in] handle raft::device_resources
* @param[in] U left singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (m, k)
* @param[in] S singular values raft::device_vector_view of shape (k, k)
* @param[in] S square matrix with singular values on its diagonal of shape (k, k)
* @param[in] V right singular values of raft::device_matrix_view with layout
* raft::col_major and dimensions (k, n)
* @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n)
*/
template <typename ValueType, typename IndexType>
void svd_reconstruction(raft::device_resources const& handle,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> U,
raft::device_vector_view<const ValueType, IndexType> S,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> S,
raft::device_matrix_view<const ValueType, IndexType, raft::col_major> V,
raft::device_matrix_view<ValueType, IndexType, raft::col_major> out)
{
Expand All @@ -380,6 +406,7 @@ void svd_reconstruction(raft::device_resources const& handle,
const_cast<ValueType*>(U.data_handle()),
const_cast<ValueType*>(S.data_handle()),
const_cast<ValueType*>(V.data_handle()),
out.data_handle(),
out.extent(0),
out.extent(1),
S.extent(0),
Expand Down
64 changes: 47 additions & 17 deletions cpp/test/linalg/svd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "../test_utils.cuh"
#include <gtest/gtest.h>
#include <raft/linalg/init.cuh>
#include <raft/linalg/svd.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -56,6 +57,49 @@ class SvdTest : public ::testing::TestWithParam<SvdInputs<T>> {
}

protected:
void test_API()
{
auto data_view = raft::make_device_matrix_view<const T, int, raft::col_major>(
data.data(), params.n_row, params.n_col);
auto sing_vals_view = raft::make_device_vector_view<T, int>(sing_vals_qr.data(), params.n_col);
auto left_eig_vectors_view = raft::make_device_matrix_view<T, int, raft::col_major>(
left_eig_vectors_qr.data(), params.n_row, params.n_col);
auto right_eig_vectors_view = raft::make_device_matrix_view<T, int, raft::col_major>(
right_eig_vectors_trans_qr.data(), params.n_col, params.n_col);
raft::linalg::svd_eig(handle, data_view, sing_vals_view, right_eig_vectors_view, std::nullopt);
raft::linalg::svd_qr(handle, data_view, sing_vals_view);
raft::linalg::svd_qr(
handle, data_view, sing_vals_view, std::make_optional(left_eig_vectors_view));
raft::linalg::svd_qr(
handle, data_view, sing_vals_view, std::nullopt, std::make_optional(right_eig_vectors_view));
raft::linalg::svd_qr_transpose_right_vec(handle, data_view, sing_vals_view);
raft::linalg::svd_qr_transpose_right_vec(
handle, data_view, sing_vals_view, std::make_optional(left_eig_vectors_view));
raft::linalg::svd_qr_transpose_right_vec(
handle, data_view, sing_vals_view, std::nullopt, std::make_optional(right_eig_vectors_view));
}

void test_qr()
{
auto data_view = raft::make_device_matrix_view<const T, int, raft::col_major>(
data.data(), params.n_row, params.n_col);
auto sing_vals_qr_view =
raft::make_device_vector_view<T, int>(sing_vals_qr.data(), params.n_col);
auto left_eig_vectors_qr_view =
std::optional(raft::make_device_matrix_view<T, int, raft::col_major>(
left_eig_vectors_qr.data(), params.n_row, params.n_col));
auto right_eig_vectors_trans_qr_view =
std::make_optional(raft::make_device_matrix_view<T, int, raft::col_major>(
right_eig_vectors_trans_qr.data(), params.n_col, params.n_col));

svd_qr_transpose_right_vec(handle,
data_view,
sing_vals_qr_view,
left_eig_vectors_qr_view,
right_eig_vectors_trans_qr_view);
handle.sync_stream(stream);
}

void SetUp() override
{
int len = params.len;
Expand All @@ -78,23 +122,9 @@ class SvdTest : public ::testing::TestWithParam<SvdInputs<T>> {
raft::update_device(right_eig_vectors_ref.data(), right_eig_vectors_ref_h, right_evl, stream);
raft::update_device(sing_vals_ref.data(), sing_vals_ref_h, params.n_col, stream);

auto data_view = raft::make_device_matrix_view<const T, int, raft::col_major>(
data.data(), params.n_row, params.n_col);
auto sing_vals_qr_view =
raft::make_device_vector_view<T, int>(sing_vals_qr.data(), params.n_col);
std::optional<raft::device_matrix_view<T, int, raft::col_major>> left_eig_vectors_qr_view =
raft::make_device_matrix_view<T, int, raft::col_major>(
left_eig_vectors_qr.data(), params.n_row, params.n_col);
std::optional<raft::device_matrix_view<T, int, raft::col_major>>
right_eig_vectors_trans_qr_view = raft::make_device_matrix_view<T, int, raft::col_major>(
right_eig_vectors_trans_qr.data(), params.n_col, params.n_col);

svd_qr_transpose_right_vec(handle,
data_view,
sing_vals_qr_view,
left_eig_vectors_qr_view,
right_eig_vectors_trans_qr_view);
handle.sync_stream(stream);
test_API();
raft::update_device(data.data(), data_h, len, stream);
test_qr();
}

protected:
Expand Down

0 comments on commit bc764b8

Please sign in to comment.