diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6898e528a3..d6cf851aef 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -249,7 +249,7 @@ if(OPENMP_FOUND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") endif(OPENMP_FOUND) -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") if(${CMAKE_VERSION} VERSION_LESS "3.17.0") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++14") diff --git a/cpp/src/dbscan/adjgraph/algo.cuh b/cpp/src/dbscan/adjgraph/algo.cuh index f80a827d7c..24a9f3f720 100644 --- a/cpp/src/dbscan/adjgraph/algo.cuh +++ b/cpp/src/dbscan/adjgraph/algo.cuh @@ -24,7 +24,7 @@ #include "../common.cuh" #include "pack.h" -#include +#include using namespace thrust; @@ -54,7 +54,7 @@ void launcher(const raft::handle_t &handle, Pack data, Index_ batchSize, int minPts = data.minPts; Index_ *vd = data.vd; - MLCommon::Sparse::csr_adj_graph_batched( + raft::sparse::convert::csr_adj_graph_batched( data.ex_scan, data.N, data.adjnnz, batchSize, data.adj, data.adj_graph, stream, [core_pts, minPts, vd] __device__(Index_ row, Index_ start_idx, diff --git a/cpp/src/dbscan/runner.cuh b/cpp/src/dbscan/runner.cuh index 1f0a5c7f8d..785b93897e 100644 --- a/cpp/src/dbscan/runner.cuh +++ b/cpp/src/dbscan/runner.cuh @@ -140,7 +140,7 @@ size_t run(const raft::handle_t& handle, Type_f* x, Index_ N, Index_ D, temp += exScanSize; // Running VertexDeg - MLCommon::Sparse::WeakCCState state(xa, fa, m); + raft::sparse::WeakCCState state(xa, fa, m); MLCommon::device_buffer adj_graph(handle.get_device_allocator(), stream); @@ -190,7 +190,7 @@ size_t run(const raft::handle_t& handle, Type_f* x, Index_ N, Index_ D, CUML_LOG_DEBUG("--> Computing connected components"); start_time = raft::curTimeMillis(); - MLCommon::Sparse::weak_cc_batched( + raft::sparse::weak_cc_batched( labels, ex_scan, adj_graph.data(), curradjlen, N, startVertexId, nPoints, &state, stream, [core_pts, startVertexId, nPoints] __device__(Index_ global_id) { diff --git a/cpp/src/knn/knn_sparse.cu b/cpp/src/knn/knn_sparse.cu index 0a73f54211..bbaf6d9c1c 100644 --- a/cpp/src/knn/knn_sparse.cu +++ b/cpp/src/knn/knn_sparse.cu @@ -19,7 +19,7 @@ #include #include -#include +#include #include @@ -40,7 +40,7 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, cusparseHandle_t cusparse_handle = handle.get_cusparse_handle(); cudaStream_t stream = handle.get_stream(); - MLCommon::Sparse::Selection::brute_force_knn( + raft::sparse::selection::brute_force_knn( idx_indptr, idx_indices, idx_data, idx_nnz, n_idx_rows, n_idx_cols, query_indptr, query_indices, query_data, query_nnz, n_query_rows, n_query_cols, output_indices, output_dists, k, cusparse_handle, d_alloc, diff --git a/cpp/src/spectral/spectral.cu b/cpp/src/spectral/spectral.cu index ea7f43a075..6db945698b 100644 --- a/cpp/src/spectral/spectral.cu +++ b/cpp/src/spectral/spectral.cu @@ -17,7 +17,7 @@ #include #include -#include +#include namespace ML { @@ -38,10 +38,8 @@ namespace Spectral { */ void fit_embedding(const raft::handle_t &handle, int *rows, int *cols, float *vals, int nnz, int n, int n_components, float *out) { - const auto &impl = handle; - MLCommon::Spectral::fit_embedding( - impl.get_cusparse_handle(), rows, cols, vals, nnz, n, n_components, out, - handle.get_device_allocator(), handle.get_stream()); + raft::sparse::spectral::fit_embedding(handle, rows, cols, vals, nnz, n, + n_components, out); } } // namespace Spectral } // namespace ML diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 7dc31e20dc..4baa50aaea 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -21,7 +21,8 @@ #include #include #include -#include +#include +#include #include @@ -82,7 +83,7 @@ template <> void get_distances(const raft::handle_t &handle, manifold_sparse_inputs_t &input, knn_graph &k_graph, cudaStream_t stream) { - MLCommon::Sparse::Selection::brute_force_knn( + raft::sparse::selection::brute_force_knn( input.indptr, input.indices, input.data, input.nnz, input.n, input.d, input.indptr, input.indices, input.data, input.nnz, input.n, input.d, k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, @@ -135,17 +136,16 @@ void normalize_distances(const value_idx n, value_t *distances, * @param[in] handle: The GPU handle. */ template -void symmetrize_perplexity( - float *P, value_idx *indices, const value_idx n, const int k, - const value_t exaggeration, - MLCommon::Sparse::COO *COO_Matrix, cudaStream_t stream, - const raft::handle_t &handle) { +void symmetrize_perplexity(float *P, value_idx *indices, const value_idx n, + const int k, const value_t exaggeration, + raft::sparse::COO *COO_Matrix, + cudaStream_t stream, const raft::handle_t &handle) { // Perform (P + P.T) / P_sum * early_exaggeration const value_t div = exaggeration / (2.0f * n); raft::linalg::scalarMultiply(P, P, div, n * k, stream); // Symmetrize to form P + P.T - MLCommon::Sparse::from_knn_symmetrize_matrix( + raft::sparse::linalg::from_knn_symmetrize_matrix( indices, P, n, k, COO_Matrix, stream, handle.get_device_allocator()); } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index ddcb66e5fb..cc8eef9daf 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -203,7 +203,7 @@ class TSNE_runner { const bool initialize_embeddings; bool barnes_hut; - MLCommon::Sparse::COO COO_Matrix; + raft::sparse::COO COO_Matrix; value_idx n, p; value_t *Y; }; diff --git a/cpp/src/umap/fuzzy_simpl_set/naive.cuh b/cpp/src/umap/fuzzy_simpl_set/naive.cuh index d40f127524..08d9b27b75 100644 --- a/cpp/src/umap/fuzzy_simpl_set/naive.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/naive.cuh @@ -23,8 +23,10 @@ #include #include +#include #include #include +#include #include @@ -276,8 +278,8 @@ void smooth_knn_dist(int n, const value_idx *knn_indices, * @param stream cuda stream to use for device operations */ template -void launcher(int n, const value_idx *knn_indices, const float *knn_dists, - int n_neighbors, MLCommon::Sparse::COO *out, +void launcher(int n, const value_idx *knn_indices, const value_t *knn_dists, + int n_neighbors, raft::sparse::COO *out, UMAPParams *params, std::shared_ptr d_alloc, cudaStream_t stream) { /** @@ -292,7 +294,7 @@ void launcher(int n, const value_idx *knn_indices, const float *knn_dists, n, knn_indices, knn_dists, rhos.data(), sigmas.data(), params, n_neighbors, params->local_connectivity, d_alloc, stream); - MLCommon::Sparse::COO in(d_alloc, stream, n * n_neighbors, n, n); + raft::sparse::COO in(d_alloc, stream, n * n_neighbors, n, n); // check for logging in order to avoid the potentially costly `arr2Str` call! if (ML::Logger::get().shouldLogFor(CUML_LEVEL_DEBUG)) { @@ -329,7 +331,7 @@ void launcher(int n, const value_idx *knn_indices, const float *knn_dists, * one via a fuzzy union. (Symmetrize knn graph). */ float set_op_mix_ratio = params->set_op_mix_ratio; - MLCommon::Sparse::coo_symmetrize( + raft::sparse::linalg::coo_symmetrize( &in, out, [set_op_mix_ratio] __device__(int row, int col, value_t result, value_t transpose) { @@ -340,7 +342,7 @@ void launcher(int n, const value_idx *knn_indices, const float *knn_dists, }, d_alloc, stream); - MLCommon::Sparse::coo_sort(out, d_alloc, stream); + raft::sparse::op::coo_sort(out, d_alloc, stream); } } // namespace Naive } // namespace FuzzySimplSet diff --git a/cpp/src/umap/fuzzy_simpl_set/runner.cuh b/cpp/src/umap/fuzzy_simpl_set/runner.cuh index aba1bbf883..6836865bac 100644 --- a/cpp/src/umap/fuzzy_simpl_set/runner.cuh +++ b/cpp/src/umap/fuzzy_simpl_set/runner.cuh @@ -41,7 +41,7 @@ using namespace ML; */ template void run(int n, const value_idx *knn_indices, const T *knn_dists, - int n_neighbors, MLCommon::Sparse::COO *coo, UMAPParams *params, + int n_neighbors, raft::sparse::COO *coo, UMAPParams *params, std::shared_ptr alloc, cudaStream_t stream, int algorithm = 0) { switch (algorithm) { diff --git a/cpp/src/umap/init_embed/runner.cuh b/cpp/src/umap/init_embed/runner.cuh index 5045c8c8af..c3a4dbdaa5 100644 --- a/cpp/src/umap/init_embed/runner.cuh +++ b/cpp/src/umap/init_embed/runner.cuh @@ -32,7 +32,7 @@ using namespace ML; template void run(const raft::handle_t &handle, int n, int d, const value_idx *knn_indices, const T *knn_dists, - MLCommon::Sparse::COO *coo, UMAPParams *params, T *embedding, + raft::sparse::COO *coo, UMAPParams *params, T *embedding, cudaStream_t stream, int algo = 0) { switch (algo) { /** diff --git a/cpp/src/umap/init_embed/spectral_algo.cuh b/cpp/src/umap/init_embed/spectral_algo.cuh index 5f6175ee95..3ec13fdd2c 100644 --- a/cpp/src/umap/init_embed/spectral_algo.cuh +++ b/cpp/src/umap/init_embed/spectral_algo.cuh @@ -43,8 +43,7 @@ using namespace ML; template void launcher(const raft::handle_t &handle, int n, int d, const value_idx *knn_indices, const T *knn_dists, - MLCommon::Sparse::COO *coo, UMAPParams *params, - T *embedding) { + raft::sparse::COO *coo, UMAPParams *params, T *embedding) { cudaStream_t stream = handle.get_stream(); ASSERT(n > params->n_components, diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 1fd7c15356..d940849a9a 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include @@ -85,7 +85,7 @@ void launcher(const raft::handle_t &handle, const ML::UMAPParams *params, std::shared_ptr d_alloc, cudaStream_t stream) { - MLCommon::Sparse::Selection::brute_force_knn( + raft::sparse::selection::brute_force_knn( inputsA.indptr, inputsA.indices, inputsA.data, inputsA.nnz, inputsA.n, inputsA.d, inputsB.indptr, inputsB.indices, inputsB.data, inputsB.nnz, inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors, diff --git a/cpp/src/umap/runner.cuh b/cpp/src/umap/runner.cuh index 6c48521e9d..9a26c6e266 100644 --- a/cpp/src/umap/runner.cuh +++ b/cpp/src/umap/runner.cuh @@ -36,8 +36,11 @@ #include #include +#include +#include #include -#include +#include +#include #include @@ -51,7 +54,6 @@ namespace FuzzySimplSetImpl = FuzzySimplSet::Naive; namespace SimplSetEmbedImpl = SimplSetEmbed::Algo; using namespace ML; -using namespace MLCommon::Sparse; template __global__ void init_transform(int *indices, T *weights, int n, @@ -126,7 +128,7 @@ void _fit(const raft::handle_t &handle, const umap_inputs &inputs, CUML_LOG_DEBUG("Done. Calling fuzzy simplicial set"); ML::PUSH_RANGE("umap::simplicial_set"); - COO rgraph_coo(d_alloc, stream); + raft::sparse::COO rgraph_coo(d_alloc, stream); FuzzySimplSet::run( inputs.n, knn_graph.knn_indices, knn_graph.knn_dists, k, &rgraph_coo, params, d_alloc, stream); @@ -135,8 +137,8 @@ void _fit(const raft::handle_t &handle, const umap_inputs &inputs, /** * Remove zeros from simplicial set */ - COO cgraph_coo(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(&rgraph_coo, &cgraph_coo, + raft::sparse::COO cgraph_coo(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(&rgraph_coo, &cgraph_coo, d_alloc, stream); ML::POP_RANGE(); @@ -209,8 +211,8 @@ void _fit_supervised(const raft::handle_t &handle, const umap_inputs &inputs, * Allocate workspace for fuzzy simplicial set. */ ML::PUSH_RANGE("umap::simplicial_set"); - COO rgraph_coo(d_alloc, stream); - COO tmp_coo(d_alloc, stream); + raft::sparse::COO rgraph_coo(d_alloc, stream); + raft::sparse::COO tmp_coo(d_alloc, stream); /** * Run Fuzzy simplicial set @@ -221,10 +223,10 @@ void _fit_supervised(const raft::handle_t &handle, const umap_inputs &inputs, &tmp_coo, params, d_alloc, stream); CUDA_CHECK(cudaPeekAtLastError()); - MLCommon::Sparse::coo_remove_zeros(&tmp_coo, &rgraph_coo, + raft::sparse::op::coo_remove_zeros(&tmp_coo, &rgraph_coo, d_alloc, stream); - COO final_coo(d_alloc, stream); + raft::sparse::COO final_coo(d_alloc, stream); /** * If target metric is 'categorical', perform @@ -247,10 +249,10 @@ void _fit_supervised(const raft::handle_t &handle, const umap_inputs &inputs, /** * Remove zeros */ - MLCommon::Sparse::coo_sort(&final_coo, d_alloc, stream); + raft::sparse::op::coo_sort(&final_coo, d_alloc, stream); - COO ocoo(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(&final_coo, &ocoo, d_alloc, + raft::sparse::COO ocoo(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(&final_coo, &ocoo, d_alloc, stream); ML::POP_RANGE(); @@ -366,7 +368,8 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs, * Allocate workspace for fuzzy simplicial set. */ - COO graph_coo(d_alloc, stream, nnz, inputs.n, inputs.n); + raft::sparse::COO graph_coo(d_alloc, stream, nnz, inputs.n, + inputs.n); FuzzySimplSetImpl::compute_membership_strength_kernel <<>>(knn_graph.knn_indices, knn_graph.knn_dists, @@ -378,9 +381,9 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs, MLCommon::device_buffer row_ind(d_alloc, stream, inputs.n); MLCommon::device_buffer ia(d_alloc, stream, inputs.n); - MLCommon::Sparse::sorted_coo_to_csr(&graph_coo, row_ind.data(), d_alloc, - stream); - MLCommon::Sparse::coo_row_count(&graph_coo, ia.data(), stream); + raft::sparse::convert::sorted_coo_to_csr(&graph_coo, row_ind.data(), d_alloc, + stream); + raft::sparse::linalg::coo_degree(&graph_coo, ia.data(), stream); MLCommon::device_buffer vals_normed(d_alloc, stream, graph_coo.nnz); CUDA_CHECK(cudaMemsetAsync(vals_normed.data(), 0, @@ -388,7 +391,7 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs, CUML_LOG_DEBUG("Performing L1 normalization"); - MLCommon::Sparse::csr_row_normalize_l1( + raft::sparse::linalg::csr_row_normalize_l1( row_ind.data(), graph_coo.vals(), graph_coo.nnz, graph_coo.n_rows, vals_normed.data(), stream); @@ -402,7 +405,7 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs, CUDA_CHECK(cudaPeekAtLastError()); /** - * Go through COO values and set everything that's less than + * Go through raft::sparse::COO values and set everything that's less than * vals.max() / params->n_epochs to 0.0 */ thrust::device_ptr d_ptr = @@ -437,8 +440,8 @@ void _transform(const raft::handle_t &handle, const umap_inputs &inputs, /** * Remove zeros */ - MLCommon::Sparse::COO comp_coo(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(&graph_coo, &comp_coo, + raft::sparse::COO comp_coo(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(&graph_coo, &comp_coo, d_alloc, stream); ML::PUSH_RANGE("umap::optimization"); diff --git a/cpp/src/umap/simpl_set_embed/algo.cuh b/cpp/src/umap/simpl_set_embed/algo.cuh index 8543a5e634..81b242cdd5 100644 --- a/cpp/src/umap/simpl_set_embed/algo.cuh +++ b/cpp/src/umap/simpl_set_embed/algo.cuh @@ -31,6 +31,8 @@ #include #include "optimize_batch_kernel.cuh" +#include + #pragma once namespace UMAPAlgo { @@ -194,7 +196,7 @@ void optimize_layout(T *head_embedding, int head_n, T *tail_embedding, * and their 1-skeletons. */ template -void launcher(int m, int n, MLCommon::Sparse::COO *in, UMAPParams *params, +void launcher(int m, int n, raft::sparse::COO *in, UMAPParams *params, T *embedding, std::shared_ptr d_alloc, cudaStream_t stream) { int nnz = in->nnz; @@ -228,8 +230,8 @@ void launcher(int m, int n, MLCommon::Sparse::COO *in, UMAPParams *params, }, stream); - MLCommon::Sparse::COO out(d_alloc, stream); - MLCommon::Sparse::coo_remove_zeros(in, &out, d_alloc, stream); + raft::sparse::COO out(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(in, &out, d_alloc, stream); MLCommon::device_buffer epochs_per_sample(d_alloc, stream, out.nnz); CUDA_CHECK( diff --git a/cpp/src/umap/simpl_set_embed/runner.cuh b/cpp/src/umap/simpl_set_embed/runner.cuh index c8b95b0842..59c3b4c812 100644 --- a/cpp/src/umap/simpl_set_embed/runner.cuh +++ b/cpp/src/umap/simpl_set_embed/runner.cuh @@ -28,7 +28,7 @@ namespace SimplSetEmbed { using namespace ML; template -void run(int m, int n, MLCommon::Sparse::COO *coo, UMAPParams *params, +void run(int m, int n, raft::sparse::COO *coo, UMAPParams *params, T *embedding, std::shared_ptr alloc, cudaStream_t stream, int algorithm = 0) { switch (algorithm) { diff --git a/cpp/src/umap/supervised.cuh b/cpp/src/umap/supervised.cuh index 5fe50c8cf5..b2d9a5414b 100644 --- a/cpp/src/umap/supervised.cuh +++ b/cpp/src/umap/supervised.cuh @@ -34,8 +34,12 @@ #include #include +#include #include -#include +#include +#include +#include +#include #include @@ -47,8 +51,6 @@ namespace Supervised { using namespace ML; -using namespace MLCommon::Sparse; - template __global__ void fast_intersection_kernel(int *rows, int *cols, T *vals, int nnz, T *target, float unknown_dist = 1.0, @@ -65,21 +67,23 @@ __global__ void fast_intersection_kernel(int *rows, int *cols, T *vals, int nnz, } template -void reset_local_connectivity(COO *in_coo, COO *out_coo, +void reset_local_connectivity(raft::sparse::COO *in_coo, + raft::sparse::COO *out_coo, std::shared_ptr d_alloc, cudaStream_t stream // size = nnz*2 ) { MLCommon::device_buffer row_ind(d_alloc, stream, in_coo->n_rows); - MLCommon::Sparse::sorted_coo_to_csr(in_coo, row_ind.data(), d_alloc, stream); + raft::sparse::convert::sorted_coo_to_csr(in_coo, row_ind.data(), d_alloc, + stream); // Perform l_inf normalization - MLCommon::Sparse::csr_row_normalize_max( + raft::sparse::linalg::csr_row_normalize_max( row_ind.data(), in_coo->vals(), in_coo->nnz, in_coo->n_rows, in_coo->vals(), stream); CUDA_CHECK(cudaPeekAtLastError()); - MLCommon::Sparse::coo_symmetrize( + raft::sparse::linalg::coo_symmetrize( in_coo, out_coo, [] __device__(int row, int col, T result, T transpose) { T prod_matrix = result * transpose; @@ -98,11 +102,9 @@ void reset_local_connectivity(COO *in_coo, COO *out_coo, * data. */ template -void categorical_simplicial_set_intersection(COO *graph_coo, - value_t *target, - cudaStream_t stream, - float far_dist = 5.0, - float unknown_dist = 1.0) { +void categorical_simplicial_set_intersection( + raft::sparse::COO *graph_coo, value_t *target, cudaStream_t stream, + float far_dist = 5.0, float unknown_dist = 1.0) { dim3 grid(raft::ceildiv(graph_coo->nnz, TPB_X), 1, 1); dim3 blk(TPB_X, 1, 1); fast_intersection_kernel<<>>( @@ -120,13 +122,13 @@ __global__ void sset_intersection_kernel( if (row < m) { int start_idx_res = result_ind[row]; - int stop_idx_res = MLCommon::Sparse::get_stop_idx(row, m, nnz, result_ind); + int stop_idx_res = raft::sparse::get_stop_idx(row, m, nnz, result_ind); int start_idx1 = row_ind1[row]; - int stop_idx1 = MLCommon::Sparse::get_stop_idx(row, m, nnz1, row_ind1); + int stop_idx1 = raft::sparse::get_stop_idx(row, m, nnz1, row_ind1); int start_idx2 = row_ind2[row]; - int stop_idx2 = MLCommon::Sparse::get_stop_idx(row, m, nnz2, row_ind2); + int stop_idx2 = raft::sparse::get_stop_idx(row, m, nnz2, row_ind2); for (int j = start_idx_res; j < stop_idx_res; j++) { int col = result_cols[j]; @@ -164,13 +166,14 @@ __global__ void sset_intersection_kernel( */ template void general_simplicial_set_intersection( - int *row1_ind, COO *in1, int *row2_ind, COO *in2, COO *result, - float weight, std::shared_ptr d_alloc, cudaStream_t stream) { + int *row1_ind, raft::sparse::COO *in1, int *row2_ind, + raft::sparse::COO *in2, raft::sparse::COO *result, float weight, + std::shared_ptr d_alloc, cudaStream_t stream) { MLCommon::device_buffer result_ind(d_alloc, stream, in1->n_rows); CUDA_CHECK( cudaMemsetAsync(result_ind.data(), 0, in1->n_rows * sizeof(int), stream)); - int result_nnz = MLCommon::Sparse::csr_add_calc_inds( + int result_nnz = raft::sparse::linalg::csr_add_calc_inds( row1_ind, in1->cols(), in1->vals(), in1->nnz, row2_ind, in2->cols(), in2->vals(), in2->nnz, in1->n_rows, result_ind.data(), d_alloc, stream); @@ -179,14 +182,14 @@ void general_simplicial_set_intersection( /** * Element-wise sum of two simplicial sets */ - MLCommon::Sparse::csr_add_finalize( + raft::sparse::linalg::csr_add_finalize( row1_ind, in1->cols(), in1->vals(), in1->nnz, row2_ind, in2->cols(), in2->vals(), in2->nnz, in1->n_rows, result_ind.data(), result->cols(), result->vals(), stream); //@todo: Write a wrapper function for this - MLCommon::Sparse::csr_to_coo(result_ind.data(), result->n_rows, - result->rows(), result->nnz, stream); + raft::sparse::convert::csr_to_coo( + result_ind.data(), result->n_rows, result->rows(), result->nnz, stream); thrust::device_ptr d_ptr1 = thrust::device_pointer_cast(in1->vals()); T min1 = *(thrust::min_element(thrust::cuda::par.on(stream), d_ptr1, @@ -212,8 +215,9 @@ void general_simplicial_set_intersection( } template -void perform_categorical_intersection(T *y, COO *rgraph_coo, - COO *final_coo, UMAPParams *params, +void perform_categorical_intersection(T *y, raft::sparse::COO *rgraph_coo, + raft::sparse::COO *final_coo, + UMAPParams *params, std::shared_ptr d_alloc, cudaStream_t stream) { float far_dist = 1.0e12; // target weight @@ -223,8 +227,9 @@ void perform_categorical_intersection(T *y, COO *rgraph_coo, categorical_simplicial_set_intersection(rgraph_coo, y, stream, far_dist); - COO comp_coo(d_alloc, stream); - coo_remove_zeros(rgraph_coo, &comp_coo, d_alloc, stream); + raft::sparse::COO comp_coo(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(rgraph_coo, &comp_coo, d_alloc, + stream); reset_local_connectivity(&comp_coo, final_coo, d_alloc, stream); @@ -233,9 +238,9 @@ void perform_categorical_intersection(T *y, COO *rgraph_coo, template void perform_general_intersection(const raft::handle_t &handle, value_t *y, - COO *rgraph_coo, - COO *final_coo, UMAPParams *params, - cudaStream_t stream) { + raft::sparse::COO *rgraph_coo, + raft::sparse::COO *final_coo, + UMAPParams *params, cudaStream_t stream) { auto d_alloc = handle.get_device_allocator(); /** @@ -272,7 +277,7 @@ void perform_general_intersection(const raft::handle_t &handle, value_t *y, /** * Compute fuzzy simplicial set */ - COO ygraph_coo(d_alloc, stream); + raft::sparse::COO ygraph_coo(d_alloc, stream); FuzzySimplSet::run( rgraph_coo->n_rows, y_knn_indices.data(), y_knn_dists.data(), @@ -297,15 +302,16 @@ void perform_general_intersection(const raft::handle_t &handle, value_t *y, CUDA_CHECK(cudaMemsetAsync(yrow_ind.data(), 0, ygraph_coo.n_rows * sizeof(int), stream)); - COO cygraph_coo(d_alloc, stream); - coo_remove_zeros(&ygraph_coo, &cygraph_coo, d_alloc, stream); + raft::sparse::COO cygraph_coo(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(&ygraph_coo, &cygraph_coo, + d_alloc, stream); - MLCommon::Sparse::sorted_coo_to_csr(&cygraph_coo, yrow_ind.data(), d_alloc, - stream); - MLCommon::Sparse::sorted_coo_to_csr(rgraph_coo, xrow_ind.data(), d_alloc, - stream); + raft::sparse::convert::sorted_coo_to_csr(&cygraph_coo, yrow_ind.data(), + d_alloc, stream); + raft::sparse::convert::sorted_coo_to_csr(rgraph_coo, xrow_ind.data(), d_alloc, + stream); - COO result_coo(d_alloc, stream); + raft::sparse::COO result_coo(d_alloc, stream); general_simplicial_set_intersection( xrow_ind.data(), rgraph_coo, yrow_ind.data(), &cygraph_coo, &result_coo, params->target_weights, d_alloc, stream); @@ -313,8 +319,9 @@ void perform_general_intersection(const raft::handle_t &handle, value_t *y, /** * Remove zeros */ - COO out(d_alloc, stream); - coo_remove_zeros(&result_coo, &out, d_alloc, stream); + raft::sparse::COO out(d_alloc, stream); + raft::sparse::op::coo_remove_zeros(&result_coo, &out, d_alloc, + stream); reset_local_connectivity(&out, final_coo, d_alloc, stream); diff --git a/cpp/src_prims/selection/columnWiseSort.cuh b/cpp/src_prims/selection/columnWiseSort.cuh index d80e5dd9f1..6f1563c3d8 100644 --- a/cpp/src_prims/selection/columnWiseSort.cuh +++ b/cpp/src_prims/selection/columnWiseSort.cuh @@ -175,8 +175,7 @@ template void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, int n_columns, bool &bAllocWorkspace, void *workspacePtr, size_t &workspaceSize, cudaStream_t stream, - InType *sortedKeys = nullptr, bool ascending = true) { - ASSERT(ascending, "Descending sort not implemented yet"); + InType *sortedKeys = nullptr) { // assume non-square row-major matrices // current use-case: KNN, trustworthiness scores // output : either sorted indices or sorted indices and input values @@ -232,17 +231,10 @@ void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, OutType *tmpValIn = nullptr; int *tmpOffsetBuffer = nullptr; - if (ascending) { - // first call is to get size of workspace - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( - workspacePtr, workspaceSize, in, sortedKeys, tmpValIn, out, - totalElements, numSegments, tmpOffsetBuffer, tmpOffsetBuffer + 1)); - } else { - // first call is to get size of workspace - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairsDescending( - workspacePtr, workspaceSize, in, sortedKeys, tmpValIn, out, - totalElements, numSegments, tmpOffsetBuffer, tmpOffsetBuffer + 1)); - } + // first call is to get size of workspace + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( + workspacePtr, workspaceSize, in, sortedKeys, tmpValIn, out, + totalElements, numSegments, tmpOffsetBuffer, tmpOffsetBuffer + 1)); bAllocWorkspace = true; // more staging space for temp output of keys if (!sortedKeys) @@ -283,17 +275,10 @@ void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, CUDA_CHECK( layoutSortOffset(dSegmentOffsets, n_columns, numSegments, stream)); - if (ascending) { - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( - workspacePtr, workspaceSize, in, sortedKeys, dValuesIn, out, - totalElements, numSegments, dSegmentOffsets, dSegmentOffsets + 1, 0, - sizeof(InType) * 8, stream)); - } else { - CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairsDescending( - workspacePtr, workspaceSize, in, sortedKeys, dValuesIn, out, - totalElements, numSegments, dSegmentOffsets, dSegmentOffsets + 1, 0, - sizeof(InType) * 8, stream)); - } + CUDA_CHECK(cub::DeviceSegmentedRadixSort::SortPairs( + workspacePtr, workspaceSize, in, sortedKeys, dValuesIn, out, + totalElements, numSegments, dSegmentOffsets, dSegmentOffsets + 1, 0, + sizeof(InType) * 8, stream)); } } else { // batched per row device wide sort @@ -337,15 +322,9 @@ void sortColumnsPerRow(const InType *in, OutType *out, int n_rows, OutType *rowOut = reinterpret_cast( (size_t)out + (i * sizeof(OutType) * (size_t)n_columns)); - if (ascending) { - CUDA_CHECK(cub::DeviceRadixSort::SortPairs( - workspacePtr, workspaceSize, rowIn, sortedKeys, dValuesIn, rowOut, - n_columns)); - } else { - CUDA_CHECK(cub::DeviceRadixSort::SortPairsDescending( - workspacePtr, workspaceSize, rowIn, sortedKeys, dValuesIn, rowOut, - n_columns)); - } + CUDA_CHECK(cub::DeviceRadixSort::SortPairs(workspacePtr, workspaceSize, + rowIn, sortedKeys, dValuesIn, + rowOut, n_columns)); if (userKeyOutputBuffer) sortedKeys = reinterpret_cast( diff --git a/cpp/src_prims/sparse/convert/coo.cuh b/cpp/src_prims/sparse/convert/coo.cuh new file mode 100644 index 0000000000..21ea45a0ef --- /dev/null +++ b/cpp/src_prims/sparse/convert/coo.cuh @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include + +#include +#include + +namespace raft { +namespace sparse { +namespace convert { + +template +__global__ void csr_to_coo_kernel(const value_idx *row_ind, value_idx m, + value_idx *coo_rows, value_idx nnz) { + // row-based matrix 1 thread per row + value_idx row = (blockIdx.x * TPB_X) + threadIdx.x; + if (row < m) { + value_idx start_idx = row_ind[row]; + value_idx stop_idx = get_stop_idx(row, m, nnz, row_ind); + for (value_idx i = start_idx; i < stop_idx; i++) coo_rows[i] = row; + } +} + +/** + * @brief Convert a CSR row_ind array to a COO rows array + * @param row_ind: Input CSR row_ind array + * @param m: size of row_ind array + * @param coo_rows: Output COO row array + * @param nnz: size of output COO row array + * @param stream: cuda stream to use + */ +template +void csr_to_coo(const value_idx *row_ind, value_idx m, value_idx *coo_rows, + value_idx nnz, cudaStream_t stream) { + // @TODO: Use cusparse for this. + dim3 grid(raft::ceildiv(m, (value_idx)TPB_X), 1, 1); + dim3 blk(TPB_X, 1, 1); + + csr_to_coo_kernel + <<>>(row_ind, m, coo_rows, nnz); + + CUDA_CHECK(cudaGetLastError()); +} + +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft \ No newline at end of file diff --git a/cpp/src_prims/sparse/convert/csr.cuh b/cpp/src_prims/sparse/convert/csr.cuh new file mode 100644 index 0000000000..0f5ce6d10f --- /dev/null +++ b/cpp/src_prims/sparse/convert/csr.cuh @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace convert { + +template +void coo_to_csr(const raft::handle_t &handle, const int *srcRows, + const int *srcCols, const value_t *srcVals, int nnz, int m, + int *dst_offsets, int *dstCols, value_t *dstVals) { + auto stream = handle.get_stream(); + auto cusparseHandle = handle.get_cusparse_handle(); + auto d_alloc = handle.get_device_allocator(); + raft::mr::device::buffer dstRows(d_alloc, stream, nnz); + CUDA_CHECK(cudaMemcpyAsync(dstRows.data(), srcRows, sizeof(int) * nnz, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dstCols, srcCols, sizeof(int) * nnz, + cudaMemcpyDeviceToDevice, stream)); + auto buffSize = raft::sparse::cusparsecoosort_bufferSizeExt( + cusparseHandle, m, m, nnz, srcRows, srcCols, stream); + raft::mr::device::buffer pBuffer(d_alloc, stream, buffSize); + raft::mr::device::buffer P(d_alloc, stream, nnz); + CUSPARSE_CHECK( + cusparseCreateIdentityPermutation(cusparseHandle, nnz, P.data())); + raft::sparse::cusparsecoosortByRow(cusparseHandle, m, m, nnz, dstRows.data(), + dstCols, P.data(), pBuffer.data(), stream); + raft::sparse::cusparsegthr(cusparseHandle, nnz, srcVals, dstVals, P.data(), + stream); + raft::sparse::cusparsecoo2csr(cusparseHandle, dstRows.data(), nnz, m, + dst_offsets, stream); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +/** + * @brief Constructs an adjacency graph CSR row_ind_ptr array from + * a row_ind array and adjacency array. + * @tparam T the numeric type of the index arrays + * @tparam TPB_X the number of threads to use per block for kernels + * @tparam Lambda function for fused operation in the adj_graph construction + * @param row_ind the input CSR row_ind array + * @param total_rows number of vertices in graph + * @param nnz number of non-zeros + * @param batchSize number of vertices in current batch + * @param adj an adjacency array (size batchSize x total_rows) + * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph + * @param stream cuda stream to use + * @param fused_op: the fused operation + */ +template void> +void csr_adj_graph_batched(const Index_ *row_ind, Index_ total_rows, Index_ nnz, + Index_ batchSize, const bool *adj, + Index_ *row_ind_ptr, cudaStream_t stream, + Lambda fused_op) { + op::csr_row_op( + row_ind, batchSize, nnz, + [fused_op, adj, total_rows, row_ind_ptr, batchSize, nnz] __device__( + Index_ row, Index_ start_idx, Index_ stop_idx) { + fused_op(row, start_idx, stop_idx); + Index_ k = 0; + for (Index_ i = 0; i < total_rows; i++) { + // @todo: uncoalesced mem accesses! + if (adj[batchSize * i + row]) { + row_ind_ptr[start_idx + k] = i; + k += 1; + } + } + }, + stream); +} + +template void> +void csr_adj_graph_batched(const Index_ *row_ind, Index_ total_rows, Index_ nnz, + Index_ batchSize, const bool *adj, + Index_ *row_ind_ptr, cudaStream_t stream) { + csr_adj_graph_batched( + row_ind, total_rows, nnz, batchSize, adj, row_ind_ptr, stream, + [] __device__(Index_ row, Index_ start_idx, Index_ stop_idx) {}); +} + +/** + * @brief Constructs an adjacency graph CSR row_ind_ptr array from a + * a row_ind array and adjacency array. + * @tparam T the numeric type of the index arrays + * @tparam TPB_X the number of threads to use per block for kernels + * @param row_ind the input CSR row_ind array + * @param total_rows number of total vertices in graph + * @param nnz number of non-zeros + * @param adj an adjacency array + * @param row_ind_ptr output CSR row_ind_ptr for adjacency graph + * @param stream cuda stream to use + * @param fused_op the fused operation + */ +template void> +void csr_adj_graph(const Index_ *row_ind, Index_ total_rows, Index_ nnz, + const bool *adj, Index_ *row_ind_ptr, cudaStream_t stream, + Lambda fused_op) { + csr_adj_graph_batched(row_ind, total_rows, nnz, total_rows, + adj, row_ind_ptr, stream, fused_op); +} + +/** + * @brief Generate the row indices array for a sorted COO matrix + * + * @param rows: COO rows array + * @param nnz: size of COO rows array + * @param row_ind: output row indices array + * @param m: number of rows in dense matrix + * @param d_alloc device allocator for temporary buffers + * @param stream: cuda stream to use + */ +template +void sorted_coo_to_csr(const T *rows, int nnz, T *row_ind, int m, + std::shared_ptr d_alloc, + cudaStream_t stream) { + raft::mr::device::buffer row_counts(d_alloc, stream, m); + + CUDA_CHECK(cudaMemsetAsync(row_counts.data(), 0, m * sizeof(T), stream)); + + linalg::coo_degree<32>(rows, nnz, row_counts.data(), stream); + + // create csr compressed row index from row counts + thrust::device_ptr row_counts_d = + thrust::device_pointer_cast(row_counts.data()); + thrust::device_ptr c_ind_d = thrust::device_pointer_cast(row_ind); + exclusive_scan(thrust::cuda::par.on(stream), row_counts_d, row_counts_d + m, + c_ind_d); +} + +/** + * @brief Generate the row indices array for a sorted COO matrix + * + * @param coo: Input COO matrix + * @param row_ind: output row indices array + * @param d_alloc device allocator for temporary buffers + * @param stream: cuda stream to use + */ +template +void sorted_coo_to_csr(COO *coo, int *row_ind, + std::shared_ptr d_alloc, + cudaStream_t stream) { + sorted_coo_to_csr(coo->rows(), coo->nnz, row_ind, coo->n_rows, d_alloc, + stream); +} + +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft \ No newline at end of file diff --git a/cpp/src_prims/sparse/convert/dense.cuh b/cpp/src_prims/sparse/convert/dense.cuh new file mode 100644 index 0000000000..772596f6df --- /dev/null +++ b/cpp/src_prims/sparse/convert/dense.cuh @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + +#include + +namespace raft { +namespace sparse { +namespace convert { + +template +__global__ void csr_to_dense_warp_per_row_kernel(int n_cols, + const value_t *csrVal, + const int *csrRowPtr, + const int *csrColInd, + value_t *a) { + int row = blockIdx.x; + int tid = threadIdx.x; + + int colStart = csrRowPtr[row]; + int colEnd = csrRowPtr[row + 1]; + int rowNnz = colEnd - colStart; + + for (int i = tid; i < rowNnz; i += blockDim.x) { + int colIdx = colStart + i; + if (colIdx < colEnd) { + int col = csrColInd[colIdx]; + a[row * n_cols + col] = csrVal[colIdx]; + } + } +} + +/** + * Convert CSR arrays to a dense matrix in either row- + * or column-major format. A custom kernel is used when + * row-major output is desired since cusparse does not + * output row-major. + * @tparam value_idx : data type of the CSR index arrays + * @tparam value_t : data type of the CSR value array + * @param[in] handle : cusparse handle for conversion + * @param[in] nrows : number of rows in CSR + * @param[in] ncols : number of columns in CSR + * @param[in] csr_indptr : CSR row index pointer array + * @param[in] csr_indices : CSR column indices array + * @param[in] csr_data : CSR data array + * @param[in] lda : Leading dimension (used for col-major only) + * @param[out] out : Dense output array of size nrows * ncols + * @param[in] stream : Cuda stream for ordering events + * @param[in] row_major : Is row-major output desired? + */ +template +void csr_to_dense(cusparseHandle_t handle, value_idx nrows, value_idx ncols, + const value_idx *csr_indptr, const value_idx *csr_indices, + const value_t *csr_data, value_idx lda, value_t *out, + cudaStream_t stream, bool row_major = true) { + if (!row_major) { + /** + * If we need col-major, use cusparse. + */ + cusparseMatDescr_t out_mat; + CUSPARSE_CHECK(cusparseCreateMatDescr(&out_mat)); + CUSPARSE_CHECK(cusparseSetMatIndexBase(out_mat, CUSPARSE_INDEX_BASE_ZERO)); + CUSPARSE_CHECK(cusparseSetMatType(out_mat, CUSPARSE_MATRIX_TYPE_GENERAL)); + + CUSPARSE_CHECK(raft::sparse::cusparsecsr2dense( + handle, nrows, ncols, out_mat, csr_data, csr_indptr, csr_indices, out, + lda, stream)); + + CUSPARSE_CHECK_NO_THROW(cusparseDestroyMatDescr(out_mat)); + + } else { + int blockdim = block_dim(ncols); + CUDA_CHECK( + cudaMemsetAsync(out, 0, nrows * ncols * sizeof(value_t), stream)); + csr_to_dense_warp_per_row_kernel<<>>( + ncols, csr_data, csr_indptr, csr_indices, out); + } +} + +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft \ No newline at end of file diff --git a/cpp/src_prims/sparse/coo.cuh b/cpp/src_prims/sparse/coo.cuh index 79da8dc18d..520f29d292 100644 --- a/cpp/src_prims/sparse/coo.cuh +++ b/cpp/src_prims/sparse/coo.cuh @@ -14,12 +14,11 @@ * limitations under the License. */ -#include -#include "csr.cuh" - +#include #include - -#include +#include +#include +#include #include @@ -28,8 +27,6 @@ #include #include -#include -#include #include #include @@ -37,8 +34,8 @@ #pragma once -namespace MLCommon { -namespace Sparse { +namespace raft { +namespace sparse { /** @brief A Container object for sparse coordinate. There are two motivations * behind using a container for COO arrays. @@ -61,9 +58,9 @@ namespace Sparse { template class COO { protected: - device_buffer rows_arr; - device_buffer cols_arr; - device_buffer vals_arr; + raft::mr::device::buffer rows_arr; + raft::mr::device::buffer cols_arr; + raft::mr::device::buffer vals_arr; public: Index_Type nnz; @@ -74,7 +71,7 @@ class COO { * @param d_alloc: the device allocator to use for the underlying buffers * @param stream: CUDA stream to use */ - COO(std::shared_ptr d_alloc, cudaStream_t stream) + COO(std::shared_ptr d_alloc, cudaStream_t stream) : rows_arr(d_alloc, stream, 0), cols_arr(d_alloc, stream, 0), vals_arr(d_alloc, stream, 0), @@ -90,8 +87,9 @@ class COO { * @param n_rows: number of rows in the dense matrix * @param n_cols: number of cols in the dense matrix */ - COO(device_buffer &rows, device_buffer &cols, - device_buffer &vals, Index_Type nnz, Index_Type n_rows = 0, + COO(raft::mr::device::buffer &rows, + raft::mr::device::buffer &cols, + raft::mr::device::buffer &vals, Index_Type nnz, Index_Type n_rows = 0, Index_Type n_cols = 0) : rows_arr(rows), cols_arr(cols), @@ -108,7 +106,7 @@ class COO { * @param n_cols: number of cols in the dense matrix * @param init: initialize arrays with zeros */ - COO(std::shared_ptr d_alloc, cudaStream_t stream, + COO(std::shared_ptr d_alloc, cudaStream_t stream, Index_Type nnz, Index_Type n_rows = 0, Index_Type n_cols = 0, bool init = true) : rows_arr(d_alloc, stream, nnz), @@ -257,731 +255,5 @@ class COO { } }; -/** - * @brief Sorts the arrays that comprise the coo matrix - * by row. - * - * @param m number of rows in coo matrix - * @param n number of cols in coo matrix - * @param nnz number of non-zeros - * @param rows rows array from coo matrix - * @param cols cols array from coo matrix - * @param vals vals array from coo matrix - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void coo_sort(int m, int n, int nnz, int *rows, int *cols, T *vals, - std::shared_ptr d_alloc, cudaStream_t stream) { - cusparseHandle_t handle = NULL; - - size_t pBufferSizeInBytes = 0; - - CUSPARSE_CHECK(cusparseCreate(&handle)); - CUSPARSE_CHECK(cusparseSetStream(handle, stream)); - CUSPARSE_CHECK(cusparseXcoosort_bufferSizeExt(handle, m, n, nnz, rows, cols, - &pBufferSizeInBytes)); - - device_buffer d_P(d_alloc, stream, nnz); - device_buffer pBuffer(d_alloc, stream, pBufferSizeInBytes); - - CUSPARSE_CHECK(cusparseCreateIdentityPermutation(handle, nnz, d_P.data())); - - CUSPARSE_CHECK(cusparseXcoosortByRow(handle, m, n, nnz, rows, cols, - d_P.data(), pBuffer.data())); - - device_buffer vals_sorted(d_alloc, stream, nnz); - - CUSPARSE_CHECK(raft::sparse::cusparsegthr( - handle, nnz, vals, vals_sorted.data(), d_P.data(), stream)); - - CUDA_CHECK(cudaStreamSynchronize(stream)); - - raft::copy(vals, vals_sorted.data(), nnz, stream); - - CUSPARSE_CHECK(cusparseDestroy(handle)); -} - -/** - * @brief Sort the underlying COO arrays by row - * @tparam T: the type name of the underlying value array - * @param in: COO to sort by row - * @param d_alloc device allocator for temporary buffers - * @param stream: the cuda stream to use - */ -template -void coo_sort(COO *const in, std::shared_ptr d_alloc, - cudaStream_t stream) { - coo_sort(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), - in->vals(), d_alloc, stream); -} - -template -__global__ void coo_remove_zeros_kernel(const int *rows, const int *cols, - const T *vals, int nnz, int *crows, - int *ccols, T *cvals, int *ex_scan, - int *cur_ex_scan, int m) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - - if (row < m) { - int start = cur_ex_scan[row]; - int stop = MLCommon::Sparse::get_stop_idx(row, m, nnz, cur_ex_scan); - int cur_out_idx = ex_scan[row]; - - for (int idx = start; idx < stop; idx++) { - if (vals[idx] != 0.0) { - crows[cur_out_idx] = rows[idx]; - ccols[cur_out_idx] = cols[idx]; - cvals[cur_out_idx] = vals[idx]; - ++cur_out_idx; - } - } - } -} - -template -__global__ void coo_remove_scalar_kernel(const int *rows, const int *cols, - const T *vals, int nnz, int *crows, - int *ccols, T *cvals, int *ex_scan, - int *cur_ex_scan, int m, T scalar) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - - if (row < m) { - int start = cur_ex_scan[row]; - int stop = MLCommon::Sparse::get_stop_idx(row, m, nnz, cur_ex_scan); - int cur_out_idx = ex_scan[row]; - - for (int idx = start; idx < stop; idx++) { - if (vals[idx] != scalar) { - crows[cur_out_idx] = rows[idx]; - ccols[cur_out_idx] = cols[idx]; - cvals[cur_out_idx] = vals[idx]; - ++cur_out_idx; - } - } - } -} - -/** - * @brief Count all the rows in the coo row array and place them in the - * results matrix, indexed by row. - * - * @tparam TPB_X: number of threads to use per block - * @param rows the rows array of the coo matrix - * @param nnz the size of the rows array - * @param results array to place results - */ -template -__global__ void coo_row_count_kernel(const int *rows, int nnz, int *results) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - if (row < nnz) { - raft::myAtomicAdd(results + rows[row], 1); - } -} - -/** - * @brief Count the number of values for each row - * @tparam TPB_X: number of threads to use per block - * @param rows: rows array of the COO matrix - * @param nnz: size of the rows array - * @param results: output result array - * @param stream: cuda stream to use - */ -template -void coo_row_count(const int *rows, int nnz, int *results, - cudaStream_t stream) { - dim3 grid_rc(raft::ceildiv(nnz, TPB_X), 1, 1); - dim3 blk_rc(TPB_X, 1, 1); - - coo_row_count_kernel - <<>>(rows, nnz, results); - CUDA_CHECK(cudaGetLastError()); -} - -/** - * @brief Count the number of values for each row - * @tparam TPB_X: number of threads to use per block - * @tparam T: type name of underlying values array - * @param in: input COO object for counting rows - * @param results: output array with row counts (size=in->n_rows) - * @param stream: cuda stream to use - */ -template -void coo_row_count(COO *in, int *results, cudaStream_t stream) { - dim3 grid_rc(raft::ceildiv(in->nnz, TPB_X), 1, 1); - dim3 blk_rc(TPB_X, 1, 1); - - coo_row_count_kernel - <<>>(in->rows(), in->nnz, results); - CUDA_CHECK(cudaGetLastError()); -} - -template -__global__ void coo_row_count_nz_kernel(const int *rows, const T *vals, int nnz, - int *results) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - if (row < nnz && vals[row] != 0.0) { - raft::myAtomicAdd(results + rows[row], 1); - } -} - -template -__global__ void coo_row_count_scalar_kernel(const int *rows, const T *vals, - int nnz, T scalar, int *results) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - if (row < nnz && vals[row] != scalar) { - raft::myAtomicAdd(results + rows[row], 1); - } -} - -/** - * @brief Count the number of values for each row matching a particular scalar - * @tparam TPB_X: number of threads to use per block - * @tparam T: the type name of the underlying value arrays - * @param in: Input COO array - * @param scalar: scalar to match for counting rows - * @param results: output row counts - * @param stream: cuda stream to use - */ -template -void coo_row_count_scalar(COO *in, T scalar, int *results, - cudaStream_t stream) { - dim3 grid_rc(raft::ceildiv(in->nnz, TPB_X), 1, 1); - dim3 blk_rc(TPB_X, 1, 1); - - coo_row_count_scalar_kernel<<>>( - in->rows(), in->vals(), in->nnz, scalar, results); - CUDA_CHECK(cudaGetLastError()); -} - -/** - * @brief Count the number of values for each row matching a particular scalar - * @tparam TPB_X: number of threads to use per block - * @tparam T: the type name of the underlying value arrays - * @param rows: Input COO row array - * @param vals: Input COO val arrays - * @param nnz: size of input COO arrays - * @param scalar: scalar to match for counting rows - * @param results: output row counts - * @param stream: cuda stream to use - */ -template -void coo_row_count_scalar(const int *rows, const T *vals, int nnz, T scalar, - int *results, cudaStream_t stream = 0) { - dim3 grid_rc(raft::ceildiv(nnz, TPB_X), 1, 1); - dim3 blk_rc(TPB_X, 1, 1); - - coo_row_count_scalar_kernel - <<>>(rows, vals, nnz, scalar, results); - CUDA_CHECK(cudaGetLastError()); -} - -/** - * @brief Count the number of nonzeros for each row - * @tparam TPB_X: number of threads to use per block - * @tparam T: the type name of the underlying value arrays - * @param rows: Input COO row array - * @param vals: Input COO val arrays - * @param nnz: size of input COO arrays - * @param results: output row counts - * @param stream: cuda stream to use - */ -template -void coo_row_count_nz(const int *rows, const T *vals, int nnz, int *results, - cudaStream_t stream) { - dim3 grid_rc(raft::ceildiv(nnz, TPB_X), 1, 1); - dim3 blk_rc(TPB_X, 1, 1); - - coo_row_count_nz_kernel - <<>>(rows, vals, nnz, results); - CUDA_CHECK(cudaGetLastError()); -} - -/** - * @brief Count the number of nonzero values for each row - * @tparam TPB_X: number of threads to use per block - * @tparam T: the type name of the underlying value arrays - * @param in: Input COO array - * @param results: output row counts - * @param stream: cuda stream to use - */ -template -void coo_row_count_nz(COO *in, int *results, cudaStream_t stream) { - dim3 grid_rc(raft::ceildiv(in->nnz, TPB_X), 1, 1); - dim3 blk_rc(TPB_X, 1, 1); - - coo_row_count_nz_kernel - <<>>(in->rows(), in->vals(), in->nnz, results); - CUDA_CHECK(cudaGetLastError()); -} - -/** - * @brief Removes the values matching a particular scalar from a COO formatted sparse matrix. - * - * @param rows: input array of rows (size n) - * @param cols: input array of cols (size n) - * @param vals: input array of vals (size n) - * @param nnz: size of current rows/cols/vals arrays - * @param crows: compressed array of rows - * @param ccols: compressed array of cols - * @param cvals: compressed array of vals - * @param cnnz: array of non-zero counts per row - * @param cur_cnnz array of counts per row - * @param scalar: scalar to remove from arrays - * @param n: number of rows in dense matrix - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void coo_remove_scalar(const int *rows, const int *cols, const T *vals, int nnz, - int *crows, int *ccols, T *cvals, int *cnnz, - int *cur_cnnz, T scalar, int n, - std::shared_ptr d_alloc, - cudaStream_t stream) { - device_buffer ex_scan(d_alloc, stream, n); - device_buffer cur_ex_scan(d_alloc, stream, n); - - CUDA_CHECK(cudaMemsetAsync(ex_scan.data(), 0, n * sizeof(int), stream)); - CUDA_CHECK(cudaMemsetAsync(cur_ex_scan.data(), 0, n * sizeof(int), stream)); - - thrust::device_ptr dev_cnnz = thrust::device_pointer_cast(cnnz); - thrust::device_ptr dev_ex_scan = - thrust::device_pointer_cast(ex_scan.data()); - thrust::exclusive_scan(thrust::cuda::par.on(stream), dev_cnnz, dev_cnnz + n, - dev_ex_scan); - CUDA_CHECK(cudaPeekAtLastError()); - - thrust::device_ptr dev_cur_cnnz = thrust::device_pointer_cast(cur_cnnz); - thrust::device_ptr dev_cur_ex_scan = - thrust::device_pointer_cast(cur_ex_scan.data()); - thrust::exclusive_scan(thrust::cuda::par.on(stream), dev_cur_cnnz, - dev_cur_cnnz + n, dev_cur_ex_scan); - CUDA_CHECK(cudaPeekAtLastError()); - - dim3 grid(raft::ceildiv(n, TPB_X), 1, 1); - dim3 blk(TPB_X, 1, 1); - - coo_remove_scalar_kernel<<>>( - rows, cols, vals, nnz, crows, ccols, cvals, dev_ex_scan.get(), - dev_cur_ex_scan.get(), n, scalar); - CUDA_CHECK(cudaPeekAtLastError()); -} - -/** - * @brief Removes the values matching a particular scalar from a COO formatted sparse matrix. - * - * @param in: input COO matrix - * @param out: output COO matrix - * @param scalar: scalar to remove from arrays - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void coo_remove_scalar(COO *in, COO *out, T scalar, - std::shared_ptr d_alloc, - cudaStream_t stream) { - device_buffer row_count_nz(d_alloc, stream, in->n_rows); - device_buffer row_count(d_alloc, stream, in->n_rows); - - CUDA_CHECK( - cudaMemsetAsync(row_count_nz.data(), 0, in->n_rows * sizeof(int), stream)); - CUDA_CHECK( - cudaMemsetAsync(row_count.data(), 0, in->n_rows * sizeof(int), stream)); - - MLCommon::Sparse::coo_row_count(in->rows(), in->nnz, row_count.data(), - stream); - CUDA_CHECK(cudaPeekAtLastError()); - - MLCommon::Sparse::coo_row_count_scalar( - in->rows(), in->vals(), in->nnz, scalar, row_count_nz.data(), stream); - CUDA_CHECK(cudaPeekAtLastError()); - - thrust::device_ptr d_row_count_nz = - thrust::device_pointer_cast(row_count_nz.data()); - int out_nnz = thrust::reduce(thrust::cuda::par.on(stream), d_row_count_nz, - d_row_count_nz + in->n_rows); - - out->allocate(out_nnz, in->n_rows, in->n_cols, false, stream); - - coo_remove_scalar(in->rows(), in->cols(), in->vals(), in->nnz, - out->rows(), out->cols(), out->vals(), - row_count_nz.data(), row_count.data(), scalar, - in->n_rows, d_alloc, stream); - CUDA_CHECK(cudaPeekAtLastError()); -} - -/** - * @brief Removes zeros from a COO formatted sparse matrix. - * - * @param in: input COO matrix - * @param out: output COO matrix - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void coo_remove_zeros(COO *in, COO *out, - std::shared_ptr d_alloc, - cudaStream_t stream) { - coo_remove_scalar(in, out, T(0.0), d_alloc, stream); -} - -template -__global__ void from_knn_graph_kernel(const long *knn_indices, - const T *knn_dists, int m, int k, - int *rows, int *cols, T *vals) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - if (row < m) { - for (int i = 0; i < k; i++) { - rows[row * k + i] = row; - cols[row * k + i] = knn_indices[row * k + i]; - vals[row * k + i] = knn_dists[row * k + i]; - } - } -} - -/** - * @brief Converts a knn graph, defined by index and distance matrices, - * into COO format. - * - * @param knn_indices: knn index array - * @param knn_dists: knn distance array - * @param m: number of vertices in graph - * @param k: number of nearest neighbors - * @param rows: output COO row array - * @param cols: output COO col array - * @param vals: output COO val array - */ -template -void from_knn(const long *knn_indices, const T *knn_dists, int m, int k, - int *rows, int *cols, T *vals) { - dim3 grid(raft::ceildiv(m, 32), 1, 1); - dim3 blk(32, 1, 1); - from_knn_graph_kernel<32, T> - <<>>(knn_indices, knn_dists, m, k, rows, cols, vals); - CUDA_CHECK(cudaGetLastError()); -} - -/** - * Converts a knn graph, defined by index and distance matrices, - * into COO format. - * @param knn_indices: KNN index array (size m * k) - * @param knn_dists: KNN dist array (size m * k) - * @param m: number of vertices in graph - * @param k: number of nearest neighbors - * @param out: The output COO graph from the KNN matrices - * @param stream: CUDA stream to use - */ -template -void from_knn(const long *knn_indices, const T *knn_dists, int m, int k, - COO *out, cudaStream_t stream) { - out->allocate(m * k, m, m, true, stream); - - from_knn(knn_indices, knn_dists, m, k, out->rows(), out->cols(), out->vals()); -} - -/** - * @brief Generate the row indices array for a sorted COO matrix - * - * @param rows: COO rows array - * @param nnz: size of COO rows array - * @param row_ind: output row indices array - * @param m: number of rows in dense matrix - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void sorted_coo_to_csr(const T *rows, int nnz, T *row_ind, int m, - std::shared_ptr d_alloc, - cudaStream_t stream) { - device_buffer row_counts(d_alloc, stream, m); - - CUDA_CHECK(cudaMemsetAsync(row_counts.data(), 0, m * sizeof(T), stream)); - - coo_row_count<32>(rows, nnz, row_counts.data(), stream); - - // create csr compressed row index from row counts - thrust::device_ptr row_counts_d = - thrust::device_pointer_cast(row_counts.data()); - thrust::device_ptr c_ind_d = thrust::device_pointer_cast(row_ind); - exclusive_scan(thrust::cuda::par.on(stream), row_counts_d, row_counts_d + m, - c_ind_d); -} - -/** - * @brief Generate the row indices array for a sorted COO matrix - * - * @param coo: Input COO matrix - * @param row_ind: output row indices array - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void sorted_coo_to_csr(COO *coo, int *row_ind, - std::shared_ptr d_alloc, - cudaStream_t stream) { - sorted_coo_to_csr(coo->rows(), coo->nnz, row_ind, coo->n_rows, d_alloc, - stream); -} - -template -__global__ void coo_symmetrize_kernel(int *row_ind, int *rows, int *cols, - T *vals, int *orows, int *ocols, T *ovals, - int n, int cnnz, Lambda reduction_op) { - int row = (blockIdx.x * TPB_X) + threadIdx.x; - - if (row < n) { - int start_idx = row_ind[row]; // each thread processes one row - int stop_idx = MLCommon::Sparse::get_stop_idx(row, n, cnnz, row_ind); - - int row_nnz = 0; - int out_start_idx = start_idx * 2; - - for (int idx = 0; idx < stop_idx - start_idx; idx++) { - int cur_row = rows[idx + start_idx]; - int cur_col = cols[idx + start_idx]; - T cur_val = vals[idx + start_idx]; - - int lookup_row = cur_col; - int t_start = row_ind[lookup_row]; // Start at - int t_stop = MLCommon::Sparse::get_stop_idx(lookup_row, n, cnnz, row_ind); - - T transpose = 0.0; - - bool found_match = false; - for (int t_idx = t_start; t_idx < t_stop; t_idx++) { - // If we find a match, let's get out of the loop. We won't - // need to modify the transposed value, since that will be - // done in a different thread. - if (cols[t_idx] == cur_row && rows[t_idx] == cur_col) { - // If it exists already, set transposed value to existing value - transpose = vals[t_idx]; - found_match = true; - break; - } - } - - // Custom reduction op on value and its transpose, which enables - // specialized weighting. - // If only simple X+X.T is desired, this op can just sum - // the two values. - T res = reduction_op(cur_row, cur_col, cur_val, transpose); - - // if we didn't find an exact match, we need to add - // the computed res into our current matrix to guarantee - // symmetry. - // Note that if we did find a match, we don't need to - // compute `res` on it here because it will be computed - // in a different thread. - if (!found_match && vals[idx] != 0.0) { - orows[out_start_idx + row_nnz] = cur_col; - ocols[out_start_idx + row_nnz] = cur_row; - ovals[out_start_idx + row_nnz] = res; - ++row_nnz; - } - - if (res != 0.0) { - orows[out_start_idx + row_nnz] = cur_row; - ocols[out_start_idx + row_nnz] = cur_col; - ovals[out_start_idx + row_nnz] = res; - ++row_nnz; - } - } - } -} - -/** - * @brief takes a COO matrix which may not be symmetric and symmetrizes - * it, running a custom reduction function against the each value - * and its transposed value. - * - * @param in: Input COO matrix - * @param out: Output symmetrized COO matrix - * @param reduction_op: a custom reduction function - * @param d_alloc device allocator for temporary buffers - * @param stream: cuda stream to use - */ -template -void coo_symmetrize(COO *in, COO *out, - Lambda reduction_op, // two-argument reducer - std::shared_ptr d_alloc, - cudaStream_t stream) { - dim3 grid(raft::ceildiv(in->n_rows, TPB_X), 1, 1); - dim3 blk(TPB_X, 1, 1); - - ASSERT(!out->validate_mem(), "Expecting unallocated COO for output"); - - device_buffer in_row_ind(d_alloc, stream, in->n_rows); - - sorted_coo_to_csr(in, in_row_ind.data(), d_alloc, stream); - - out->allocate(in->nnz * 2, in->n_rows, in->n_cols, true, stream); - - coo_symmetrize_kernel<<>>( - in_row_ind.data(), in->rows(), in->cols(), in->vals(), out->rows(), - out->cols(), out->vals(), in->n_rows, in->nnz, reduction_op); - CUDA_CHECK(cudaPeekAtLastError()); -} - -/** - * @brief Find how much space needed in each row. - * We look through all datapoints and increment the count for each row. - * - * @param data: Input knn distances(n, k) - * @param indices: Input knn indices(n, k) - * @param n: Number of rows - * @param k: Number of n_neighbors - * @param row_sizes: Input empty row sum 1 array(n) - * @param row_sizes2: Input empty row sum 2 array(n) for faster reduction - */ -template -__global__ static void symmetric_find_size(const value_t *restrict data, - const value_idx *restrict indices, - const value_idx n, const int k, - value_idx *restrict row_sizes, - value_idx *restrict row_sizes2) { - const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row - const auto j = - blockIdx.y * blockDim.y + threadIdx.y; // for every item in row - if (row >= n || j >= k) return; - - const auto col = indices[row * k + j]; - if (j % 2) - atomicAdd(&row_sizes[col], (value_idx)1); - else - atomicAdd(&row_sizes2[col], (value_idx)1); -} - -/** - * @brief Reduce sum(row_sizes) + k - * Reduction for symmetric_find_size kernel. Allows algo to be faster. - * - * @param n: Number of rows - * @param k: Number of n_neighbors - * @param row_sizes: Input row sum 1 array(n) - * @param row_sizes2: Input row sum 2 array(n) for faster reduction - */ -template -__global__ static void reduce_find_size(const value_idx n, const int k, - value_idx *restrict row_sizes, - const value_idx *restrict row_sizes2) { - const auto i = (blockIdx.x * blockDim.x) + threadIdx.x; - if (i >= n) return; - row_sizes[i] += (row_sizes2[i] + k); -} - -/** - * @brief Perform data + data.T operation. - * Can only run once row_sizes from the CSR matrix of data + data.T has been - * determined. - * - * @param edges: Input row sum array(n) after reduction - * @param data: Input knn distances(n, k) - * @param indices: Input knn indices(n, k) - * @param VAL: Output values for data + data.T - * @param COL: Output column indices for data + data.T - * @param ROW: Output row indices for data + data.T - * @param n: Number of rows - * @param k: Number of n_neighbors - */ -template -__global__ static void symmetric_sum(value_idx *restrict edges, - const value_t *restrict data, - const value_idx *restrict indices, - value_t *restrict VAL, - value_idx *restrict COL, - value_idx *restrict ROW, const value_idx n, - const int k) { - const auto row = blockIdx.x * blockDim.x + threadIdx.x; // for every row - const auto j = - blockIdx.y * blockDim.y + threadIdx.y; // for every item in row - if (row >= n || j >= k) return; - - const auto col = indices[row * k + j]; - const auto original = atomicAdd(&edges[row], (value_idx)1); - const auto transpose = atomicAdd(&edges[col], (value_idx)1); - - VAL[transpose] = VAL[original] = data[row * k + j]; - // Notice swapped ROW, COL since transpose - ROW[original] = row; - COL[original] = col; - - ROW[transpose] = col; - COL[transpose] = row; -} - -/** - * @brief Perform data + data.T on raw KNN data. - * The following steps are invoked: - * (1) Find how much space needed in each row - * (2) Compute final space needed (n*k + sum(row_sizes)) == 2*n*k - * (3) Allocate new space - * (4) Prepare edges for each new row - * (5) Perform final data + data.T operation - * (6) Return summed up VAL, COL, ROW - * - * @param knn_indices: Input knn distances(n, k) - * @param knn_dists: Input knn indices(n, k) - * @param n: Number of rows - * @param k: Number of n_neighbors - * @param out: Output COO Matrix class - * @param stream: Input cuda stream - * @param d_alloc device allocator for temporary buffers - */ -template -void from_knn_symmetrize_matrix(const value_idx *restrict knn_indices, - const value_t *restrict knn_dists, - const value_idx n, const int k, - COO *out, - cudaStream_t stream, - std::shared_ptr d_alloc) { - // (1) Find how much space needed in each row - // We look through all datapoints and increment the count for each row. - const dim3 threadsPerBlock(TPB_X, TPB_Y); - const dim3 numBlocks(raft::ceildiv(n, (value_idx)TPB_X), - raft::ceildiv(k, TPB_Y)); - - // Notice n+1 since we can reuse these arrays for transpose_edges, original_edges in step (4) - device_buffer row_sizes(d_alloc, stream, n); - CUDA_CHECK( - cudaMemsetAsync(row_sizes.data(), 0, sizeof(value_idx) * n, stream)); - - device_buffer row_sizes2(d_alloc, stream, n); - CUDA_CHECK( - cudaMemsetAsync(row_sizes2.data(), 0, sizeof(value_idx) * n, stream)); - - symmetric_find_size<<>>( - knn_dists, knn_indices, n, k, row_sizes.data(), row_sizes2.data()); - CUDA_CHECK(cudaPeekAtLastError()); - - reduce_find_size<<>>( - n, k, row_sizes.data(), row_sizes2.data()); - CUDA_CHECK(cudaPeekAtLastError()); - - // (2) Compute final space needed (n*k + sum(row_sizes)) == 2*n*k - // Notice we don't do any merging and leave the result as 2*NNZ - const auto NNZ = 2 * n * k; - - // (3) Allocate new space - out->allocate(NNZ, n, n, true, stream); - - // (4) Prepare edges for each new row - // This mirrors CSR matrix's row Pointer, were maximum bounds for each row - // are calculated as the cumulative rolling sum of the previous rows. - // Notice reusing old row_sizes2 memory - value_idx *edges = row_sizes2.data(); - thrust::device_ptr __edges = thrust::device_pointer_cast(edges); - thrust::device_ptr __row_sizes = - thrust::device_pointer_cast(row_sizes.data()); - - // Rolling cumulative sum - thrust::exclusive_scan(thrust::cuda::par.on(stream), __row_sizes, - __row_sizes + n, __edges); - - // (5) Perform final data + data.T operation in tandem with memcpying - symmetric_sum<<>>( - edges, knn_dists, knn_indices, out->vals(), out->cols(), out->rows(), n, k); - CUDA_CHECK(cudaPeekAtLastError()); -} - -}; // namespace Sparse -}; // namespace MLCommon +}; // namespace sparse +}; // namespace raft diff --git a/cpp/src_prims/sparse/csr.cuh b/cpp/src_prims/sparse/csr.cuh index e43bbd850d..b312810cf5 100644 --- a/cpp/src_prims/sparse/csr.cuh +++ b/cpp/src_prims/sparse/csr.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include