From 0cc01820ec11267f26053bd4e8d57b0ec46dc29e Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 14 Sep 2021 17:31:44 +0200 Subject: [PATCH 01/12] TSNE KL divergence --- cpp/include/cuml/manifold/tsne.h | 40 +++++++++++---------- cpp/src/tsne/barnes_hut_kernels.cuh | 6 +++- cpp/src/tsne/barnes_hut_tsne.cuh | 23 +++++++----- cpp/src/tsne/exact_kernels.cuh | 20 ++++++++--- cpp/src/tsne/exact_tsne.cuh | 36 +++++++++++++------ cpp/src/tsne/fft_kernels.cuh | 11 +++--- cpp/src/tsne/fft_tsne.cuh | 30 ++++++++++------ cpp/src/tsne/tsne.cu | 55 +++++++++++++++-------------- cpp/src/tsne/tsne_runner.cuh | 10 +++--- python/cuml/manifold/t_sne.pyx | 26 ++++++++------ 10 files changed, 158 insertions(+), 99 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index 6bd9ecb953..f18c86ca28 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -117,6 +117,7 @@ struct TSNEParams { * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model + * @return The Kullback–Leibler divergence * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs @@ -125,14 +126,14 @@ struct TSNEParams { * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -void TSNE_fit(const raft::handle_t& handle, - float* X, - float* Y, - int n, - int p, - int64_t* knn_indices, - float* knn_dists, - TSNEParams& params); +float TSNE_fit(const raft::handle_t& handle, + float* X, + float* Y, + int n, + int p, + int64_t* knn_indices, + float* knn_dists, + TSNEParams& params); /** * @brief Dimensionality reduction via TSNE using either Barnes Hut O(NlogN) @@ -149,6 +150,7 @@ void TSNE_fit(const raft::handle_t& handle, * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model + * @return The Kullback–Leibler divergence * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs @@ -157,16 +159,16 @@ void TSNE_fit(const raft::handle_t& handle, * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -void TSNE_fit_sparse(const raft::handle_t& handle, - int* indptr, - int* indices, - float* data, - float* Y, - int nnz, - int n, - int p, - int* knn_indices, - float* knn_dists, - TSNEParams& params); +float TSNE_fit_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + float* Y, + int nnz, + int n, + int p, + int* knn_indices, + float* knn_dists, + TSNEParams& params); } // namespace ML diff --git a/cpp/src/tsne/barnes_hut_kernels.cuh b/cpp/src/tsne/barnes_hut_kernels.cuh index 2680d3b456..521558d3a5 100644 --- a/cpp/src/tsne/barnes_hut_kernels.cuh +++ b/cpp/src/tsne/barnes_hut_kernels.cuh @@ -681,6 +681,7 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, const value_t* restrict Y2, value_t* restrict attract1, value_t* restrict attract2, + value_t* restrict kl_divergences, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -694,7 +695,9 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, // As a sum of squares, SED is mathematically >= 0. There might be a source of // NaNs upstream though, so until we find and fix them, enforce that trait. if (!(squared_euclidean_dist >= 0)) squared_euclidean_dist = 0.0f; - const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + 1.0f); + const value_t P = VAL[index]; + const value_t Q = __fdividef(1.0f, squared_euclidean_dist + 1.0f); // without normalization + const value_t PQ = P * Q; // TODO: Calculate Kullback-Leibler divergence // TODO: Convert attractive forces to CSR format @@ -702,6 +705,7 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, // Apply forces atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); atomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); + atomicAdd(&kl_divergences[i], P * log(P / Q)); } /** diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index 8170d89eb4..fb170b1494 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -36,14 +36,14 @@ namespace TSNE { */ template -void Barnes_Hut(value_t* VAL, - const value_idx* COL, - const value_idx* ROW, - const value_idx NNZ, - const raft::handle_t& handle, - value_t* Y, - const value_idx n, - const TSNEParams& params) +value_t Barnes_Hut(value_t* VAL, + const value_idx* COL, + const value_idx* ROW, + const value_idx NNZ, + const raft::handle_t& handle, + value_t* Y, + const value_idx n, + const TSNEParams& params) { cudaStream_t stream = handle.get_stream(); @@ -120,6 +120,8 @@ void Barnes_Hut(value_t* VAL, raft::copy(YY.data() + nnodes + 1, Y + n, n, stream); } + rmm::device_uvector kl_divergences(n, stream); + // Set cache levels for faster algorithm execution //--------------------------------------------------- CUDA_CHECK( @@ -273,6 +275,7 @@ void Barnes_Hut(value_t* VAL, YY.data() + nnodes + 1, attr_forces.data(), attr_forces.data() + n, + kl_divergences.data(), NNZ); CUDA_CHECK(cudaPeekAtLastError()); END_TIMER(attractive_time); @@ -302,6 +305,10 @@ void Barnes_Hut(value_t* VAL, // Copy final YY into true output Y raft::copy(Y, YY.data(), n, stream); raft::copy(Y + n, YY.data() + nnodes + 1, n, stream); + + value_t kl_div = + thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + return kl_div; } } // namespace TSNE diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index 177f57008e..968875e914 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -172,6 +172,7 @@ __global__ void attractive_kernel(const value_t* restrict VAL, const value_t* restrict Y, const value_t* restrict norm, value_t* restrict attract, + value_t* restrict kl_divergences, const value_idx NNZ, const value_idx n, const value_idx dim, @@ -192,11 +193,15 @@ __global__ void attractive_kernel(const value_t* restrict VAL, // TODO: Calculate Kullback-Leibler divergence // #863 - const value_t PQ = VAL[index] * __powf((1.0f + euclidean_d * recp_df), df_power); // P*Q + const value_t P = VAL[index]; + const value_t Q = __powf((1.0f + euclidean_d * recp_df), df_power); // without normalization + const value_t PQ = P * Q; // Apply forces - for (int k = 0; k < dim; k++) + for (int k = 0; k < dim; k++) { raft::myAtomicAdd(&attract[k * n + i], PQ * (Y[k * n + i] - Y[k * n + j])); + raft::myAtomicAdd(&kl_divergences[i], P * log(P / Q)); + } } /****************************************/ @@ -211,6 +216,7 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, const value_t* restrict norm, value_t* restrict attract1, value_t* restrict attract2, + value_t* restrict kl_divergences, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -224,11 +230,14 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, // TODO: Calculate Kullback-Leibler divergence // #863 - const value_t PQ = __fdividef(VAL[index], (1.0f + euclidean_d)); // P*Q + const value_t P = VAL[index]; + const value_t Q = __fdividef(1.0f, (1.0f + euclidean_d)); // without normalization + const value_t PQ = P * Q; // Apply forces raft::myAtomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); raft::myAtomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); + raft::myAtomicAdd(&kl_divergences[i], P * log(P / Q)); } /****************************************/ @@ -239,6 +248,7 @@ void attractive_forces(const value_t* restrict VAL, const value_t* restrict Y, const value_t* restrict norm, value_t* restrict attract, + value_t* restrict kl_divergences, const value_idx NNZ, const value_idx n, const value_idx dim, @@ -253,12 +263,12 @@ void attractive_forces(const value_t* restrict VAL, // For general embedding dimensions if (dim != 2) { attractive_kernel<<>>( - VAL, COL, ROW, Y, norm, attract, NNZ, n, dim, df_power, recp_df); + VAL, COL, ROW, Y, norm, attract, kl_divergences, NNZ, n, dim, df_power, recp_df); } // For special case dim == 2 else { attractive_kernel_2d<<>>( - VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, NNZ); + VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, kl_divergences, NNZ); } CUDA_CHECK(cudaPeekAtLastError()); } diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index b65e269f7d..5bb4f5477b 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -35,14 +35,14 @@ namespace TSNE { * @param[in] params: Parameters for TSNE model. */ template -void Exact_TSNE(value_t* VAL, - const value_idx* COL, - const value_idx* ROW, - const value_idx NNZ, - const raft::handle_t& handle, - value_t* Y, - const value_idx n, - const TSNEParams& params) +value_t Exact_TSNE(value_t* VAL, + const value_idx* COL, + const value_idx* ROW, + const value_idx NNZ, + const raft::handle_t& handle, + value_t* Y, + const value_idx n, + const TSNEParams& params) { cudaStream_t stream = handle.get_stream(); const value_idx dim = params.dim; @@ -69,6 +69,7 @@ void Exact_TSNE(value_t* VAL, thrust::fill(thrust::cuda::par.on(stream), begin, begin + n * dim, 1.0f); rmm::device_uvector gradient(n * dim, stream); + rmm::device_uvector kl_divergences(n, stream); //--------------------------------------------------- // Calculate degrees of freedom @@ -97,8 +98,19 @@ void Exact_TSNE(value_t* VAL, raft::linalg::rowNorm(norm.data(), Y, dim, n, raft::linalg::L2Norm, false, stream); // Compute attractive forces - TSNE::attractive_forces( - VAL, COL, ROW, Y, norm.data(), attract.data(), NNZ, n, dim, df_power, recp_df, stream); + TSNE::attractive_forces(VAL, + COL, + ROW, + Y, + norm.data(), + attract.data(), + kl_divergences.data(), + NNZ, + n, + dim, + df_power, + recp_df, + stream); // Compute repulsive forces const float Z = TSNE::repulsive_forces( Y, repel.data(), norm.data(), Z_sum.data(), n, dim, df_power, recp_df, stream); @@ -139,6 +151,10 @@ void Exact_TSNE(value_t* VAL, if (iter % 100 == 0) { CUML_LOG_DEBUG("Z at iter = %d = %f", iter, Z); } } } + + value_t kl_div = + thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + return kl_div; } } // namespace TSNE diff --git a/cpp/src/tsne/fft_kernels.cuh b/cpp/src/tsne/fft_kernels.cuh index fbc2aabb96..c939dfbffc 100644 --- a/cpp/src/tsne/fft_kernels.cuh +++ b/cpp/src/tsne/fft_kernels.cuh @@ -307,6 +307,7 @@ __global__ void compute_repulsive_forces_kernel( template __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, + value_t* __restrict__ kl_divergences, const value_t* __restrict__ pij, const value_idx* __restrict__ coo_rows, const value_idx* __restrict__ coo_cols, @@ -327,11 +328,13 @@ __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, value_t dx = ix - jx; value_t dy = iy - jy; - value_t denom = 1 + (dx * dx) + (dy * dy); + const value_t P = pij[TID]; + const value_t Q = __fdividef(1.0f, 1 + (dx * dx) + (dy * dy)); // without normalization + const value_t PQ = P * Q; - value_t pijqij = pij[TID] / denom; - atomicAdd(attr_forces + i, pijqij * dx); - atomicAdd(attr_forces + num_points + i, pijqij * dy); + atomicAdd(attr_forces + i, PQ * dx); + atomicAdd(attr_forces + num_points + i, PQ * dy); + if (kl_divergences) atomicAdd(kl_divergences + i, P * log(P / Q)); } template diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index b7e6b54009..8f448fdec7 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -152,14 +152,14 @@ std::pair min_max(const value_t* Y, const value_idx n, cudaStr * @param[in] params: Parameters for TSNE model. */ template -void FFT_TSNE(value_t* VAL, - const value_idx* COL, - const value_idx* ROW, - const value_idx NNZ, - const raft::handle_t& handle, - value_t* Y, - const value_idx n, - const TSNEParams& params) +value_t FFT_TSNE(value_t* VAL, + const value_idx* COL, + const value_idx* ROW, + const value_idx NNZ, + const raft::handle_t& handle, + value_t* Y, + const value_idx n, + const TSNEParams& params) { auto stream = handle.get_stream(); auto thrust_policy = handle.get_thrust_policy(); @@ -334,6 +334,7 @@ void FFT_TSNE(value_t* VAL, random_vector(Y, 0.0000f, 0.0001f, n * 2, stream, params.random_state); } + value_t kl_div = 0; for (int iter = 0; iter < params.max_iter; iter++) { // Compute charges Q_ij { @@ -513,8 +514,16 @@ void FFT_TSNE(value_t* VAL, // Compute attractive forces { auto num_blocks = raft::ceildiv(NNZ, (value_idx)NTHREADS_1024); - FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), VAL, ROW, COL, Y, n, NNZ); + if (iter != params.max_iter - 1) { // not last iter + FFT::compute_Pij_x_Qij_kernel<<>>( + attractive_forces_device.data(), (value_t*)nullptr, VAL, ROW, COL, Y, n, NNZ); + } else { // last iteration + rmm::device_uvector kl_divergences(n, stream); + FFT::compute_Pij_x_Qij_kernel<<>>( + attractive_forces_device.data(), kl_divergences.data(), VAL, ROW, COL, Y, n, NNZ); + kl_div = + thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + } } // Apply Forces @@ -572,6 +581,7 @@ void FFT_TSNE(value_t* VAL, CUFFT_TRY(cufftDestroy(plan_kernel_tilde)); CUFFT_TRY(cufftDestroy(plan_dft)); CUFFT_TRY(cufftDestroy(plan_idft)); + return kl_div; } } // namespace TSNE diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 8f389d7437..41d83b7270 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -20,24 +20,24 @@ namespace ML { template -void _fit(const raft::handle_t& handle, - tsne_input& input, - knn_graph& k_graph, - TSNEParams& params) +value_t _fit(const raft::handle_t& handle, + tsne_input& input, + knn_graph& k_graph, + TSNEParams& params) { TSNE_runner runner(handle, input, k_graph, params); - runner.run(); + return runner.run(); // returns the Kullback–Leibler divergence } -void TSNE_fit(const raft::handle_t& handle, - float* X, - float* Y, - int n, - int p, - int64_t* knn_indices, - float* knn_dists, - TSNEParams& params) +float TSNE_fit(const raft::handle_t& handle, + float* X, + float* Y, + int n, + int p, + int64_t* knn_indices, + float* knn_dists, + TSNEParams& params) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && X != NULL && Y != NULL, "Wrong input args"); @@ -45,20 +45,22 @@ void TSNE_fit(const raft::handle_t& handle, manifold_dense_inputs_t input(X, Y, n, p); knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); - _fit, knn_indices_dense_t, float>(handle, input, k_graph, params); + return _fit, knn_indices_dense_t, float>( + handle, input, k_graph, params); + // returns the Kullback–Leibler divergence } -void TSNE_fit_sparse(const raft::handle_t& handle, - int* indptr, - int* indices, - float* data, - float* Y, - int nnz, - int n, - int p, - int* knn_indices, - float* knn_dists, - TSNEParams& params) +float TSNE_fit_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + float* Y, + int nnz, + int n, + int p, + int* knn_indices, + float* knn_dists, + TSNEParams& params) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && indptr != NULL && indices != NULL && data != NULL && Y != NULL, @@ -67,8 +69,9 @@ void TSNE_fit_sparse(const raft::handle_t& handle, manifold_sparse_inputs_t input(indptr, indices, data, Y, nnz, n, p); knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); - _fit, knn_indices_sparse_t, float>( + return _fit, knn_indices_sparse_t, float>( handle, input, k_graph, params); + // returns the Kullback–Leibler divergence } } // namespace ML diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 3ff8f322bc..811684538d 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -74,7 +74,7 @@ class TSNE_runner { " might be a bit strange..."); } - void run() + value_t run() { distance_and_perplexity(); @@ -86,11 +86,11 @@ class TSNE_runner { switch (params.algorithm) { case TSNE_ALGORITHM::BARNES_HUT: - TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, params); - break; - case TSNE_ALGORITHM::FFT: TSNE::FFT_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, params); break; - case TSNE_ALGORITHM::EXACT: TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, params); break; + return TSNE::Barnes_Hut(VAL, COL, ROW, NNZ, handle, Y, n, params); + case TSNE_ALGORITHM::FFT: return TSNE::FFT_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, params); + case TSNE_ALGORITHM::EXACT: return TSNE::Exact_TSNE(VAL, COL, ROW, NNZ, handle, Y, n, params); } + return 0; } private: diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 03aa7d8c6b..5bd41adee0 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -84,7 +84,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": cdef extern from "cuml/manifold/tsne.h" namespace "ML": - cdef void TSNE_fit( + cdef float TSNE_fit( handle_t &handle, float *X, float *Y, @@ -94,7 +94,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": float* knn_dists, TSNEParams ¶ms) except + - cdef void TSNE_fit_sparse( + cdef float TSNE_fit_sparse( const handle_t &handle, int *indptr, int *indices, @@ -493,8 +493,9 @@ class TSNE(Base, cdef TSNEParams* params = \ self._build_tsne_params(algo) + cdef float kl_divergence = 0 if self.sparse_fit: - TSNE_fit_sparse(handle_[0], + kl_divergence = TSNE_fit_sparse(handle_[0], self.X_m.indptr.ptr, self.X_m.indices.ptr, self.X_m.data.ptr, @@ -506,18 +507,21 @@ class TSNE(Base, knn_dists_raw, deref(params)) else: - TSNE_fit(handle_[0], - self.X_m.ptr, - embed_ptr, - n, - p, - knn_indices_raw, - knn_dists_raw, - deref(params)) + kl_divergence = TSNE_fit(handle_[0], + self.X_m.ptr, + embed_ptr, + n, + p, + knn_indices_raw, + knn_dists_raw, + deref(params)) self.handle.sync() free(params) + self.kl_divergence_ = kl_divergence + if self.verbose: + print("[t-SNE] KL divergence: {}".format(kl_divergence)) return self @generate_docstring(convert_dtype_cast='np.float32', From 6ebd360e4021326865cec412c34a8226d2067bd4 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 15 Sep 2021 19:03:00 +0200 Subject: [PATCH 02/12] Adding normalization of Q --- cpp/src/tsne/barnes_hut_kernels.cuh | 19 ++++++++------ cpp/src/tsne/barnes_hut_tsne.cuh | 18 ++++++++++++- cpp/src/tsne/exact_kernels.cuh | 40 ++++++++++++++++------------- cpp/src/tsne/exact_tsne.cuh | 20 ++++++++++++++- cpp/src/tsne/fft_kernels.cuh | 16 +++++++----- cpp/src/tsne/fft_tsne.cuh | 27 +++++++++++++++---- cpp/src/tsne/utils.cuh | 23 +++++++++++++++++ 7 files changed, 124 insertions(+), 39 deletions(-) diff --git a/cpp/src/tsne/barnes_hut_kernels.cuh b/cpp/src/tsne/barnes_hut_kernels.cuh index 521558d3a5..334bf5ba5a 100644 --- a/cpp/src/tsne/barnes_hut_kernels.cuh +++ b/cpp/src/tsne/barnes_hut_kernels.cuh @@ -681,7 +681,8 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, const value_t* restrict Y2, value_t* restrict attract1, value_t* restrict attract2, - value_t* restrict kl_divergences, + value_t* restrict Qs, + value_t* restrict Qs_norm, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -695,17 +696,19 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, // As a sum of squares, SED is mathematically >= 0. There might be a source of // NaNs upstream though, so until we find and fix them, enforce that trait. if (!(squared_euclidean_dist >= 0)) squared_euclidean_dist = 0.0f; - const value_t P = VAL[index]; - const value_t Q = __fdividef(1.0f, squared_euclidean_dist + 1.0f); // without normalization - const value_t PQ = P * Q; - - // TODO: Calculate Kullback-Leibler divergence - // TODO: Convert attractive forces to CSR format + const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + 1.0f); // Apply forces atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); atomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); - atomicAdd(&kl_divergences[i], P * log(P / Q)); + + if (Qs) { // check if Kl div calculation is necessary + const value_t Q_unnormalized = __expf(-squared_euclidean_dist); + Qs[index] = Q_unnormalized; + atomicAdd(&Qs_norm[i], Q_unnormalized); + } + + // TODO: Convert attractive forces to CSR format } /** diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index fb170b1494..05465d8bb7 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -120,6 +120,8 @@ value_t Barnes_Hut(value_t* VAL, raft::copy(YY.data() + nnodes + 1, Y + n, n, stream); } + rmm::device_uvector Qs(NNZ, stream); + rmm::device_uvector Qs_norm(n, stream); rmm::device_uvector kl_divergences(n, stream); // Set cache levels for faster algorithm execution @@ -267,6 +269,13 @@ value_t Barnes_Hut(value_t* VAL, START_TIMER; // TODO: Calculate Kullback-Leibler divergence // For general embedding dimensions + bool last_iter = iter == params.max_iter - 1; + if (last_iter) { + CUDA_CHECK(cudaMemsetAsync(Qs_norm.data(), 0, Qs_norm.size() * sizeof(value_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(kl_divergences.data(), 0, kl_divergences.size() * sizeof(value_t), stream)); + } + BH::attractive_kernel_bh<<>>( VAL, COL, @@ -275,11 +284,18 @@ value_t Barnes_Hut(value_t* VAL, YY.data() + nnodes + 1, attr_forces.data(), attr_forces.data() + n, - kl_divergences.data(), + last_iter ? Qs.data() : nullptr, + last_iter ? Qs_norm.data() : nullptr, NNZ); CUDA_CHECK(cudaPeekAtLastError()); END_TIMER(attractive_time); + if (last_iter) { + compute_kl_div<<>>( + VAL, ROW, Qs.data(), Qs_norm.data(), kl_divergences.data(), NNZ); + CUDA_CHECK(cudaPeekAtLastError()); + } + START_TIMER; BH::IntegrationKernel<<>>(learning_rate, momentum, diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index 968875e914..0bc653a9da 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -172,7 +172,8 @@ __global__ void attractive_kernel(const value_t* restrict VAL, const value_t* restrict Y, const value_t* restrict norm, value_t* restrict attract, - value_t* restrict kl_divergences, + value_t* restrict Qs, + value_t* restrict Qs_norm, const value_idx NNZ, const value_idx n, const value_idx dim, @@ -191,16 +192,17 @@ __global__ void attractive_kernel(const value_t* restrict VAL, d += Y[k * n + i] * Y[k * n + j]; const value_t euclidean_d = -2.0f * d + norm[i] + norm[j]; - // TODO: Calculate Kullback-Leibler divergence - // #863 - const value_t P = VAL[index]; - const value_t Q = __powf((1.0f + euclidean_d * recp_df), df_power); // without normalization - const value_t PQ = P * Q; + const value_t PQ = __fdividef(VAL[index], __powf((1.0f + euclidean_d * recp_df), df_power)); // Apply forces for (int k = 0; k < dim; k++) { raft::myAtomicAdd(&attract[k * n + i], PQ * (Y[k * n + i] - Y[k * n + j])); - raft::myAtomicAdd(&kl_divergences[i], P * log(P / Q)); + } + + if (Qs) { // check if Kl div calculation is necessary + const value_t Q_unnormalized = __expf(-euclidean_d); + Qs[index] = Q_unnormalized; + atomicAdd(&Qs_norm[i], Q_unnormalized); } } @@ -216,7 +218,8 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, const value_t* restrict norm, value_t* restrict attract1, value_t* restrict attract2, - value_t* restrict kl_divergences, + value_t* restrict Qs, + value_t* restrict Qs_norm, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -227,17 +230,17 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, // TODO: can provide any distance ie cosine // #862 const value_t euclidean_d = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); - - // TODO: Calculate Kullback-Leibler divergence - // #863 - const value_t P = VAL[index]; - const value_t Q = __fdividef(1.0f, (1.0f + euclidean_d)); // without normalization - const value_t PQ = P * Q; + const value_t PQ = __fdividef(VAL[index], (1.0f + euclidean_d)); // Apply forces raft::myAtomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); raft::myAtomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); - raft::myAtomicAdd(&kl_divergences[i], P * log(P / Q)); + + if (Qs) { // check if Kl div calculation is necessary + const value_t Q_unnormalized = __expf(-euclidean_d); + Qs[index] = Q_unnormalized; + atomicAdd(&Qs_norm[i], Q_unnormalized); + } } /****************************************/ @@ -248,7 +251,8 @@ void attractive_forces(const value_t* restrict VAL, const value_t* restrict Y, const value_t* restrict norm, value_t* restrict attract, - value_t* restrict kl_divergences, + value_t* restrict Qs, + value_t* restrict Qs_norm, const value_idx NNZ, const value_idx n, const value_idx dim, @@ -263,12 +267,12 @@ void attractive_forces(const value_t* restrict VAL, // For general embedding dimensions if (dim != 2) { attractive_kernel<<>>( - VAL, COL, ROW, Y, norm, attract, kl_divergences, NNZ, n, dim, df_power, recp_df); + VAL, COL, ROW, Y, norm, attract, Qs, Qs_norm, NNZ, n, dim, df_power, recp_df); } // For special case dim == 2 else { attractive_kernel_2d<<>>( - VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, kl_divergences, NNZ); + VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, Qs, Qs_norm, NNZ); } CUDA_CHECK(cudaPeekAtLastError()); } diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 5bb4f5477b..3bf1eb22e9 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -69,6 +69,9 @@ value_t Exact_TSNE(value_t* VAL, thrust::fill(thrust::cuda::par.on(stream), begin, begin + n * dim, 1.0f); rmm::device_uvector gradient(n * dim, stream); + + rmm::device_uvector Qs(NNZ, stream); + rmm::device_uvector Qs_norm(n, stream); rmm::device_uvector kl_divergences(n, stream); //--------------------------------------------------- @@ -97,6 +100,14 @@ value_t Exact_TSNE(value_t* VAL, // Get row norm of Y raft::linalg::rowNorm(norm.data(), Y, dim, n, raft::linalg::L2Norm, false, stream); + bool last_iter = iter == params.max_iter - 1; + + if (last_iter) { + CUDA_CHECK(cudaMemsetAsync(Qs_norm.data(), 0, Qs_norm.size() * sizeof(value_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(kl_divergences.data(), 0, kl_divergences.size() * sizeof(value_t), stream)); + } + // Compute attractive forces TSNE::attractive_forces(VAL, COL, @@ -104,13 +115,20 @@ value_t Exact_TSNE(value_t* VAL, Y, norm.data(), attract.data(), - kl_divergences.data(), + last_iter ? Qs.data() : nullptr, + last_iter ? Qs_norm.data() : nullptr, NNZ, n, dim, df_power, recp_df, stream); + + if (last_iter) { + compute_kl_div<<>>( + VAL, ROW, Qs.data(), Qs_norm.data(), kl_divergences.data(), NNZ); + } + // Compute repulsive forces const float Z = TSNE::repulsive_forces( Y, repel.data(), norm.data(), Z_sum.data(), n, dim, df_power, recp_df, stream); diff --git a/cpp/src/tsne/fft_kernels.cuh b/cpp/src/tsne/fft_kernels.cuh index c939dfbffc..90b30a4f17 100644 --- a/cpp/src/tsne/fft_kernels.cuh +++ b/cpp/src/tsne/fft_kernels.cuh @@ -307,7 +307,8 @@ __global__ void compute_repulsive_forces_kernel( template __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, - value_t* __restrict__ kl_divergences, + value_t* __restrict__ Qs, + value_t* __restrict__ Qs_norm, const value_t* __restrict__ pij, const value_idx* __restrict__ coo_rows, const value_idx* __restrict__ coo_cols, @@ -328,13 +329,16 @@ __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, value_t dx = ix - jx; value_t dy = iy - jy; - const value_t P = pij[TID]; - const value_t Q = __fdividef(1.0f, 1 + (dx * dx) + (dy * dy)); // without normalization - const value_t PQ = P * Q; - + const value_t squared_euclidean_dist = (dx * dx) + (dy * dy); + const value_t PQ = __fdividef(pij[TID], 1 + squared_euclidean_dist); atomicAdd(attr_forces + i, PQ * dx); atomicAdd(attr_forces + num_points + i, PQ * dy); - if (kl_divergences) atomicAdd(kl_divergences + i, P * log(P / Q)); + + if (Qs) { // check if Kl div calculation is necessary + const value_t Q_unnormalized = __expf(-squared_euclidean_dist); + Qs[TID] = Q_unnormalized; + atomicAdd(&Qs_norm[i], Q_unnormalized); + } } template diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index 8f448fdec7..8d122380ec 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -514,15 +514,32 @@ value_t FFT_TSNE(value_t* VAL, // Compute attractive forces { auto num_blocks = raft::ceildiv(NNZ, (value_idx)NTHREADS_1024); - if (iter != params.max_iter - 1) { // not last iter - FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), (value_t*)nullptr, VAL, ROW, COL, Y, n, NNZ); - } else { // last iteration + bool last_iter = iter == params.max_iter - 1; + if (last_iter) { + rmm::device_uvector Qs(NNZ, stream); + rmm::device_uvector Qs_norm(n, stream); rmm::device_uvector kl_divergences(n, stream); + CUDA_CHECK(cudaMemsetAsync(Qs_norm.data(), 0, Qs_norm.size() * sizeof(value_t), stream)); + CUDA_CHECK(cudaMemsetAsync( + kl_divergences.data(), 0, kl_divergences.size() * sizeof(value_t), stream)); + FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), kl_divergences.data(), VAL, ROW, COL, Y, n, NNZ); + attractive_forces_device.data(), Qs.data(), Qs_norm.data(), VAL, ROW, COL, Y, n, NNZ); + compute_kl_div<<>>( + VAL, ROW, Qs.data(), Qs_norm.data(), kl_divergences.data(), NNZ); kl_div = thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + } else { + FFT::compute_Pij_x_Qij_kernel<<>>( + attractive_forces_device.data(), + (value_t*)nullptr, + (value_t*)nullptr, + VAL, + ROW, + COL, + Y, + n, + NNZ); } } diff --git a/cpp/src/tsne/utils.cuh b/cpp/src/tsne/utils.cuh index 583e7719f9..2dbcbee3cd 100644 --- a/cpp/src/tsne/utils.cuh +++ b/cpp/src/tsne/utils.cuh @@ -176,3 +176,26 @@ __global__ void min_max_kernel( atomicMax(max, block_max); } } + +/** + * Compute KL divergence + */ +template +__global__ void compute_kl_div(const value_t* restrict Ps, + const value_idx* restrict ROW, + value_t* restrict Qs, + value_t* restrict Qs_norm, + value_t* restrict kl_divergences, + const value_idx NNZ) +{ + const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; + if (index >= NNZ) return; + const auto i = ROW[index]; + + if (Qs) { // check if Kl div calculation is necessary + const value_t P = max(Ps[index], FLT_EPSILON); + const value_t Q = max(__fdividef(Qs[index], Qs_norm[i]), FLT_EPSILON); + + kl_divergences[index] = P * __logf(__fdividef(P, Q)); + } +} \ No newline at end of file From d020066dc74a4dccd95a5bb00d1f16d91dd8651d Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 24 Sep 2021 18:51:08 +0200 Subject: [PATCH 03/12] Compute KL div --- cpp/src/tsne/barnes_hut_kernels.cuh | 15 ++++++------ cpp/src/tsne/barnes_hut_tsne.cuh | 22 ++++++----------- cpp/src/tsne/distances.cuh | 38 ----------------------------- cpp/src/tsne/exact_kernels.cuh | 27 +++++++++----------- cpp/src/tsne/exact_tsne.cuh | 23 ++++++----------- cpp/src/tsne/fft_kernels.cuh | 10 +++----- cpp/src/tsne/fft_tsne.cuh | 29 +++++++--------------- cpp/src/tsne/tsne_runner.cuh | 16 ++++++------ cpp/src/tsne/utils.cuh | 14 +++-------- 9 files changed, 59 insertions(+), 135 deletions(-) diff --git a/cpp/src/tsne/barnes_hut_kernels.cuh b/cpp/src/tsne/barnes_hut_kernels.cuh index 334bf5ba5a..3982c4bdfb 100644 --- a/cpp/src/tsne/barnes_hut_kernels.cuh +++ b/cpp/src/tsne/barnes_hut_kernels.cuh @@ -682,7 +682,6 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, value_t* restrict attract1, value_t* restrict attract2, value_t* restrict Qs, - value_t* restrict Qs_norm, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -696,16 +695,16 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, // As a sum of squares, SED is mathematically >= 0. There might be a source of // NaNs upstream though, so until we find and fix them, enforce that trait. if (!(squared_euclidean_dist >= 0)) squared_euclidean_dist = 0.0f; - const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + 1.0f); + const value_t dof = 1.0f; + const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + dof); // Apply forces - atomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); - atomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); + atomicAdd(&attract1[i], PQ * y1d); + atomicAdd(&attract2[i], PQ * y2d); - if (Qs) { // check if Kl div calculation is necessary - const value_t Q_unnormalized = __expf(-squared_euclidean_dist); - Qs[index] = Q_unnormalized; - atomicAdd(&Qs_norm[i], Q_unnormalized); + if (Qs) { // when computing KL div + value_t Q_unnormalized = __fdividef(dof, dof + squared_euclidean_dist); + Qs[index] = Q_unnormalized; } // TODO: Convert attractive forces to CSR format diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index 05465d8bb7..2305e9cc30 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -17,6 +17,7 @@ #include #include +#include #include "barnes_hut_kernels.cuh" #include "utils.cuh" @@ -120,9 +121,9 @@ value_t Barnes_Hut(value_t* VAL, raft::copy(YY.data() + nnodes + 1, Y + n, n, stream); } - rmm::device_uvector Qs(NNZ, stream); - rmm::device_uvector Qs_norm(n, stream); - rmm::device_uvector kl_divergences(n, stream); + rmm::device_uvector tmp(NNZ, stream); + value_t* Qs = tmp.data(); + value_t* KL_divs = tmp.data(); // Set cache levels for faster algorithm execution //--------------------------------------------------- @@ -263,18 +264,12 @@ value_t Barnes_Hut(value_t* VAL, START_TIMER; BH::Find_Normalization<<<1, 1, 0, stream>>>(Z_norm.data(), n); CUDA_CHECK(cudaPeekAtLastError()); - END_TIMER(Reduction_time); START_TIMER; // TODO: Calculate Kullback-Leibler divergence // For general embedding dimensions bool last_iter = iter == params.max_iter - 1; - if (last_iter) { - CUDA_CHECK(cudaMemsetAsync(Qs_norm.data(), 0, Qs_norm.size() * sizeof(value_t), stream)); - CUDA_CHECK( - cudaMemsetAsync(kl_divergences.data(), 0, kl_divergences.size() * sizeof(value_t), stream)); - } BH::attractive_kernel_bh<<>>( VAL, @@ -284,15 +279,15 @@ value_t Barnes_Hut(value_t* VAL, YY.data() + nnodes + 1, attr_forces.data(), attr_forces.data() + n, - last_iter ? Qs.data() : nullptr, - last_iter ? Qs_norm.data() : nullptr, + last_iter ? Qs : nullptr, NNZ); CUDA_CHECK(cudaPeekAtLastError()); END_TIMER(attractive_time); if (last_iter) { + raft::linalg::scalarMultiply(Qs, Qs, Z_norm.value(stream), NNZ, stream); compute_kl_div<<>>( - VAL, ROW, Qs.data(), Qs_norm.data(), kl_divergences.data(), NNZ); + VAL, Qs, KL_divs, NNZ); CUDA_CHECK(cudaPeekAtLastError()); } @@ -322,8 +317,7 @@ value_t Barnes_Hut(value_t* VAL, raft::copy(Y, YY.data(), n, stream); raft::copy(Y + n, YY.data() + nnodes + 1, n, stream); - value_t kl_div = - thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + value_t kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); return kl_div; } diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 15acd2dda6..e33d76f2a6 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -137,40 +137,6 @@ void get_distances(const raft::handle_t& handle, throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet."); } -/** - * @brief Find the maximum element in the distances matrix, then divide all entries by this. - * This promotes exp(distances) to not explode. - * @param[in] n: The number of rows in the data X. - * @param[in] distances: The output sorted distances from KNN. - * @param[in] n_neighbors: The number of nearest neighbors you want. - * @param[in] stream: The GPU stream. - */ -template -void normalize_distances(const value_idx n, - value_t* distances, - const int n_neighbors, - cudaStream_t stream) -{ - // Now D / max(abs(D)) to allow exp(D) to not explode - - auto policy = rmm::exec_policy(stream); - - auto functional_abs = [] __device__(const value_t& x) { return abs(x); }; - - value_t maxNorm = thrust::transform_reduce( - policy, distances, distances + n * n_neighbors, functional_abs, 0.0f, thrust::maximum()); - - if (maxNorm == 0.0f) { maxNorm = 1.0f; } - - thrust::constant_iterator division_iterator(1.0f / maxNorm); - thrust::transform(policy, - distances, - distances + n * n_neighbors, - division_iterator, - distances, - thrust::multiplies()); -} - /** * @brief Performs P + P.T. * @param[in] P: The perplexity matrix (n, k) @@ -191,10 +157,6 @@ void symmetrize_perplexity(float* P, cudaStream_t stream, const raft::handle_t& handle) { - // Perform (P + P.T) / P_sum * early_exaggeration - const value_t div = 1.0f / (2.0f * n); - raft::linalg::scalarMultiply(P, P, div, n * k, stream); - // Symmetrize to form P + P.T raft::sparse::linalg::from_knn_symmetrize_matrix( indices, P, n, k, COO_Matrix, stream); diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index 0bc653a9da..10d031ac1e 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -173,7 +173,6 @@ __global__ void attractive_kernel(const value_t* restrict VAL, const value_t* restrict norm, value_t* restrict attract, value_t* restrict Qs, - value_t* restrict Qs_norm, const value_idx NNZ, const value_idx n, const value_idx dim, @@ -192,17 +191,15 @@ __global__ void attractive_kernel(const value_t* restrict VAL, d += Y[k * n + i] * Y[k * n + j]; const value_t euclidean_d = -2.0f * d + norm[i] + norm[j]; - const value_t PQ = __fdividef(VAL[index], __powf((1.0f + euclidean_d * recp_df), df_power)); - + const value_t Q_unnormalized = __powf((euclidean_d * recp_df) + 1.0f, df_power); + const value_t PQ = __fdividef(VAL[index], Q_unnormalized); // Apply forces for (int k = 0; k < dim; k++) { raft::myAtomicAdd(&attract[k * n + i], PQ * (Y[k * n + i] - Y[k * n + j])); } - if (Qs) { // check if Kl div calculation is necessary - const value_t Q_unnormalized = __expf(-euclidean_d); - Qs[index] = Q_unnormalized; - atomicAdd(&Qs_norm[i], Q_unnormalized); + if (Qs) { // when computing KL div + Qs[index] = Q_unnormalized; } } @@ -219,7 +216,6 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, value_t* restrict attract1, value_t* restrict attract2, value_t* restrict Qs, - value_t* restrict Qs_norm, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; @@ -230,16 +226,16 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, // TODO: can provide any distance ie cosine // #862 const value_t euclidean_d = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); - const value_t PQ = __fdividef(VAL[index], (1.0f + euclidean_d)); + + const value_t Q_unnormalized = 1.0f + euclidean_d; + const value_t PQ = __fdividef(VAL[index], Q_unnormalized); // Apply forces raft::myAtomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); raft::myAtomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); - if (Qs) { // check if Kl div calculation is necessary - const value_t Q_unnormalized = __expf(-euclidean_d); - Qs[index] = Q_unnormalized; - atomicAdd(&Qs_norm[i], Q_unnormalized); + if (Qs) { // when computing KL div + Qs[index] = Q_unnormalized; } } @@ -252,7 +248,6 @@ void attractive_forces(const value_t* restrict VAL, const value_t* restrict norm, value_t* restrict attract, value_t* restrict Qs, - value_t* restrict Qs_norm, const value_idx NNZ, const value_idx n, const value_idx dim, @@ -267,12 +262,12 @@ void attractive_forces(const value_t* restrict VAL, // For general embedding dimensions if (dim != 2) { attractive_kernel<<>>( - VAL, COL, ROW, Y, norm, attract, Qs, Qs_norm, NNZ, n, dim, df_power, recp_df); + VAL, COL, ROW, Y, norm, attract, Qs, NNZ, n, dim, df_power, recp_df); } // For special case dim == 2 else { attractive_kernel_2d<<>>( - VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, Qs, Qs_norm, NNZ); + VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, Qs, NNZ); } CUDA_CHECK(cudaPeekAtLastError()); } diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 3bf1eb22e9..c0e03822d8 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -70,9 +70,9 @@ value_t Exact_TSNE(value_t* VAL, rmm::device_uvector gradient(n * dim, stream); - rmm::device_uvector Qs(NNZ, stream); - rmm::device_uvector Qs_norm(n, stream); - rmm::device_uvector kl_divergences(n, stream); + rmm::device_uvector tmp(NNZ, stream); + value_t* Qs = tmp.data(); + value_t* KL_divs = tmp.data(); //--------------------------------------------------- // Calculate degrees of freedom @@ -102,12 +102,6 @@ value_t Exact_TSNE(value_t* VAL, bool last_iter = iter == params.max_iter - 1; - if (last_iter) { - CUDA_CHECK(cudaMemsetAsync(Qs_norm.data(), 0, Qs_norm.size() * sizeof(value_t), stream)); - CUDA_CHECK( - cudaMemsetAsync(kl_divergences.data(), 0, kl_divergences.size() * sizeof(value_t), stream)); - } - // Compute attractive forces TSNE::attractive_forces(VAL, COL, @@ -115,8 +109,7 @@ value_t Exact_TSNE(value_t* VAL, Y, norm.data(), attract.data(), - last_iter ? Qs.data() : nullptr, - last_iter ? Qs_norm.data() : nullptr, + last_iter ? Qs : nullptr, NNZ, n, dim, @@ -125,8 +118,10 @@ value_t Exact_TSNE(value_t* VAL, stream); if (last_iter) { + value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); compute_kl_div<<>>( - VAL, ROW, Qs.data(), Qs_norm.data(), kl_divergences.data(), NNZ); + VAL, Qs, KL_divs, NNZ); } // Compute repulsive forces @@ -169,9 +164,7 @@ value_t Exact_TSNE(value_t* VAL, if (iter % 100 == 0) { CUML_LOG_DEBUG("Z at iter = %d = %f", iter, Z); } } } - - value_t kl_div = - thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + value_t kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); return kl_div; } diff --git a/cpp/src/tsne/fft_kernels.cuh b/cpp/src/tsne/fft_kernels.cuh index 90b30a4f17..d353c2f1a7 100644 --- a/cpp/src/tsne/fft_kernels.cuh +++ b/cpp/src/tsne/fft_kernels.cuh @@ -308,7 +308,6 @@ __global__ void compute_repulsive_forces_kernel( template __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, value_t* __restrict__ Qs, - value_t* __restrict__ Qs_norm, const value_t* __restrict__ pij, const value_idx* __restrict__ coo_rows, const value_idx* __restrict__ coo_cols, @@ -330,14 +329,13 @@ __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, value_t dy = iy - jy; const value_t squared_euclidean_dist = (dx * dx) + (dy * dy); - const value_t PQ = __fdividef(pij[TID], 1 + squared_euclidean_dist); + const value_t Q_unnormalized = 1 + squared_euclidean_dist; + const value_t PQ = __fdividef(pij[TID], Q_unnormalized); atomicAdd(attr_forces + i, PQ * dx); atomicAdd(attr_forces + num_points + i, PQ * dy); - if (Qs) { // check if Kl div calculation is necessary - const value_t Q_unnormalized = __expf(-squared_euclidean_dist); - Qs[TID] = Q_unnormalized; - atomicAdd(&Qs_norm[i], Q_unnormalized); + if (Qs) { // when computing KL div + Qs[TID] = Q_unnormalized; } } diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index 8d122380ec..afd741e4c6 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -516,30 +516,19 @@ value_t FFT_TSNE(value_t* VAL, auto num_blocks = raft::ceildiv(NNZ, (value_idx)NTHREADS_1024); bool last_iter = iter == params.max_iter - 1; if (last_iter) { - rmm::device_uvector Qs(NNZ, stream); - rmm::device_uvector Qs_norm(n, stream); - rmm::device_uvector kl_divergences(n, stream); - CUDA_CHECK(cudaMemsetAsync(Qs_norm.data(), 0, Qs_norm.size() * sizeof(value_t), stream)); - CUDA_CHECK(cudaMemsetAsync( - kl_divergences.data(), 0, kl_divergences.size() * sizeof(value_t), stream)); + rmm::device_uvector tmp(NNZ, stream); + value_t* Qs = tmp.data(); + value_t* KL_divs = tmp.data(); FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), Qs.data(), Qs_norm.data(), VAL, ROW, COL, Y, n, NNZ); - compute_kl_div<<>>( - VAL, ROW, Qs.data(), Qs_norm.data(), kl_divergences.data(), NNZ); - kl_div = - thrust::reduce(handle.get_thrust_policy(), kl_divergences.begin(), kl_divergences.end()); + attractive_forces_device.data(), Qs, VAL, ROW, COL, Y, n, NNZ); + value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); + compute_kl_div<<>>(VAL, Qs, KL_divs, NNZ); + kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); } else { FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), - (value_t*)nullptr, - (value_t*)nullptr, - VAL, - ROW, - COL, - Y, - n, - NNZ); + attractive_forces_device.data(), (value_t*)nullptr, VAL, ROW, COL, Y, n, NNZ); } } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 811684538d..05065b4059 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -133,14 +133,6 @@ class TSNE_runner { //--------------------------------------------------- END_TIMER(DistancesTime); - START_TIMER; - //--------------------------------------------------- - // Normalize distances - CUML_LOG_DEBUG("Now normalizing distances so exp(D) doesn't explode."); - TSNE::normalize_distances(n, k_graph.knn_dists, params.n_neighbors, stream); - //--------------------------------------------------- - END_TIMER(NormalizeTime); - START_TIMER; //--------------------------------------------------- // Optimal perplexity @@ -158,6 +150,14 @@ class TSNE_runner { //--------------------------------------------------- END_TIMER(PerplexityTime); + START_TIMER; + //--------------------------------------------------- + // Normalize perplexity to prepare for symmetrization + value_t P_sum = thrust::reduce(rmm::exec_policy(stream), P.begin(), P.end()); + raft::linalg::scalarMultiply(P.data(), P.data(), 1.0f / (2.0f * P_sum), P.size(), stream); + //--------------------------------------------------- + END_TIMER(NormalizeTime); + START_TIMER; //--------------------------------------------------- // Convert data to COO layout diff --git a/cpp/src/tsne/utils.cuh b/cpp/src/tsne/utils.cuh index 2dbcbee3cd..bc0488e8a9 100644 --- a/cpp/src/tsne/utils.cuh +++ b/cpp/src/tsne/utils.cuh @@ -182,20 +182,14 @@ __global__ void min_max_kernel( */ template __global__ void compute_kl_div(const value_t* restrict Ps, - const value_idx* restrict ROW, value_t* restrict Qs, - value_t* restrict Qs_norm, value_t* restrict kl_divergences, const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; - const auto i = ROW[index]; + const value_t P = Ps[index]; + const value_t Q = max(Qs[index], FLT_EPSILON); - if (Qs) { // check if Kl div calculation is necessary - const value_t P = max(Ps[index], FLT_EPSILON); - const value_t Q = max(__fdividef(Qs[index], Qs_norm[i]), FLT_EPSILON); - - kl_divergences[index] = P * __logf(__fdividef(P, Q)); - } -} \ No newline at end of file + kl_divergences[index] = P * __logf(__fdividef(max(P, FLT_EPSILON), Q)); +} From 9e3c81e851bc3e85e038f99d1856174c67cc440f Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 28 Sep 2021 16:49:22 +0200 Subject: [PATCH 04/12] Fix barnes hut --- cpp/src/tsne/barnes_hut_tsne.cuh | 5 ++++- cpp/src/tsne/distances.cuh | 8 -------- cpp/src/tsne/exact_tsne.cuh | 2 ++ python/cuml/test/test_tsne.py | 4 +++- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index 2305e9cc30..7b5e36a1bf 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -285,7 +285,10 @@ value_t Barnes_Hut(value_t* VAL, END_TIMER(attractive_time); if (last_iter) { - raft::linalg::scalarMultiply(Qs, Qs, Z_norm.value(stream), NNZ, stream); + value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); + value_t P_sum = thrust::reduce(rmm::exec_policy(stream), VAL, VAL + NNZ); + raft::linalg::scalarMultiply(VAL, VAL, 1.0f / P_sum, NNZ, stream); compute_kl_div<<>>( VAL, Qs, KL_divs, NNZ); CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index e33d76f2a6..fa2ea5fc5a 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -66,14 +66,6 @@ void get_distances(const raft::handle_t& handle, std::vector input_vec = {input.X}; std::vector sizes_vec = {input.n}; - /** - * std::vector &input, std::vector &sizes, - IntType D, float *search_items, IntType n, int64_t *res_I, - float *res_D, IntType k, - std::shared_ptr allocator, - cudaStream_t userStream, - */ - raft::spatial::knn::brute_force_knn(handle, input_vec, sizes_vec, diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index c0e03822d8..9a98c4d641 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -120,6 +120,8 @@ value_t Exact_TSNE(value_t* VAL, if (last_iter) { value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); + value_t P_sum = thrust::reduce(rmm::exec_policy(stream), VAL, VAL + NNZ); + raft::linalg::scalarMultiply(VAL, VAL, 1.0f / P_sum, NNZ, stream); compute_kl_div<<>>( VAL, Qs, KL_divs, NNZ); } diff --git a/python/cuml/test/test_tsne.py b/python/cuml/test/test_tsne.py index 8158543666..ed30f03536 100644 --- a/python/cuml/test/test_tsne.py +++ b/python/cuml/test/test_tsne.py @@ -139,8 +139,10 @@ def test_tsne_knn_parameters(dataset, type_knn_graph, method): validate_embedding(X, embed) +from sklearn.manifold import TSNE as skTSNE + @pytest.mark.parametrize('dataset', test_datasets.values()) -@pytest.mark.parametrize('method', ['fft', 'barnes_hut']) +@pytest.mark.parametrize('method', ['fft']) def test_tsne(dataset, method): """ This tests how TSNE handles a lot of input data across time. From ff83ab99b721194f7c20f93ee3297096bc0477cf Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 30 Sep 2021 15:48:03 +0200 Subject: [PATCH 05/12] Fixing Q in kernels + adding Gtests --- cpp/src/tsne/barnes_hut_kernels.cuh | 21 +-- cpp/src/tsne/barnes_hut_tsne.cuh | 7 +- cpp/src/tsne/distances.cuh | 13 ++ cpp/src/tsne/exact_kernels.cuh | 40 +++--- cpp/src/tsne/exact_tsne.cuh | 7 +- cpp/src/tsne/fft_kernels.cuh | 15 +- cpp/src/tsne/fft_tsne.cuh | 8 +- cpp/src/tsne/tsne_runner.cuh | 13 +- cpp/src/tsne/utils.cuh | 2 +- cpp/test/sg/tsne_test.cu | 204 +++++++++++++++++++--------- python/cuml/test/test_tsne.py | 4 +- 11 files changed, 223 insertions(+), 111 deletions(-) diff --git a/cpp/src/tsne/barnes_hut_kernels.cuh b/cpp/src/tsne/barnes_hut_kernels.cuh index 3982c4bdfb..c00a10e5ca 100644 --- a/cpp/src/tsne/barnes_hut_kernels.cuh +++ b/cpp/src/tsne/barnes_hut_kernels.cuh @@ -682,29 +682,32 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, value_t* restrict attract1, value_t* restrict attract2, value_t* restrict Qs, - const value_idx NNZ) + const value_idx NNZ, + const value_t dof) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; const auto i = ROW[index]; const auto j = COL[index]; - const value_t y1d = Y1[i] - Y1[j]; - const value_t y2d = Y2[i] - Y2[j]; - value_t squared_euclidean_dist = y1d * y1d + y2d * y2d; + const value_t y1d = Y1[i] - Y1[j]; + const value_t y2d = Y2[i] - Y2[j]; + value_t dist = y1d * y1d + y2d * y2d; // As a sum of squares, SED is mathematically >= 0. There might be a source of // NaNs upstream though, so until we find and fix them, enforce that trait. - if (!(squared_euclidean_dist >= 0)) squared_euclidean_dist = 0.0f; - const value_t dof = 1.0f; - const value_t PQ = __fdividef(VAL[index], squared_euclidean_dist + dof); + if (!(dist >= 0)) dist = 0.0f; + + const value_t exponent = (dof + 1.0) / 2.0; + const value_t P = VAL[index]; + const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t PQ = P * Q; // Apply forces atomicAdd(&attract1[i], PQ * y1d); atomicAdd(&attract2[i], PQ * y2d); if (Qs) { // when computing KL div - value_t Q_unnormalized = __fdividef(dof, dof + squared_euclidean_dist); - Qs[index] = Q_unnormalized; + Qs[index] = Q; } // TODO: Convert attractive forces to CSR format diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index 7b5e36a1bf..b4865cdef1 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -280,15 +280,16 @@ value_t Barnes_Hut(value_t* VAL, attr_forces.data(), attr_forces.data() + n, last_iter ? Qs : nullptr, - NNZ); + NNZ, + fmaxf(params.dim - 1, 1)); CUDA_CHECK(cudaPeekAtLastError()); END_TIMER(attractive_time); if (last_iter) { - value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); - raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); value_t P_sum = thrust::reduce(rmm::exec_policy(stream), VAL, VAL + NNZ); raft::linalg::scalarMultiply(VAL, VAL, 1.0f / P_sum, NNZ, stream); + value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); compute_kl_div<<>>( VAL, Qs, KL_divs, NNZ); CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index fa2ea5fc5a..77fcdaf577 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -129,6 +129,19 @@ void get_distances(const raft::handle_t& handle, throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet."); } +template +void normalize_distances(value_t* distances, const size_t total_nn, cudaStream_t stream) +{ + auto abs_f = [] __device__(const value_t& x) { return abs(x); }; + value_t maxNorm = thrust::transform_reduce(rmm::exec_policy(stream), + distances, + distances + total_nn, + abs_f, + 0.0f, + thrust::maximum()); + raft::linalg::scalarMultiply(distances, distances, 1.0f / maxNorm, total_nn, stream); +} + /** * @brief Performs P + P.T. * @param[in] P: The perplexity matrix (n, k) diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index 10d031ac1e..b40c3f4cec 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -176,8 +176,7 @@ __global__ void attractive_kernel(const value_t* restrict VAL, const value_idx NNZ, const value_idx n, const value_idx dim, - const float df_power, // -(df + 1)/2) - const float recp_df) // 1 / df + const value_t dof) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; @@ -186,20 +185,24 @@ __global__ void attractive_kernel(const value_t* restrict VAL, // Euclidean distances // TODO: can provide any distance ie cosine // #862 - value_t d = 0; + value_t dist = 0; for (int k = 0; k < dim; k++) - d += Y[k * n + i] * Y[k * n + j]; - const value_t euclidean_d = -2.0f * d + norm[i] + norm[j]; + dist += Y[k * n + i] * Y[k * n + j]; + dist = norm[i] + norm[j] - 2.0f * dist; + + const value_t exponent = (dof + 1.0) / 2.0; + + const value_t P = VAL[index]; + const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t PQ = P * Q; - const value_t Q_unnormalized = __powf((euclidean_d * recp_df) + 1.0f, df_power); - const value_t PQ = __fdividef(VAL[index], Q_unnormalized); // Apply forces for (int k = 0; k < dim; k++) { raft::myAtomicAdd(&attract[k * n + i], PQ * (Y[k * n + i] - Y[k * n + j])); } if (Qs) { // when computing KL div - Qs[index] = Q_unnormalized; + Qs[index] = Q; } } @@ -216,7 +219,8 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, value_t* restrict attract1, value_t* restrict attract2, value_t* restrict Qs, - const value_idx NNZ) + const value_idx NNZ, + const value_t dof) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; @@ -225,17 +229,20 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, // Euclidean distances // TODO: can provide any distance ie cosine // #862 - const value_t euclidean_d = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); + const value_t dist = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); + + const value_t exponent = (dof + 1.0) / 2.0; - const value_t Q_unnormalized = 1.0f + euclidean_d; - const value_t PQ = __fdividef(VAL[index], Q_unnormalized); + const value_t P = VAL[index]; + const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t PQ = P * Q; // Apply forces raft::myAtomicAdd(&attract1[i], PQ * (Y1[i] - Y1[j])); raft::myAtomicAdd(&attract2[i], PQ * (Y2[i] - Y2[j])); if (Qs) { // when computing KL div - Qs[index] = Q_unnormalized; + Qs[index] = Q; } } @@ -251,8 +258,7 @@ void attractive_forces(const value_t* restrict VAL, const value_idx NNZ, const value_idx n, const value_idx dim, - const float df_power, // -(df + 1)/2) - const float recp_df, // 1 / df + const value_t dof, cudaStream_t stream) { CUDA_CHECK(cudaMemsetAsync(attract, 0, sizeof(value_t) * n * dim, stream)); @@ -262,12 +268,12 @@ void attractive_forces(const value_t* restrict VAL, // For general embedding dimensions if (dim != 2) { attractive_kernel<<>>( - VAL, COL, ROW, Y, norm, attract, Qs, NNZ, n, dim, df_power, recp_df); + VAL, COL, ROW, Y, norm, attract, Qs, NNZ, n, dim, dof); } // For special case dim == 2 else { attractive_kernel_2d<<>>( - VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, Qs, NNZ); + VAL, COL, ROW, Y, Y + n, norm, attract, attract + n, Qs, NNZ, dof); } CUDA_CHECK(cudaPeekAtLastError()); } diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 9a98c4d641..2976e7be75 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -113,15 +113,14 @@ value_t Exact_TSNE(value_t* VAL, NNZ, n, dim, - df_power, - recp_df, + fmaxf(params.dim - 1, 1), stream); if (last_iter) { - value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); - raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); value_t P_sum = thrust::reduce(rmm::exec_policy(stream), VAL, VAL + NNZ); raft::linalg::scalarMultiply(VAL, VAL, 1.0f / P_sum, NNZ, stream); + value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); compute_kl_div<<>>( VAL, Qs, KL_divs, NNZ); } diff --git a/cpp/src/tsne/fft_kernels.cuh b/cpp/src/tsne/fft_kernels.cuh index d353c2f1a7..8be17847f4 100644 --- a/cpp/src/tsne/fft_kernels.cuh +++ b/cpp/src/tsne/fft_kernels.cuh @@ -313,7 +313,8 @@ __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, const value_idx* __restrict__ coo_cols, const value_t* __restrict__ points, const value_idx num_points, - const value_idx num_nonzero) + const value_idx num_nonzero, + const value_t dof) { const value_idx TID = threadIdx.x + blockIdx.x * blockDim.x; if (TID >= num_nonzero) return; @@ -328,14 +329,18 @@ __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, value_t dx = ix - jx; value_t dy = iy - jy; - const value_t squared_euclidean_dist = (dx * dx) + (dy * dy); - const value_t Q_unnormalized = 1 + squared_euclidean_dist; - const value_t PQ = __fdividef(pij[TID], Q_unnormalized); + const value_t dist = (dx * dx) + (dy * dy); + const value_t exponent = (dof + 1.0) / 2.0; + + const value_t P = pij[TID]; + const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t PQ = P * Q; + atomicAdd(attr_forces + i, PQ * dx); atomicAdd(attr_forces + num_points + i, PQ * dy); if (Qs) { // when computing KL div - Qs[TID] = Q_unnormalized; + Qs[TID] = Q; } } diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index afd741e4c6..e29d051600 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -514,21 +514,21 @@ value_t FFT_TSNE(value_t* VAL, // Compute attractive forces { auto num_blocks = raft::ceildiv(NNZ, (value_idx)NTHREADS_1024); - bool last_iter = iter == params.max_iter - 1; - if (last_iter) { + const float dof = fmaxf(params.dim - 1, 1); // degree of freedom + if (iter == params.max_iter - 1) { // last iteration rmm::device_uvector tmp(NNZ, stream); value_t* Qs = tmp.data(); value_t* KL_divs = tmp.data(); FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), Qs, VAL, ROW, COL, Y, n, NNZ); + attractive_forces_device.data(), Qs, VAL, ROW, COL, Y, n, NNZ, dof); value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); compute_kl_div<<>>(VAL, Qs, KL_divs, NNZ); kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); } else { FFT::compute_Pij_x_Qij_kernel<<>>( - attractive_forces_device.data(), (value_t*)nullptr, VAL, ROW, COL, Y, n, NNZ); + attractive_forces_device.data(), (value_t*)nullptr, VAL, ROW, COL, Y, n, NNZ, dof); } } diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index 05065b4059..9bbebb6d5a 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -133,6 +133,14 @@ class TSNE_runner { //--------------------------------------------------- END_TIMER(DistancesTime); + START_TIMER; + //--------------------------------------------------- + // Normalize distances + CUML_LOG_DEBUG("Now normalizing distances so exp(D) doesn't explode."); + TSNE::normalize_distances(k_graph.knn_dists, n * params.n_neighbors, stream); + //--------------------------------------------------- + END_TIMER(NormalizeTime); + START_TIMER; //--------------------------------------------------- // Optimal perplexity @@ -172,12 +180,15 @@ class TSNE_runner { END_TIMER(SymmetrizeTime); } + public: + raft::sparse::COO COO_Matrix; + + private: const raft::handle_t& handle; tsne_input& input; knn_graph& k_graph; TSNEParams& params; - raft::sparse::COO COO_Matrix; value_idx n, p; value_t* Y; }; diff --git a/cpp/src/tsne/utils.cuh b/cpp/src/tsne/utils.cuh index bc0488e8a9..bd6f53a467 100644 --- a/cpp/src/tsne/utils.cuh +++ b/cpp/src/tsne/utils.cuh @@ -182,7 +182,7 @@ __global__ void min_max_kernel( */ template __global__ void compute_kl_div(const value_t* restrict Ps, - value_t* restrict Qs, + const value_t* restrict Qs, value_t* restrict kl_divergences, const value_idx NNZ) { diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index 973b4b3f48..e3de3ec750 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -27,8 +28,9 @@ #include #include #include -#include #include +#include +#include #include using namespace MLCommon; @@ -42,89 +44,156 @@ struct TSNEInput { double trustworthiness_threshold; }; +float get_kl_div(TSNEParams& params, + raft::sparse::COO& input_matrix, + float* emb_dists, + size_t n, + cudaStream_t stream) +{ + const size_t total_nn = 2 * n * params.n_neighbors; + + rmm::device_uvector Qs_vec(total_nn, stream); + float* Ps = input_matrix.vals(); + float* Qs = Qs_vec.data(); + float* KL_divs = Qs; + + // Normalize Ps + float P_sum = thrust::reduce(rmm::exec_policy(stream), Ps, Ps + total_nn); + raft::linalg::scalarMultiply(Ps, Ps, 1.0f / P_sum, total_nn, stream); + + // Build Qs + auto get_emb_dist = [=] __device__(const int64_t i, const int64_t j) { + return emb_dists[i * n + j]; + }; + raft::linalg::map(Qs, total_nn, get_emb_dist, stream, input_matrix.rows(), input_matrix.cols()); + + const float dof = fmaxf(params.dim - 1, 1); // degree of freedom + const float exponent = (dof + 1.0) / 2.0; + raft::linalg::unaryOp( + Qs, + Qs, + total_nn, + [=] __device__(float dist) { return __powf(dof / (dof + dist), exponent); }, + stream); + float Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + total_nn); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, total_nn, stream); + + compute_kl_div<<>>( + Ps, Qs, KL_divs, total_nn); + float kl_div = thrust::reduce(rmm::exec_policy(stream), KL_divs, KL_divs + total_nn); + return kl_div; +} + class TSNETest : public ::testing::TestWithParam { protected: - void assert_score(double score, const char* test, const double threshold) + struct TSNEResults; + + void assert_results(const char* test, TSNEResults& results) { - printf("%s", test); - printf("score = %f\n", score); - ASSERT_TRUE(threshold < score); + std::cout << "Testing " << test << ":" << std::endl; + std::cout << "\ttrustworthiness = " << results.trustworthiness << std::endl; + std::cout << "\tkl_div = " << results.kl_div << std::endl; + std::cout << "\tkl_div_ref = " << results.kl_div_ref << std::endl; + ASSERT_TRUE(results.trustworthiness > trustworthiness_threshold); + double kl_div_tol = 0.2; + ASSERT_TRUE(results.kl_div_ref - kl_div_tol < results.kl_div && + results.kl_div < results.kl_div_ref + kl_div_tol); + std::cout << std::endl; } - double runTest(TSNE_ALGORITHM algo, bool knn = false) + TSNEResults runTest(TSNE_ALGORITHM algo, bool knn = false) { raft::handle_t handle; + auto stream = handle.get_stream(); + TSNEResults results; - // Allocate memory - rmm::device_uvector X_d(n * p, handle.get_stream()); - raft::update_device(X_d.data(), dataset.data(), n * p, handle.get_stream()); - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - - rmm::device_uvector Y_d(n * 2, handle.get_stream()); - - rmm::device_uvector knn_indices(n * 90, handle.get_stream()); - - rmm::device_uvector knn_dists(n * 90, handle.get_stream()); - - manifold_dense_inputs_t input(X_d.data(), Y_d.data(), n, p); - knn_graph k_graph(n, 90, knn_indices.data(), knn_dists.data()); - - if (knn) TSNE::get_distances(handle, input, k_graph, handle.get_stream()); - - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - + // Setup parameters + model_params.algorithm = algo; + model_params.dim = 2; model_params.n_neighbors = 90; model_params.min_grad_norm = 1e-12; model_params.verbosity = CUML_LEVEL_DEBUG; - model_params.algorithm = algo; - TSNE_fit(handle, - X_d.data(), // X - Y_d.data(), // embeddings - n, // n_pts - p, // n_ftr - knn ? knn_indices.data() : NULL, // knn_indices - knn ? knn_dists.data() : NULL, // knn_dists - model_params); // model parameters - - float* embeddings_h = (float*)malloc(sizeof(float) * n * 2); + // Allocate memory + rmm::device_uvector X_d(n * p, stream); + raft::update_device(X_d.data(), dataset.data(), n * p, stream); + rmm::device_uvector Y_d(n * model_params.dim, stream); + rmm::device_uvector input_indices(0, stream); + rmm::device_uvector input_dists(0, stream); + rmm::device_uvector pw_emb_dists(n * n, stream); + + // Run TSNE + manifold_dense_inputs_t input(X_d.data(), Y_d.data(), n, p); + knn_graph k_graph(n, model_params.n_neighbors, nullptr, nullptr); + + if (knn) { + input_indices.resize(n * model_params.n_neighbors, stream); + input_dists.resize(n * model_params.n_neighbors, stream); + k_graph.knn_indices = input_indices.data(); + k_graph.knn_dists = input_dists.data(); + TSNE::get_distances(handle, input, k_graph, stream); + } + CUDA_CHECK(cudaStreamSynchronize(stream)); + TSNE_runner, knn_indices_dense_t, float> runner( + handle, input, k_graph, model_params); + results.kl_div = runner.run(); + + // Compute embedding's pairwise distances + pairwise_distance(handle, + Y_d.data(), + Y_d.data(), + pw_emb_dists.data(), + n, + n, + model_params.dim, + raft::distance::DistanceType::L2Expanded, + stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Compute theorical KL div + results.kl_div_ref = + get_kl_div(model_params, runner.COO_Matrix, pw_emb_dists.data(), n, stream); + + // Transfer embeddings + float* embeddings_h = (float*)malloc(sizeof(float) * n * model_params.dim); assert(embeddings_h != NULL); - raft::update_host(&embeddings_h[0], Y_d.data(), n * 2, handle.get_stream()); - CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - + raft::update_host(embeddings_h, Y_d.data(), n * model_params.dim, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); // Move embeddings to host. // This can be used for printing if needed. int k = 0; - float C_contiguous_embedding[n * 2]; + float C_contiguous_embedding[n * model_params.dim]; for (int i = 0; i < n; i++) { - for (int j = 0; j < 2; j++) + for (int j = 0; j < model_params.dim; j++) C_contiguous_embedding[k++] = embeddings_h[j * n + i]; } - // Move transposed embeddings back to device, as trustworthiness requires C contiguous format - raft::update_device(Y_d.data(), C_contiguous_embedding, n * 2, handle.get_stream()); - + raft::update_device(Y_d.data(), C_contiguous_embedding, n * model_params.dim, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); free(embeddings_h); - // Test trustworthiness - return trustworthiness_score( - handle, X_d.data(), Y_d.data(), n, p, 2, 5); + // Produce trustworthiness score + results.trustworthiness = + trustworthiness_score( + handle, X_d.data(), Y_d.data(), n, p, model_params.dim, 5); + + return results; } void basicTest() { - printf("BH\n"); + std::cout << "Running BH:" << std::endl; score_bh = runTest(TSNE_ALGORITHM::BARNES_HUT); - printf("EXACT\n"); + std::cout << "Running EXACT:" << std::endl; score_exact = runTest(TSNE_ALGORITHM::EXACT); - printf("FFT\n"); + std::cout << "Running FFT:" << std::endl; score_fft = runTest(TSNE_ALGORITHM::FFT); - printf("KNN BH\n"); + std::cout << "Running KNN BH:" << std::endl; knn_score_bh = runTest(TSNE_ALGORITHM::BARNES_HUT, true); - printf("KNN EXACT\n"); + std::cout << "Running KNN EXACT:" << std::endl; knn_score_exact = runTest(TSNE_ALGORITHM::EXACT, true); - printf("KNN FFT\n"); + std::cout << "Running KNN FFT:" << std::endl; knn_score_fft = runTest(TSNE_ALGORITHM::FFT, true); } @@ -145,12 +214,19 @@ class TSNETest : public ::testing::TestWithParam { TSNEParams model_params; std::vector dataset; int n, p; - double score_bh; - double score_exact; - double score_fft; - double knn_score_bh; - double knn_score_exact; - double knn_score_fft; + + struct TSNEResults { + double trustworthiness; + double kl_div_ref; + double kl_div; + }; + + TSNEResults score_bh; + TSNEResults score_exact; + TSNEResults score_fft; + TSNEResults knn_score_bh; + TSNEResults knn_score_exact; + TSNEResults knn_score_fft; double trustworthiness_threshold; }; @@ -163,12 +239,12 @@ const std::vector inputs = { typedef TSNETest TSNETestF; TEST_P(TSNETestF, Result) { - assert_score(score_bh, "bh\n", trustworthiness_threshold); - assert_score(score_exact, "exact\n", trustworthiness_threshold); - assert_score(score_fft, "fft\n", trustworthiness_threshold); - assert_score(knn_score_bh, "knn_bh\n", trustworthiness_threshold); - assert_score(knn_score_exact, "knn_exact\n", trustworthiness_threshold); - assert_score(knn_score_fft, "knn_fft\n", trustworthiness_threshold); + assert_results("BH", score_bh); + assert_results("EXACT", score_exact); + assert_results("FFT", score_fft); + assert_results("KNN BH", knn_score_bh); + assert_results("KNN EXACT", knn_score_exact); + assert_results("KNN FFT", knn_score_fft); } INSTANTIATE_TEST_CASE_P(TSNETests, TSNETestF, ::testing::ValuesIn(inputs)); diff --git a/python/cuml/test/test_tsne.py b/python/cuml/test/test_tsne.py index ed30f03536..8158543666 100644 --- a/python/cuml/test/test_tsne.py +++ b/python/cuml/test/test_tsne.py @@ -139,10 +139,8 @@ def test_tsne_knn_parameters(dataset, type_knn_graph, method): validate_embedding(X, embed) -from sklearn.manifold import TSNE as skTSNE - @pytest.mark.parametrize('dataset', test_datasets.values()) -@pytest.mark.parametrize('method', ['fft']) +@pytest.mark.parametrize('method', ['fft', 'barnes_hut']) def test_tsne(dataset, method): """ This tests how TSNE handles a lot of input data across time. From 8eb451c8fbd33a801e3454cde0f0c81e57cebe79 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Thu, 30 Sep 2021 15:59:05 +0200 Subject: [PATCH 06/12] Fix Python style --- python/cuml/manifold/t_sne.pyx | 37 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 5bd41adee0..1dda76e250 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -496,25 +496,28 @@ class TSNE(Base, cdef float kl_divergence = 0 if self.sparse_fit: kl_divergence = TSNE_fit_sparse(handle_[0], - self.X_m.indptr.ptr, - self.X_m.indices.ptr, - self.X_m.data.ptr, - embed_ptr, - self.X_m.nnz, - n, - p, - knn_indices_raw, - knn_dists_raw, - deref(params)) + + self.X_m.indptr.ptr, + + self.X_m.indices.ptr, + + self.X_m.data.ptr, + embed_ptr, + self.X_m.nnz, + n, + p, + knn_indices_raw, + knn_dists_raw, + deref(params)) else: kl_divergence = TSNE_fit(handle_[0], - self.X_m.ptr, - embed_ptr, - n, - p, - knn_indices_raw, - knn_dists_raw, - deref(params)) + self.X_m.ptr, + embed_ptr, + n, + p, + knn_indices_raw, + knn_dists_raw, + deref(params)) self.handle.sync() free(params) From 115e0d41466ec53035d0df8398f8be66a0f1ff53 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 8 Oct 2021 19:14:13 +0200 Subject: [PATCH 07/12] documentation + warning --- python/cuml/manifold/t_sne.pyx | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 1dda76e250..fe8f661ab7 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -204,6 +204,12 @@ class TSNE(Base, module level, `cuml.global_settings.output_type`. See :ref:`output-data-type-configuration` for more info. + Attributes + ---------- + kl_divergence_ : float + Kullback-Leibler divergence after optimization. An experimental + feature at this time. + References ----------- .. [1] `van der Maaten, L.J.P. @@ -522,7 +528,7 @@ class TSNE(Base, self.handle.sync() free(params) - self.kl_divergence_ = kl_divergence + self._kl_divergence_ = kl_divergence if self.verbose: print("[t-SNE] KL divergence: {}".format(kl_divergence)) return self @@ -581,6 +587,18 @@ class TSNE(Base, params.algorithm = algo return params + @property + def kl_divergence_(self): + if self.method == 'barnes_hut': + warnings.warn("The calculation of the Kullback-Leibler " + "divergence is still an experimental feature " + "while using the Barnes Hut algorithm.") + return self._kl_divergence_ + + @kl_divergence_.setter + def kl_divergence_(self, value): + self._kl_divergence_ = value + def __del__(self): if hasattr(self, "embedding_"): From 273b685d6a5f193cbcc3e880e74e64f278b393c3 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 8 Oct 2021 19:17:57 +0200 Subject: [PATCH 08/12] Update TSNE license --- cpp/src/tsne/cannylab/bh.cu | 1742 ++++++++--------------- cpp/src/tsne/cannylabs_tsne_license.txt | 85 +- 2 files changed, 631 insertions(+), 1196 deletions(-) diff --git a/cpp/src/tsne/cannylab/bh.cu b/cpp/src/tsne/cannylab/bh.cu index 776f68aa15..d280ae6f76 100644 --- a/cpp/src/tsne/cannylab/bh.cu +++ b/cpp/src/tsne/cannylab/bh.cu @@ -1,37 +1,41 @@ /* -CUDA BarnesHut v3.1: Simulation of the gravitational forces -in a galactic cluster using the Barnes-Hut n-body algorithm +ECL-BH v4.5: Simulation of the gravitational forces in a star cluster using +the Barnes-Hut n-body algorithm. -Copyright (c) 2013, Texas State University-San Marcos. All rights reserved. +Copyright (c) 2010-2020 Texas State University. All rights reserved. -Redistribution and use in source and binary forms, with or without modification, -are permitted for academic, research, experimental, or personal use provided that -the following conditions are met: +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of Texas State University-San Marcos nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. - -For all other uses, please contact the Office for Commercialization and Industry -Relations at Texas State University-San Marcos . + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Texas State University nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED -IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE -OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -OF THE POSSIBILITY OF SUCH DAMAGE. - -Author: Martin Burtscher +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL TEXAS STATE UNIVERSITY BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Authors: Martin Burtscher and Sahar Azimi + +URL: The latest version of this code is available at +https://userweb.cs.txstate.edu/~burtscher/research/ECL-BH/. + +Publication: This work is described in detail in the following paper. +Martin Burtscher and Keshav Pingali. An Efficient CUDA Implementation of the +Tree-based Barnes Hut n-Body Algorithm. Chapter 6 in GPU Computing Gems +Emerald Edition, pp. 75-92. January 2011. */ @@ -40,118 +44,38 @@ Author: Martin Burtscher #include #include #include -#include -#include -#include - -#ifndef NO_ZMQ - #include -#endif - -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#endif - -// #ifdef __KEPLER__ - -// #define GPU_ARCH "KEPLER" - -// // thread count -// #define THREADS1 1024 /* must be a power of 2 */ -// #define THREADS2 1024 -// #define THREADS3 768 -// #define THREADS4 128 -// #define THREADS5 1024 -// #define THREADS6 1024 - -// // block count = factor * #SMs -// #define FACTOR1 2 -// #define FACTOR2 2 -// #define FACTOR3 1 /* must all be resident at the same time */ -// #define FACTOR4 4 /* must all be resident at the same time */ -// #define FACTOR5 2 -// #define FACTOR6 2 - -// #elif __MAXWELL__ - -// #define GPU_ARCH "MAXWELL" - -// // thread count -// #define THREADS1 512 /* must be a power of 2 */ -// #define THREADS2 512 -// #define THREADS3 128 -// #define THREADS4 64 -// #define THREADS5 256 -// #define THREADS6 1024 - -// // block count = factor * #SMs -// #define FACTOR1 3 -// #define FACTOR2 3 -// #define FACTOR3 6 /* must all be resident at the same time */ -// #define FACTOR4 6 /* must all be resident at the same time */ -// #define FACTOR5 5 -// #define FACTOR6 1 - -// #elif __PASCAL__ - -#define GPU_ARCH "PASCAL" - -// thread count -#define THREADS1 512 /* must be a power of 2 */ -#define THREADS2 512 -#define THREADS3 768 -#define THREADS4 128 + +// threads per block +#define THREADS1 1024 /* must be a power of 2 */ +#define THREADS2 1024 +#define THREADS3 768 /* shared-memory limited on some devices */ +#define THREADS4 1024 #define THREADS5 1024 #define THREADS6 1024 -#define THREADS7 1024 // block count = factor * #SMs -#define FACTOR1 3 -#define FACTOR2 3 +#define FACTOR1 2 +#define FACTOR2 2 #define FACTOR3 1 /* must all be resident at the same time */ -#define FACTOR4 4 /* must all be resident at the same time */ +#define FACTOR4 1 /* must all be resident at the same time */ #define FACTOR5 2 #define FACTOR6 2 -#define FACTOR7 1 - -// #else - -// #define GPU_ARCH "UNKNOWN" - -// // thread count -// #define THREADS1 512 /* must be a power of 2 */ -// #define THREADS2 512 -// #define THREADS3 128 -// #define THREADS4 64 -// #define THREADS5 256 -// #define THREADS6 1024 - -// // block count = factor * #SMs -// #define FACTOR1 3 -// #define FACTOR2 3 -// #define FACTOR3 6 /* must all be resident at the same time */ -// #define FACTOR4 6 /* must all be resident at the same time */ -// #define FACTOR5 5 -// #define FACTOR6 1 - -// #endif #define WARPSIZE 32 #define MAXDEPTH 32 -__device__ volatile int stepd, bottomd, maxdepthd; +__device__ volatile int stepd, bottomd; __device__ unsigned int blkcntd; __device__ volatile float radiusd; + /******************************************************************************/ /*** initialize memory ********************************************************/ /******************************************************************************/ -__global__ void InitializationKernel(int * __restrict errd) +__global__ void InitializationKernel() { - *errd = 0; stepd = -1; - maxdepthd = 1; blkcntd = 0; } @@ -162,87 +86,94 @@ __global__ void InitializationKernel(int * __restrict errd) __global__ __launch_bounds__(THREADS1, FACTOR1) -void BoundingBoxKernel(int nnodesd, - int nbodiesd, - volatile int * __restrict startd, - volatile int * __restrict childd, - volatile float * __restrict massd, - volatile float * __restrict posxd, - volatile float * __restrict posyd, - volatile float * __restrict maxxd, - volatile float * __restrict maxyd, - volatile float * __restrict minxd, - volatile float * __restrict minyd) +void BoundingBoxKernel(const int nnodesd, const int nbodiesd, int* const __restrict__ startd, int* const __restrict__ childd, float4* const __restrict__ posMassd, float3* const __restrict__ maxd, float3* const __restrict__ mind) { - register int i, j, k, inc; - register float val, minx, maxx, miny, maxy; - __shared__ volatile float sminx[THREADS1], smaxx[THREADS1], sminy[THREADS1], smaxy[THREADS1]; + int i, j, k, inc; + float val; + __shared__ volatile float sminx[THREADS1], smaxx[THREADS1], sminy[THREADS1], smaxy[THREADS1], sminz[THREADS1], smaxz[THREADS1]; + float3 min, max; // initialize with valid data (in case #bodies < #threads) - minx = maxx = posxd[0]; - miny = maxy = posyd[0]; + const float4 p0 = posMassd[0]; + min.x = max.x = p0.x; + min.y = max.y = p0.y; + min.z = max.z = p0.z; // scan all bodies i = threadIdx.x; inc = THREADS1 * gridDim.x; for (j = i + blockIdx.x * THREADS1; j < nbodiesd; j += inc) { - val = posxd[j]; - minx = fminf(minx, val); - maxx = fmaxf(maxx, val); - val = posyd[j]; - miny = fminf(miny, val); - maxy = fmaxf(maxy, val); + const float4 p = posMassd[j]; + val = p.x; + min.x = fminf(min.x, val); + max.x = fmaxf(max.x, val); + val = p.y; + min.y = fminf(min.y, val); + max.y = fmaxf(max.y, val); + val = p.z; + min.z = fminf(min.z, val); + max.z = fmaxf(max.z, val); } // reduction in shared memory - sminx[i] = minx; - smaxx[i] = maxx; - sminy[i] = miny; - smaxy[i] = maxy; + sminx[i] = min.x; + smaxx[i] = max.x; + sminy[i] = min.y; + smaxy[i] = max.y; + sminz[i] = min.z; + smaxz[i] = max.z; for (j = THREADS1 / 2; j > 0; j /= 2) { __syncthreads(); if (i < j) { k = i + j; - sminx[i] = minx = fminf(minx, sminx[k]); - smaxx[i] = maxx = fmaxf(maxx, smaxx[k]); - sminy[i] = miny = fminf(miny, sminy[k]); - smaxy[i] = maxy = fmaxf(maxy, smaxy[k]); + sminx[i] = min.x = fminf(min.x, sminx[k]); + smaxx[i] = max.x = fmaxf(max.x, smaxx[k]); + sminy[i] = min.y = fminf(min.y, sminy[k]); + smaxy[i] = max.y = fmaxf(max.y, smaxy[k]); + sminz[i] = min.z = fminf(min.z, sminz[k]); + smaxz[i] = max.z = fmaxf(max.z, smaxz[k]); } } // write block result to global memory if (i == 0) { k = blockIdx.x; - minxd[k] = minx; - maxxd[k] = maxx; - minyd[k] = miny; - maxyd[k] = maxy; + mind[k] = min; + maxd[k] = max; __threadfence(); inc = gridDim.x - 1; if (inc == atomicInc(&blkcntd, inc)) { // I'm the last block, so combine all block results for (j = 0; j <= inc; j++) { - minx = fminf(minx, minxd[j]); - maxx = fmaxf(maxx, maxxd[j]); - miny = fminf(miny, minyd[j]); - maxy = fmaxf(maxy, maxyd[j]); + float3 minp = mind[j]; + float3 maxp = maxd[j]; + min.x = fminf(min.x, minp.x); + max.x = fmaxf(max.x, maxp.x); + min.y = fminf(min.y, minp.y); + max.y = fmaxf(max.y, maxp.y); + min.z = fminf(min.z, minp.z); + max.z = fmaxf(max.z, maxp.z); } - // compute 'radius' - radiusd = fmaxf(maxx - minx, maxy - miny) * 0.5f + 1e-5f; + // compute radius + val = fmaxf(max.x - min.x, max.y - min.y); + radiusd = fmaxf(val, max.z - min.z) * 0.5f; // create root node k = nnodesd; bottomd = k; - massd[k] = -1.0f; startd[k] = 0; - posxd[k] = (minx + maxx) * 0.5f; - posyd[k] = (miny + maxy) * 0.5f; - k *= 4; - for (i = 0; i < 4; i++) childd[k + i] = -1; + float4 p; + p.x = (min.x + max.x) * 0.5f; + p.y = (min.y + max.y) * 0.5f; + p.z = (min.z + max.z) * 0.5f; + p.w = -1.0f; + posMassd[k] = p; + k *= 8; + for (i = 0; i < 8; i++) childd[k + i] = -1; stepd++; } @@ -256,12 +187,12 @@ void BoundingBoxKernel(int nnodesd, __global__ __launch_bounds__(1024, 1) -void ClearKernel1(int nnodesd, int nbodiesd, volatile int * __restrict childd) +void ClearKernel1(const int nnodesd, const int nbodiesd, int* const __restrict__ childd) { - register int k, inc, top, bottom; + int k, inc, top, bottom; - top = 4 * nnodesd; - bottom = 4 * nbodiesd; + top = 8 * nnodesd; + bottom = 8 * nbodiesd; inc = blockDim.x * gridDim.x; k = (bottom & (-WARPSIZE)) + threadIdx.x + blockIdx.x * blockDim.x; // align to warp size if (k < bottom) k += inc; @@ -276,130 +207,126 @@ void ClearKernel1(int nnodesd, int nbodiesd, volatile int * __restrict childd) __global__ __launch_bounds__(THREADS2, FACTOR2) -void TreeBuildingKernel(int nnodesd, - int nbodiesd, - volatile int * __restrict errd, - volatile int * __restrict childd, - volatile float * __restrict posxd, - volatile float * __restrict posyd) +void TreeBuildingKernel(const int nnodesd, const int nbodiesd, volatile int* const __restrict__ childd, const float4* const __restrict__ posMassd) { - register int i, j, depth, localmaxdepth, skip, inc; - register float x, y, r; - register float px, py; - register float dx, dy; - register int ch, n, cell, locked, patch; - register float radius, rootx, rooty; + int i, j, depth, skip, inc; + float x, y, z, r; + float dx, dy, dz; + int ch, n, cell, locked, patch; + float radius; // cache root data - radius = radiusd; - rootx = posxd[nnodesd]; - rooty = posyd[nnodesd]; + radius = radiusd * 0.5f; + const float4 root = posMassd[nnodesd]; - localmaxdepth = 1; skip = 1; inc = blockDim.x * gridDim.x; i = threadIdx.x + blockIdx.x * blockDim.x; // iterate over all bodies assigned to thread while (i < nbodiesd) { - // if (TID == 0) - // printf("\tStarting\n"); + const float4 p = posMassd[i]; if (skip != 0) { // new body, so start traversing at root skip = 0; - px = posxd[i]; - py = posyd[i]; n = nnodesd; depth = 1; - r = radius * 0.5f; - dx = dy = -r; + r = radius; + dx = dy = dz = -r; j = 0; // determine which child to follow - if (rootx < px) {j = 1; dx = r;} - if (rooty < py) {j |= 2; dy = r;} - x = rootx + dx; - y = rooty + dy; + if (root.x < p.x) {j = 1; dx = r;} + if (root.y < p.y) {j |= 2; dy = r;} + if (root.z < p.z) {j |= 4; dz = r;} + x = root.x + dx; + y = root.y + dy; + z = root.z + dz; } // follow path to leaf cell - ch = childd[n*4+j]; + ch = childd[n*8+j]; while (ch >= nbodiesd) { n = ch; depth++; r *= 0.5f; - dx = dy = -r; + dx = dy = dz = -r; j = 0; // determine which child to follow - if (x < px) {j = 1; dx = r;} - if (y < py) {j |= 2; dy = r;} + if (x < p.x) {j = 1; dx = r;} + if (y < p.y) {j |= 2; dy = r;} + if (z < p.z) {j |= 4; dz = r;} x += dx; y += dy; - ch = childd[n*4+j]; + z += dz; + ch = childd[n*8+j]; } + if (ch != -2) { // skip if child pointer is locked and try again later - locked = n*4+j; + locked = n*8+j; if (ch == -1) { - if (-1 == atomicCAS((int *)&childd[locked], -1, i)) { // if null, just insert the new body - localmaxdepth = max(depth, localmaxdepth); + if (-1 == atomicCAS((int*)&childd[locked], -1, i)) { // if null, just insert the new body i += inc; // move on to next body skip = 1; } - } else { // there already is a body in this position - if (ch == atomicCAS((int *)&childd[locked], ch, -2)) { // try to lock + } else { // there already is a body at this position + if (ch == atomicCAS((int*)&childd[locked], ch, -2)) { // try to lock patch = -1; - // create new cell(s) and insert the old and new body + const float4 chp = posMassd[ch]; + // create new cell(s) and insert the old and new bodies do { depth++; + if (depth > MAXDEPTH) {printf("ERROR: maximum depth exceeded (bodies are too close together)\n"); asm("trap;");} - cell = atomicSub((int *)&bottomd, 1) - 1; - if (cell <= nbodiesd) { - *errd = 1; - bottomd = nnodesd; - } + cell = atomicSub((int*)&bottomd, 1) - 1; + if (cell <= nbodiesd) {printf("ERROR: out of cell memory\n"); asm("trap;");} if (patch != -1) { - childd[n*4+j] = cell; + childd[n*8+j] = cell; } patch = max(patch, cell); + j = 0; - if (x < posxd[ch]) j = 1; - if (y < posyd[ch]) j |= 2; - childd[cell*4+j] = ch; + if (x < chp.x) j = 1; + if (y < chp.y) j |= 2; + if (z < chp.z) j |= 4; + childd[cell*8+j] = ch; + n = cell; r *= 0.5f; - dx = dy = -r; + dx = dy = dz = -r; j = 0; - if (x < px) {j = 1; dx = r;} - if (y < py) {j |= 2; dy = r;} + if (x < p.x) {j = 1; dx = r;} + if (y < p.y) {j |= 2; dy = r;} + if (z < p.z) {j |= 4; dz = r;} x += dx; y += dy; - ch = childd[n*4+j]; + z += dz; + + ch = childd[n*8+j]; // repeat until the two bodies are different children - } while (ch >= 0 && r > 1e-10); // add radius check because bodies that are very close together can cause this to fail... there is some error condition here that I'm not entirely sure of (not just when two bodies are equal) - childd[n*4+j] = i; + } while (ch >= 0); + childd[n*8+j] = i; - localmaxdepth = max(depth, localmaxdepth); i += inc; // move on to next body skip = 2; } } } + __syncthreads(); // optional barrier for performance __threadfence(); if (skip == 2) { childd[locked] = patch; } } - // record maximum tree depth - atomicMax((int *)&maxdepthd, localmaxdepth); } __global__ __launch_bounds__(1024, 1) -void ClearKernel2(int nnodesd, volatile int * __restrict startd, volatile float * __restrict massd) +void ClearKernel2(const int nnodesd, int* const __restrict__ startd, float4* const __restrict__ posMassd) { - register int k, inc, bottom; + int k, inc, bottom; bottom = bottomd; inc = blockDim.x * gridDim.x; @@ -408,7 +335,7 @@ void ClearKernel2(int nnodesd, volatile int * __restrict startd, volatile float // iterate over all cells assigned to thread while (k < nnodesd) { - massd[k] = -1.0f; + posMassd[k].w = -1.0f; startd[k] = -1; k += inc; } @@ -421,64 +348,67 @@ void ClearKernel2(int nnodesd, volatile int * __restrict startd, volatile float __global__ __launch_bounds__(THREADS3, FACTOR3) -void SummarizationKernel(const int nnodesd, - const int nbodiesd, - volatile int * __restrict countd, - const int * __restrict childd, - volatile float * __restrict massd, - volatile float * __restrict posxd, - volatile float * __restrict posyd) +void SummarizationKernel(const int nnodesd, const int nbodiesd, volatile int* const __restrict__ countd, const int* const __restrict__ childd, volatile float4* const __restrict__ posMassd) { - register int i, j, k, ch, inc, cnt, bottom, flag; - register float m, cm, px, py; - __shared__ int child[THREADS3 * 4]; - __shared__ float mass[THREADS3 * 4]; + int i, j, k, ch, inc, cnt, bottom; + float m, cm, px, py, pz; + __shared__ int child[THREADS3 * 8]; + __shared__ float mass[THREADS3 * 8]; bottom = bottomd; inc = blockDim.x * gridDim.x; k = (bottom & (-WARPSIZE)) + threadIdx.x + blockIdx.x * blockDim.x; // align to warp size if (k < bottom) k += inc; - register int restart = k; - for (j = 0; j < 5; j++) { // wait-free pre-passes + int restart = k; + for (j = 0; j < 3; j++) { // wait-free pre-passes // iterate over all cells assigned to thread while (k <= nnodesd) { - if (massd[k] < 0.0f) { - for (i = 0; i < 4; i++) { - ch = childd[k*4+i]; + if (posMassd[k].w < 0.0f) { + for (i = 0; i < 8; i++) { + ch = childd[k*8+i]; child[i*THREADS3+threadIdx.x] = ch; // cache children - if ((ch >= nbodiesd) && ((mass[i*THREADS3+threadIdx.x] = massd[ch]) < 0.0f)) { + if ((ch >= nbodiesd) && ((mass[i*THREADS3+threadIdx.x] = posMassd[ch].w) < 0.0f)) { break; } } - if (i == 4) { + if (i == 8) { // all children are ready cm = 0.0f; px = 0.0f; py = 0.0f; + pz = 0.0f; cnt = 0; - for (i = 0; i < 4; i++) { + for (i = 0; i < 8; i++) { ch = child[i*THREADS3+threadIdx.x]; if (ch >= 0) { + // four reads due to missing copy constructor for "volatile float4" + const float chx = posMassd[ch].x; + const float chy = posMassd[ch].y; + const float chz = posMassd[ch].z; + const float chw = posMassd[ch].w; if (ch >= nbodiesd) { // count bodies (needed later) m = mass[i*THREADS3+threadIdx.x]; cnt += countd[ch]; } else { - m = massd[ch]; + m = chw; cnt++; } // add child's contribution cm += m; - px += posxd[ch] * m; - py += posyd[ch] * m; + px += chx * m; + py += chy * m; + pz += chz * m; } } countd[k] = cnt; m = 1.0f / cm; - posxd[k] = px * m; - posyd[k] = py * m; - __threadfence(); // make sure data are visible before setting mass - massd[k] = cm; + // four writes due to missing copy constructor for "volatile float4" + posMassd[k].x = px * m; + posMassd[k].y = py * m; + posMassd[k].z = pz * m; + __threadfence(); + posMassd[k].w = cm; } } k += inc; // move on to next cell @@ -486,27 +416,26 @@ void SummarizationKernel(const int nnodesd, k = restart; } - flag = 0; j = 0; // iterate over all cells assigned to thread while (k <= nnodesd) { - if (massd[k] >= 0.0f) { + if (posMassd[k].w >= 0.0f) { k += inc; } else { if (j == 0) { - j = 4; - for (i = 0; i < 4; i++) { - ch = childd[k*4+i]; + j = 8; + for (i = 0; i < 8; i++) { + ch = childd[k*8+i]; child[i*THREADS3+threadIdx.x] = ch; // cache children - if ((ch < nbodiesd) || ((mass[i*THREADS3+threadIdx.x] = massd[ch]) >= 0.0f)) { + if ((ch < nbodiesd) || ((mass[i*THREADS3+threadIdx.x] = posMassd[ch].w) >= 0.0f)) { j--; } } } else { - j = 4; - for (i = 0; i < 4; i++) { + j = 8; + for (i = 0; i < 8; i++) { ch = child[i*THREADS3+threadIdx.x]; - if ((ch < nbodiesd) || (mass[i*THREADS3+threadIdx.x] >= 0.0f) || ((mass[i*THREADS3+threadIdx.x] = massd[ch]) >= 0.0f)) { + if ((ch < nbodiesd) || (mass[i*THREADS3+threadIdx.x] >= 0.0f) || ((mass[i*THREADS3+threadIdx.x] = posMassd[ch].w) >= 0.0f)) { j--; } } @@ -517,37 +446,41 @@ void SummarizationKernel(const int nnodesd, cm = 0.0f; px = 0.0f; py = 0.0f; + pz = 0.0f; cnt = 0; - for (i = 0; i < 4; i++) { + for (i = 0; i < 8; i++) { ch = child[i*THREADS3+threadIdx.x]; if (ch >= 0) { + // four reads due to missing copy constructor for "volatile float4" + const float chx = posMassd[ch].x; + const float chy = posMassd[ch].y; + const float chz = posMassd[ch].z; + const float chw = posMassd[ch].w; if (ch >= nbodiesd) { // count bodies (needed later) m = mass[i*THREADS3+threadIdx.x]; cnt += countd[ch]; } else { - m = massd[ch]; + m = chw; cnt++; } // add child's contribution cm += m; - px += posxd[ch] * m; - py += posyd[ch] * m; + px += chx * m; + py += chy * m; + pz += chz * m; } } countd[k] = cnt; m = 1.0f / cm; - posxd[k] = px * m; - posyd[k] = py * m; - flag = 1; + // four writes due to missing copy constructor for "volatile float4" + posMassd[k].x = px * m; + posMassd[k].y = py * m; + posMassd[k].z = pz * m; + __threadfence(); + posMassd[k].w = cm; + k += inc; } } - __syncthreads(); - // __threadfence(); - if (flag != 0) { - massd[k] = cm; - k += inc; - flag = 0; - } } } @@ -558,9 +491,9 @@ void SummarizationKernel(const int nnodesd, __global__ __launch_bounds__(THREADS4, FACTOR4) -void SortKernel(int nnodesd, int nbodiesd, int * __restrict sortd, int * __restrict countd, volatile int * __restrict startd, int * __restrict childd) +void SortKernel(const int nnodesd, const int nbodiesd, int* const __restrict__ sortd, const int* const __restrict__ countd, volatile int* const __restrict__ startd, int* const __restrict__ childd) { - register int i, j, k, ch, dec, start, bottom; + int i, j, k, ch, dec, start, bottom; bottom = bottomd; dec = blockDim.x * gridDim.x; @@ -571,13 +504,13 @@ void SortKernel(int nnodesd, int nbodiesd, int * __restrict sortd, int * __restr start = startd[k]; if (start >= 0) { j = 0; - for (i = 0; i < 4; i++) { - ch = childd[k*4+i]; + for (i = 0; i < 8; i++) { + ch = childd[k*8+i]; if (ch >= 0) { if (i != j) { // move children to front (needed later for speed) - childd[k*4+i] = -1; - childd[k*4+j] = ch; + childd[k*8+i] = -1; + childd[k*8+j] = ch; } j++; if (ch >= nbodiesd) { @@ -593,6 +526,7 @@ void SortKernel(int nnodesd, int nbodiesd, int * __restrict sortd, int * __restr } k -= dec; // move on to next cell } + __syncthreads(); // optional barrier for performance } } @@ -603,120 +537,107 @@ void SortKernel(int nnodesd, int nbodiesd, int * __restrict sortd, int * __restr __global__ __launch_bounds__(THREADS5, FACTOR5) -void ForceCalculationKernel(int nnodesd, - int nbodiesd, - volatile int * __restrict errd, - float theta, - float epssqd, // correction for zero distance - volatile int * __restrict sortd, - volatile int * __restrict childd, - volatile float * __restrict massd, - volatile float * __restrict posxd, - volatile float * __restrict posyd, - volatile float * __restrict velxd, - volatile float * __restrict velyd, - volatile float * __restrict normd) +void ForceCalculationKernel(const int nnodesd, const int nbodiesd, const float dthfd, const float itolsqd, const float epssqd, const int* const __restrict__ sortd, const int* const __restrict__ childd, const float4* const __restrict__ posMassd, float2* const __restrict__ veld, float4* const __restrict__ accVeld) { - register int i, j, k, n, depth, base, sbase, diff, pd, nd; - register float px, py, vx, vy, dx, dy, normsum, tmp, mult; + int i, j, k, n, depth, base, sbase, diff, pd, nd; + float ax, ay, az, dx, dy, dz, tmp; __shared__ volatile int pos[MAXDEPTH * THREADS5/WARPSIZE], node[MAXDEPTH * THREADS5/WARPSIZE]; __shared__ float dq[MAXDEPTH * THREADS5/WARPSIZE]; if (0 == threadIdx.x) { - dq[0] = (radiusd * radiusd) / (theta * theta); - for (i = 1; i < maxdepthd; i++) { - dq[i] = dq[i - 1] * 0.25f; // radius is halved every level of tree so squared radius is quartered - dq[i - 1] += epssqd; + tmp = radiusd * 2; + // precompute values that depend only on tree level + dq[0] = tmp * tmp * itolsqd; + for (i = 1; i < MAXDEPTH; i++) { + dq[i] = dq[i - 1] * 0.25f; + dq[i - 1] += epssqd; } dq[i - 1] += epssqd; - - if (maxdepthd > MAXDEPTH) { - *errd = maxdepthd; - } } __syncthreads(); - if (maxdepthd <= MAXDEPTH) { - // figure out first thread in each warp (lane 0) - base = threadIdx.x / WARPSIZE; - sbase = base * WARPSIZE; - j = base * MAXDEPTH; + // figure out first thread in each warp (lane 0) + base = threadIdx.x / WARPSIZE; + sbase = base * WARPSIZE; + j = base * MAXDEPTH; + + diff = threadIdx.x - sbase; + // make multiple copies to avoid index calculations later + if (diff < MAXDEPTH) { + dq[diff+j] = dq[diff]; + } + __syncthreads(); - diff = threadIdx.x - sbase; - // make multiple copies to avoid index calculations later - if (diff < MAXDEPTH) { - dq[diff+j] = dq[diff]; + // iterate over all bodies assigned to thread + for (k = threadIdx.x + blockIdx.x * blockDim.x; k < nbodiesd; k += blockDim.x * gridDim.x) { + i = sortd[k]; // get permuted/sorted index + // cache position info + const float4 pi = posMassd[i]; + + ax = 0.0f; + ay = 0.0f; + az = 0.0f; + + // initialize iteration stack, i.e., push root node onto stack + depth = j; + if (sbase == threadIdx.x) { + pos[j] = 0; + node[j] = nnodesd * 8; } - __syncthreads(); - __threadfence_block(); - - // iterate over all bodies assigned to thread - for (k = threadIdx.x + blockIdx.x * blockDim.x; k < nbodiesd; k += blockDim.x * gridDim.x) { - i = sortd[k]; // get permuted/sorted index - // cache position info - px = posxd[i]; - py = posyd[i]; - - vx = 0.0f; - vy = 0.0f; - normsum = 0.0f; - - // initialize iteration stack, i.e., push root node onto stack - depth = j; - if (sbase == threadIdx.x) { - pos[j] = 0; - node[j] = nnodesd * 4; - } - do { - // stack is not empty - pd = pos[depth]; - nd = node[depth]; - while (pd < 4) { - // node on top of stack has more children to process - n = childd[nd + pd]; // load child pointer - pd++; - - if (n >= 0) { - dx = px - posxd[n]; - dy = py - posyd[n]; - tmp = dx*dx + dy*dy + epssqd; // distance squared plus small constant to prevent zeros - #if (CUDART_VERSION >= 9000) - if ((n < nbodiesd) || __all_sync(__activemask(), tmp >= dq[depth])) { // check if all threads agree that cell is far enough away (or is a body) - #else - if ((n < nbodiesd) || __all(tmp >= dq[depth])) { // check if all threads agree that cell is far enough away (or is a body) - #endif - // from bhtsne - sptree.cpp - tmp = 1 / (1 + tmp); - mult = massd[n] * tmp; - normsum += mult; - mult *= tmp; - vx += dx * mult; - vy += dy * mult; - } else { - // push cell onto stack - if (sbase == threadIdx.x) { // maybe don't push and inc if last child - pos[depth] = pd; - node[depth] = nd; - } - depth++; - pd = 0; - nd = n * 4; - } + do { + // stack is not empty + pd = pos[depth]; + nd = node[depth]; + while (pd < 8) { + // node on top of stack has more children to process + n = childd[nd + pd]; // load child pointer + pd++; + + if (n >= 0) { + const float4 pn = posMassd[n]; + dx = pn.x - pi.x; + dy = pn.y - pi.y; + dz = pn.z - pi.z; + tmp = dx*dx + (dy*dy + (dz*dz + epssqd)); // compute distance squared (plus softening) + if ((n < nbodiesd) || __all_sync(0xffffffff, tmp >= dq[depth])) { // check if all threads agree that cell is far enough away (or is a body) + tmp = rsqrtf(tmp); // compute distance + tmp = pn.w * tmp * tmp * tmp; + ax += dx * tmp; + ay += dy * tmp; + az += dz * tmp; } else { - pd = 4; // early out because all remaining children are also zero + // push cell onto stack + if (sbase == threadIdx.x) { + pos[depth] = pd; + node[depth] = nd; + } + depth++; + pd = 0; + nd = n * 8; } + } else { + pd = 8; // early out because all remaining children are also zero } - depth--; // done with this level - } while (depth >= j); - - if (stepd >= 0) { - // update velocity - velxd[i] += vx; - velyd[i] += vy; - normd[i] = normsum - 1.0f; // subtract one for self computation (qii) } + depth--; // done with this level + } while (depth >= j); + + float4 acc = accVeld[i]; + if (stepd > 0) { + // update velocity + float2 v = veld[i]; + v.x += (ax - acc.x) * dthfd; + v.y += (ay - acc.y) * dthfd; + acc.w += (az - acc.z) * dthfd; + veld[i] = v; } + + // save computed acceleration + acc.x = ax; + acc.y = ay; + acc.z = az; + accVeld[i] = acc; } } @@ -724,808 +645,343 @@ void ForceCalculationKernel(int nnodesd, /******************************************************************************/ /*** advance bodies ***********************************************************/ /******************************************************************************/ -// Edited to add momentum, repulsive, attr forces, etc. + __global__ __launch_bounds__(THREADS6, FACTOR6) -void IntegrationKernel(int N, - int nnodes, - float eta, - float norm, - float momentum, - float exaggeration, - volatile float * __restrict pts, // (nnodes + 1) x 2 - volatile float * __restrict attr_forces, // (N x 2) - volatile float * __restrict rep_forces, // (nnodes + 1) x 2 - volatile float * __restrict gains, - volatile float * __restrict old_forces) // (N x 2) +void IntegrationKernel(const int nbodiesd, const float dtimed, const float dthfd, float4* const __restrict__ posMass, float2* const __restrict__ veld, float4* const __restrict__ accVeld) { - register int i, inc; - register float dx, dy, ux, uy, gx, gy; + int i, inc; + float dvelx, dvely, dvelz; + float velhx, velhy, velhz; // iterate over all bodies assigned to thread inc = blockDim.x * gridDim.x; - for (i = threadIdx.x + blockIdx.x * blockDim.x; i < N; i += inc) { - ux = old_forces[i]; - uy = old_forces[N + i]; - gx = gains[i]; - gy = gains[N + i]; - dx = exaggeration*attr_forces[i] - (rep_forces[i] / norm); - dy = exaggeration*attr_forces[i + N] - (rep_forces[nnodes + 1 + i] / norm); - - gx = (signbit(dx) != signbit(ux)) ? gx + 0.2 : gx * 0.8; - gy = (signbit(dy) != signbit(uy)) ? gy + 0.2 : gy * 0.8; - gx = (gx < 0.01) ? 0.01 : gx; - gy = (gy < 0.01) ? 0.01 : gy; - - ux = momentum * ux - eta * gx * dx; - uy = momentum * uy - eta * gy * dy; - - pts[i] += ux; - pts[i + nnodes + 1] += uy; - - old_forces[i] = ux; - old_forces[N + i] = uy; - gains[i] = gx; - gains[N + i] = gy; - } + for (i = threadIdx.x + blockIdx.x * blockDim.x; i < nbodiesd; i += inc) { + // integrate + float4 acc = accVeld[i]; + dvelx = acc.x * dthfd; + dvely = acc.y * dthfd; + dvelz = acc.z * dthfd; + + float2 v = veld[i]; + velhx = v.x + dvelx; + velhy = v.y + dvely; + velhz = acc.w + dvelz; + + float4 p = posMass[i]; + p.x += velhx * dtimed; + p.y += velhy * dtimed; + p.z += velhz * dtimed; + posMass[i] = p; + + v.x = velhx + dvelx; + v.y = velhy + dvely; + acc.w = velhz + dvelz; + veld[i] = v; + accVeld[i] = acc; + } } - -/******************************************************************************/ -/*** compute attractive force *************************************************/ /******************************************************************************/ -__global__ -void csr2coo(int N, int nnz, - volatile int * __restrict pijRowPtr, - volatile int * __restrict pijColInd, - volatile int * __restrict indices) -{ - register int TID, i, j, start, end; - TID = threadIdx.x + blockIdx.x * blockDim.x; - if (TID >= nnz) return; - start = 0; end = N + 1; - i = (N + 1) >> 1; - while (end - start > 1) { - j = pijRowPtr[i]; - end = (j <= TID) ? end : i; - start = (j > TID) ? start : i; - i = (start + end) >> 1; - } - j = pijColInd[TID]; - indices[2*TID] = i; - indices[2*TID+1] = j; -} -__global__ -void ComputePijKernel(const unsigned int N, - const unsigned int K, - float * __restrict pij, - const float * __restrict sqdist, - const float * __restrict betas) +static void CudaTest(const char* const msg) { - register int TID, i, j; - register float dist, beta; - - TID = threadIdx.x + blockIdx.x * blockDim.x; - if (TID >= N * K) return; - i = TID / K; - j = TID % K; + cudaError_t e; - beta = betas[i]; - dist = sqdist[TID]; - pij[TID] = (j == 0 && dist == 0.0f) ? 0.0f : __expf(-beta * dist); // condition deals with evaluation of pii + cudaDeviceSynchronize(); + if (cudaSuccess != (e = cudaGetLastError())) { + fprintf(stderr, "%s: %d\n", msg, e); + fprintf(stderr, "%s\n", cudaGetErrorString(e)); + exit(-1); + } } -__global__ -void ComputePijxQijKernel(int N, int nnz, int nnodes, - volatile int * indices, - volatile float * __restrict pij, - volatile float * __restrict forceProd, - volatile float * __restrict pts) -{ - register int TID, i, j; //, inc; - register float ix, iy, jx, jy, dx, dy; - TID = threadIdx.x + blockIdx.x * blockDim.x; - // inc = blockDim.x * gridDim.x; - // for (TID = threadIdx.x + blockIdx.x * blockDim.x; TID < nnz; TID += inc) { - if (TID >= nnz) return; - i = indices[2*TID]; - j = indices[2*TID+1]; - ix = pts[i]; iy = pts[nnodes + 1 + i]; - jx = pts[j]; jy = pts[nnodes + 1 + j]; - dx = ix - jx; - dy = iy - jy; - forceProd[TID] = pij[TID] * 1 / (1 + dx*dx + dy*dy); - // } -} -__global__ -void PerplexitySearchKernel(const unsigned int N, - const float perplexity_target, - const float eps, - float * __restrict betas, - float * __restrict lower_bound, - float * __restrict upper_bound, - int * __restrict found, - const float * __restrict neg_entropy, - const float * __restrict row_sum) -{ - register int i, is_found; - register float perplexity, neg_ent, sum_P, pdiff, beta, min_beta, max_beta; - i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= N) return; - - neg_ent = neg_entropy[i]; - sum_P = row_sum[i]; - beta = betas[i]; - - min_beta = lower_bound[i]; - max_beta = upper_bound[i]; - - perplexity = (neg_ent / sum_P) + __logf(sum_P); - pdiff = perplexity - __logf(perplexity_target); - is_found = (pdiff < eps && - pdiff < eps); - if (!is_found) { - if (pdiff > 0) { - min_beta = beta; - beta = (max_beta == FLT_MAX || max_beta == -FLT_MAX) ? beta * 2.0f : (beta + max_beta) / 2.0f; - } else { - max_beta = beta; - beta = (min_beta == -FLT_MAX || min_beta == FLT_MAX) ? beta / 2.0f : (beta + min_beta) / 2.0f; - } - lower_bound[i] = min_beta; - upper_bound[i] = max_beta; - betas[i] = beta; - } - found[i] = is_found; -} -// computes unnormalized attractive forces -void computeAttrForce(int N, - int nnz, - int nnodes, - int attr_forces_grid_size, - int attr_forces_block_size, - cusparseHandle_t &handle, - cusparseMatDescr_t &descr, - thrust::device_vector &sparsePij, - thrust::device_vector &pijRowPtr, // (N + 1)-D vector, should be constant L - thrust::device_vector &pijColInd, // NxL matrix (same shape as sparsePij) - thrust::device_vector &forceProd, // NxL matrix - thrust::device_vector &pts, // (nnodes + 1) x 2 matrix - thrust::device_vector &forces, // N x 2 matrix - thrust::device_vector &ones, - thrust::device_vector &indices) // N x 2 matrix of ones -{ - // Computes pij*qij for each i,j - ComputePijxQijKernel<<>>(N, nnz, nnodes, - thrust::raw_pointer_cast(indices.data()), - thrust::raw_pointer_cast(sparsePij.data()), - thrust::raw_pointer_cast(forceProd.data()), - thrust::raw_pointer_cast(pts.data())); - // ComputePijxQijKernel<<>>(N, nnz, nnodes, - // thrust::raw_pointer_cast(indices.data()), - // thrust::raw_pointer_cast(sparsePij.data()), - // thrust::raw_pointer_cast(forceProd.data()), - // thrust::raw_pointer_cast(pts.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - // compute forces_i = sum_j pij*qij*normalization*yi - float alpha = 1.0f; - float beta = 0.0f; - CusparseSafeCall(cusparseScsrmm(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - N, 2, N, nnz, &alpha, descr, - thrust::raw_pointer_cast(forceProd.data()), - thrust::raw_pointer_cast(pijRowPtr.data()), - thrust::raw_pointer_cast(pijColInd.data()), - thrust::raw_pointer_cast(ones.data()), - N, &beta, thrust::raw_pointer_cast(forces.data()), - N)); - GpuErrorCheck(cudaDeviceSynchronize()); - thrust::transform(forces.begin(), forces.begin() + N, pts.begin(), forces.begin(), thrust::multiplies()); - thrust::transform(forces.begin() + N, forces.end(), pts.begin() + nnodes + 1, forces.begin() + N, thrust::multiplies()); - - // compute forces_i = forces_i - sum_j pij*qij*normalization*yj - alpha = -1.0f; - beta = 1.0f; - CusparseSafeCall(cusparseScsrmm(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - N, 2, N, nnz, &alpha, descr, - thrust::raw_pointer_cast(forceProd.data()), - thrust::raw_pointer_cast(pijRowPtr.data()), - thrust::raw_pointer_cast(pijColInd.data()), - thrust::raw_pointer_cast(pts.data()), - nnodes + 1, &beta, thrust::raw_pointer_cast(forces.data()), - N)); - GpuErrorCheck(cudaDeviceSynchronize()); - +/******************************************************************************/ -} +// random number generator (based on SPLASH-2 code at https://github.com/staceyson/splash2/blob/master/codes/apps/barnes/util.C) -// TODO: Add -1 notification here... and how to deal with it if it happens -// TODO: Maybe think about getting FAISS to return integers (long-term todo) -__global__ void postprocess_matrix(float* matrix, - long* long_indices, - int* indices, - unsigned int N_POINTS, - unsigned int K) -{ - register int TID = threadIdx.x + blockIdx.x * blockDim.x; - if (TID >= N_POINTS*K) return; +static int randx = 7; - // Set pij to 0 for each of the broken values - Note: this should be handled in the ComputePijKernel now - // if (matrix[TID] == 1.0f) matrix[TID] = 0.0f; - indices[TID] = (int) long_indices[TID]; - return; -} -thrust::device_vector search_perplexity(cublasHandle_t &handle, - thrust::device_vector &knn_distances, - const float perplexity_target, - const float eps, - const unsigned int N, - const unsigned int K) +static double drnd() { - // use beta instead of sigma (this matches the bhtsne code but not the paper) - // beta is just multiplicative instead of divisive (changes the way binary search works) - thrust::device_vector betas(N, 1.0f); - thrust::device_vector lbs(N, 0.0f); - thrust::device_vector ubs(N, 1000.0f); - thrust::device_vector pij(N*K); - thrust::device_vector entropy(N*K); - thrust::device_vector found(N); - - const unsigned int BLOCKSIZE1 = 1024; - const unsigned int NBLOCKS1 = iDivUp(N * K, BLOCKSIZE1); - - const unsigned int BLOCKSIZE2 = 128; - const unsigned int NBLOCKS2 = iDivUp(N, BLOCKSIZE2); - - int iters = 0; - int all_found = 0; - thrust::device_vector row_sum; - do { - // compute Gaussian Kernel row - ComputePijKernel<<>>(N, K, thrust::raw_pointer_cast(pij.data()), - thrust::raw_pointer_cast(knn_distances.data()), - thrust::raw_pointer_cast(betas.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - // compute entropy of current row - row_sum = tsnecuda::util::ReduceSum(handle, pij, K, N, 0); - thrust::transform(pij.begin(), pij.end(), entropy.begin(), tsnecuda::util::FunctionalEntropy()); - auto neg_entropy = tsnecuda::util::ReduceAlpha(handle, entropy, K, N, -1.0f, 0); - - // binary search for beta - PerplexitySearchKernel<<>>(N, perplexity_target, eps, - thrust::raw_pointer_cast(betas.data()), - thrust::raw_pointer_cast(lbs.data()), - thrust::raw_pointer_cast(ubs.data()), - thrust::raw_pointer_cast(found.data()), - thrust::raw_pointer_cast(neg_entropy.data()), - thrust::raw_pointer_cast(row_sum.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - all_found = thrust::reduce(found.begin(), found.end(), 1, thrust::minimum()); - iters++; - } while (!all_found && iters < 200); - // TODO: Warn if iters == 200 because perplexity not found? - - tsnecuda::util::BroadcastMatrixVector(pij, row_sum, K, N, thrust::divides(), 1, 1.0f); - return pij; + const int lastrand = randx; + randx = (1103515245 * randx + 12345) & 0x7FFFFFFF; + return (double)lastrand / 2147483648.0; } -void BHTSNE::tsne(cublasHandle_t &dense_handle, cusparseHandle_t &sparse_handle, BHTSNE::Options &opt) { - - // Check the validity of the options file - if (!opt.validate()) { - std::cout << "E: Invalid options file. Terminating." << std::endl; - return; - } - // Setup some return information if we're working on snapshots - int snap_interval; - int snap_num = 0; - if (opt.return_style == BHTSNE::RETURN_STYLE::SNAPSHOT) { - snap_interval = opt.iterations / (opt.num_snapshots-1); - } +/******************************************************************************/ - // Setup clock information - auto end_time = std::chrono::high_resolution_clock::now(); - auto start_time = std::chrono::high_resolution_clock::now(); - int times[30]; for (int i = 0; i < 30; i++) times[i] = 0; - - // Allocate Memory for the KNN problem and do some configuration - start_time = std::chrono::high_resolution_clock::now(); - - // Setup CUDA configs - cudaFuncSetCacheConfig(BoundingBoxKernel, cudaFuncCachePreferShared); - cudaFuncSetCacheConfig(TreeBuildingKernel, cudaFuncCachePreferL1); - cudaFuncSetCacheConfig(ClearKernel1, cudaFuncCachePreferL1); - cudaFuncSetCacheConfig(ClearKernel2, cudaFuncCachePreferL1); - cudaFuncSetCacheConfig(SummarizationKernel, cudaFuncCachePreferShared); - cudaFuncSetCacheConfig(SortKernel, cudaFuncCachePreferL1); - #ifdef __KEPLER__ - cudaFuncSetCacheConfig(ForceCalculationKernel, cudaFuncCachePreferEqual); - #else - cudaFuncSetCacheConfig(ForceCalculationKernel, cudaFuncCachePreferL1); - #endif - cudaFuncSetCacheConfig(IntegrationKernel, cudaFuncCachePreferL1); - cudaFuncSetCacheConfig(ComputePijxQijKernel, cudaFuncCachePreferShared); - - // Allocate some memory - const unsigned int K = opt.n_neighbors < opt.n_points ? opt.n_neighbors : opt.n_points - 1; - float *knn_distances = new float[opt.n_points*K]; - memset(knn_distances, 0, opt.n_points * K * sizeof(float)); - long *knn_indices = new long[opt.n_points*K]; // Allocate memory for the indices on the CPU - - end_time = std::chrono::high_resolution_clock::now(); - times[0] = std::chrono::duration_cast(end_time-start_time).count(); - - // Compute the KNNs and distances - start_time = std::chrono::high_resolution_clock::now(); - - // Do KNN Call - tsnecuda::util::KNearestNeighbors(knn_indices, knn_distances, opt.points, opt.n_dims, opt.n_points, K); - - end_time = std::chrono::high_resolution_clock::now(); - times[1] = std::chrono::duration_cast(end_time-start_time).count(); - - // Copy the distances to the GPU and compute Pij - start_time = std::chrono::high_resolution_clock::now(); - - // Allocate device distance memory - thrust::device_vector d_knn_distances(knn_distances, knn_distances + (opt.n_points * K)); - tsnecuda::util::MaxNormalizeDeviceVector(d_knn_distances); // Here, the extra 0s floating around won't matter - thrust::device_vector d_pij = search_perplexity(dense_handle, d_knn_distances, opt.perplexity, opt.perplexity_search_epsilon, opt.n_points, K); - - // Clean up distance memory - d_knn_distances.clear(); - d_knn_distances.shrink_to_fit(); - - // Copy the distances back to the GPU - thrust::device_vector d_knn_indices_long(knn_indices, knn_indices + opt.n_points*K); - thrust::device_vector d_knn_indices(opt.n_points*K); - - // Post-process the floating point matrix - const int NBLOCKS_PP = iDivUp(opt.n_points*K, 128); - postprocess_matrix<<< NBLOCKS_PP, 128 >>>(thrust::raw_pointer_cast(d_pij.data()), - thrust::raw_pointer_cast(d_knn_indices_long.data()), - thrust::raw_pointer_cast(d_knn_indices.data()), opt.n_points, K); - cudaDeviceSynchronize(); - - // Clean up extra memory - d_knn_indices_long.clear(); - d_knn_indices_long.shrink_to_fit(); - delete[] knn_distances; - delete[] knn_indices; - - end_time = std::chrono::high_resolution_clock::now(); - times[2] = std::chrono::duration_cast(end_time-start_time).count(); - - // Symmetrize the Pij matrix - start_time = std::chrono::high_resolution_clock::now(); - - // Construct sparse matrix descriptor - cusparseMatDescr_t descr; - cusparseCreateMatDescr(&descr); - cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL); - cusparseSetMatIndexBase(descr,CUSPARSE_INDEX_BASE_ZERO); - - // Compute the symmetrized matrix - thrust::device_vector sparsePij; // Device - thrust::device_vector pijRowPtr; // Device - thrust::device_vector pijColInd; // Device - tsnecuda::util::SymmetrizeMatrix(sparse_handle, - d_pij, d_knn_indices, sparsePij, pijColInd, pijRowPtr, opt.n_points, K, opt.magnitude_factor); - - // Clear some old memory - d_knn_indices.clear(); - d_knn_indices.shrink_to_fit(); - d_pij.clear(); - d_pij.shrink_to_fit(); - - end_time = std::chrono::high_resolution_clock::now(); - times[3] = std::chrono::duration_cast(end_time-start_time).count(); - - - // Do setup for Barnes-Hut computation - start_time = std::chrono::high_resolution_clock::now(); - - // Compute the CUDA device properties - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - if (deviceProp.warpSize != WARPSIZE) { - fprintf(stderr, "Warp size must be %d\n", deviceProp.warpSize); - exit(-1); - } - int blocks = deviceProp.multiProcessorCount; - std::cout << "Multiprocessor Count: " << blocks << std::endl; - std::cout << "GPU Architecture: " << GPU_ARCH << std::endl; - - // Figure out the number of nodes needed for the BH tree - int nnodes = opt.n_points * 2; - if (nnodes < 1024*blocks) nnodes = 1024*blocks; - while ((nnodes & (WARPSIZE-1)) != 0) nnodes++; - nnodes--; - - opt.n_nodes = nnodes; - - std::cout << "Number of nodes chosen: " << nnodes << std::endl; - - int attr_forces_block_size; - int attr_forces_min_grid_size; - int attr_forces_grid_size; - cudaOccupancyMaxPotentialBlockSize( &attr_forces_min_grid_size, &attr_forces_block_size, ComputePijxQijKernel, 0, 0); - attr_forces_grid_size = (sparsePij.size() + attr_forces_block_size - 1) / attr_forces_block_size; - std::cout << "Autotuned attractive force kernel - Grid size: " << attr_forces_grid_size << " Block Size: " << attr_forces_block_size << std::endl; - - - // Allocate memory for the barnes hut implementations - thrust::device_vector forceProd(sparsePij.size()); - thrust::device_vector rep_forces((nnodes + 1) * 2, 0); - thrust::device_vector attr_forces(opt.n_points * 2, 0); - thrust::device_vector gains(opt.n_points * 2, 1); - thrust::device_vector old_forces(opt.n_points * 2, 0); // for momentum - thrust::device_vector errl(1); - thrust::device_vector startl(nnodes + 1); - thrust::device_vector childl((nnodes + 1) * 4); - thrust::device_vector massl(nnodes + 1, 1.0); // TODO: probably don't need massl - thrust::device_vector countl(nnodes + 1); - thrust::device_vector sortl(nnodes + 1); - thrust::device_vector norml(nnodes + 1); - thrust::device_vector maxxl(blocks * FACTOR1); - thrust::device_vector maxyl(blocks * FACTOR1); - thrust::device_vector minxl(blocks * FACTOR1); - thrust::device_vector minyl(blocks * FACTOR1); - thrust::device_vector ones(opt.n_points * 2, 1); // This is for reduce summing, etc. - thrust::device_vector indices(sparsePij.size()*2); - - // Compute the indices setup - const int SBS = 1024; - const int NBS = iDivUp(sparsePij.size(), SBS); - csr2coo<<>>(opt.n_points, sparsePij.size(), - thrust::raw_pointer_cast(pijRowPtr.data()), - thrust::raw_pointer_cast(pijColInd.data()), - thrust::raw_pointer_cast(indices.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - // Point initialization - thrust::device_vector pts((nnodes + 1) * 2); - thrust::device_vector random_vec(pts.size()); - - if (opt.initialization == BHTSNE::TSNE_INIT::UNIFORM) { // Random uniform initialization - pts = tsnecuda::util::RandomDeviceVectorInRange((nnodes+1)*2, -100, 100); - } else if (opt.initialization == BHTSNE::TSNE_INIT::GAUSSIAN) { // Random gaussian initialization - std::default_random_engine generator; - std::normal_distribution distribution1(0.0, 1.0); - thrust::host_vector h_pts(opt.n_points); - for (int i = 0; i < opt.n_points; i++) - h_pts[i] = 0.0001 * distribution1(generator); - thrust::copy(h_pts.begin(), h_pts.end(), pts.begin()); - for (int i = 0; i < opt.n_points; i++) - h_pts[i] = 0.0001 * distribution1(generator); - thrust::copy(h_pts.begin(), h_pts.end(), pts.begin()+nnodes+1); - } else if (opt.initialization == BHTSNE::TSNE_INIT::RESUME) { // Preinit from vector - // Load from vector - if(opt.preinit_data != nullptr) { - thrust::copy(opt.preinit_data, opt.preinit_data+(nnodes+1)*2, pts.begin()); - } - else { - std::cout << "E: Invalid initialization. Initialization points are null." << std::endl; - } - } else if (opt.initialization == BHTSNE::TSNE_INIT::VECTOR) { // Preinit from vector points only - // Copy the pre-init data - if(opt.preinit_data != nullptr) { - thrust::copy(opt.preinit_data, opt.preinit_data+opt.n_points, pts.begin()); - thrust::copy(opt.preinit_data+opt.n_points+1, opt.preinit_data+opt.n_points*2 , pts.begin()+(nnodes+1)); - tsnecuda::util::GaussianNormalizeDeviceVector(dense_handle, pts, (nnodes+1), 2); - } - else { - std::cout << "E: Invalid initialization. Initialization points are null." << std::endl; - } - } else { // Invalid initialization - std::cout << "E: Invalid initialization type specified." << std::endl; - exit(1); - } +int main(int argc, char* argv[]) +{ + int i, run, blocks; + int nnodes, nbodies, step, timesteps; + double runtime; + float dtime, dthf, epssq, itolsq; + float time, timing[7]; + cudaEvent_t start, stop; + + float4 *accVel; + float2 *vel; + int *sortl, *childl, *countl, *startl; + float4 *accVell; + float2 *vell; + float3 *maxl, *minl; + float4 *posMassl; + float4 *posMass; + double rsc, vsc, r, v, x, y, z, sq, scale; + + // perform some checks + + printf("ECL-BH v4.5\n"); + printf("Copyright (c) 2010-2020 Texas State University\n"); + fflush(stdout); + + if (argc != 4) { + fprintf(stderr, "\n"); + fprintf(stderr, "arguments: number_of_bodies number_of_timesteps device\n"); + exit(-1); + } - // Initialize the learning rates and momentums - float eta = opt.learning_rate; - float momentum = opt.pre_exaggeration_momentum; - float norm; - - // These variables currently govern the tolerance (whether it recurses on a cell) - float theta = opt.theta; - float epssq = opt.epssq; - - // Initialize the GPU tree memory - InitializationKernel<<<1, 1>>>(thrust::raw_pointer_cast(errl.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[4] = std::chrono::duration_cast(end_time-start_time).count(); - - // Dump file - float *host_ys = nullptr; - std::ofstream dump_file; - if (opt.get_dump_points()) { - dump_file.open(opt.get_dump_file()); - host_ys = new float[(nnodes + 1) * 2]; - dump_file << opt.n_points << " " << 2 << std::endl; - } + int deviceCount; + cudaGetDeviceCount(&deviceCount); + if (deviceCount == 0) { + fprintf(stderr, "There is no device supporting CUDA\n"); + exit(-1); + } - #ifndef NO_ZMQ - - bool send_zmq = opt.get_use_interactive(); - zmq::context_t context(1); - zmq::socket_t publisher(context, ZMQ_REQ); - if (opt.get_use_interactive()) { - - // Try to connect to the socket - if (opt.verbosity >= 1) - std::cout << "Initializing Connection...." << std::endl; - publisher.setsockopt(ZMQ_RCVTIMEO, opt.get_viz_timeout()); - publisher.setsockopt(ZMQ_SNDTIMEO, opt.get_viz_timeout()); - if (opt.verbosity >= 1) - std::cout << "Waiting for connection to visualization for 10 secs...." << std::endl; - publisher.connect(opt.get_viz_server()); - - // Send the number of points we should be expecting to the server - std::string message = std::to_string(opt.n_points); - send_zmq = publisher.send(message.c_str(), message.length()); - - // Wait for server reply - zmq::message_t request; - send_zmq = publisher.recv (&request); - - // If there's a time-out, don't bother. - if (send_zmq) { - if (opt.verbosity >= 1) - std::cout << "Visualization connected!" << std::endl; - } else { - std::cout << "No Visualization Terminal, continuing..." << std::endl; - send_zmq = false; - } - } - #endif - - #ifdef NO_ZMQ - if (opt.get_use_interactive()) - std::cout << "This version is not built with ZMQ for interative viz. Rebuild with WITH_ZMQ=TRUE for viz." << std::endl; - #endif - - // Support for infinite iteration - float attr_exaggeration = opt.early_exaggeration; - - // Random noise handling - std::default_random_engine generator; - std::normal_distribution distribution1(0.0, 1.0); - thrust::host_vector h_pts(opt.n_points*2); - thrust::device_vector rand_noise(opt.n_points*2); - - for (int step = 0; step != opt.iterations; step++) { - - // Setup learning rate schedule - if (step == opt.force_magnify_iters) { - momentum = opt.post_exaggeration_momentum; - attr_exaggeration = 1.0f; - } + const int dev = atoi(argv[3]); + if ((dev < 0) || (deviceCount <= dev)) { + fprintf(stderr, "There is no device %d\n", dev); + exit(-1); + } + cudaSetDevice(dev); - // Do Force Reset - start_time = std::chrono::high_resolution_clock::now(); - - thrust::fill(attr_forces.begin(), attr_forces.end(), 0); - thrust::fill(rep_forces.begin(), rep_forces.end(), 0); - - end_time = std::chrono::high_resolution_clock::now(); - times[5] += std::chrono::duration_cast(end_time-start_time).count(); - - - // Bounding box kernel - start_time = std::chrono::high_resolution_clock::now(); - - BoundingBoxKernel<<>>(nnodes, - opt.n_points, - thrust::raw_pointer_cast(startl.data()), - thrust::raw_pointer_cast(childl.data()), - thrust::raw_pointer_cast(massl.data()), - thrust::raw_pointer_cast(pts.data()), - thrust::raw_pointer_cast(pts.data() + nnodes + 1), - thrust::raw_pointer_cast(maxxl.data()), - thrust::raw_pointer_cast(maxyl.data()), - thrust::raw_pointer_cast(minxl.data()), - thrust::raw_pointer_cast(minyl.data())); - - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[6] += std::chrono::duration_cast(end_time-start_time).count(); - - // Tree Building - start_time = std::chrono::high_resolution_clock::now(); - - ClearKernel1<<>>(nnodes, opt.n_points, thrust::raw_pointer_cast(childl.data())); - TreeBuildingKernel<<>>(nnodes, opt.n_points, thrust::raw_pointer_cast(errl.data()), - thrust::raw_pointer_cast(childl.data()), - thrust::raw_pointer_cast(pts.data()), - thrust::raw_pointer_cast(pts.data() + nnodes + 1)); - ClearKernel2<<>>(nnodes, thrust::raw_pointer_cast(startl.data()), thrust::raw_pointer_cast(massl.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[7] += std::chrono::duration_cast(end_time-start_time).count(); - - // Tree Summarization - start_time = std::chrono::high_resolution_clock::now(); - - SummarizationKernel<<>>(nnodes, opt.n_points, thrust::raw_pointer_cast(countl.data()), - thrust::raw_pointer_cast(childl.data()), - thrust::raw_pointer_cast(massl.data()), - thrust::raw_pointer_cast(pts.data()), - thrust::raw_pointer_cast(pts.data() + nnodes + 1)); - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[8] += std::chrono::duration_cast(end_time-start_time).count(); - - // Force sorting - start_time = std::chrono::high_resolution_clock::now(); - - SortKernel<<>>(nnodes, opt.n_points, thrust::raw_pointer_cast(sortl.data()), - thrust::raw_pointer_cast(countl.data()), - thrust::raw_pointer_cast(startl.data()), - thrust::raw_pointer_cast(childl.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[9] += std::chrono::duration_cast(end_time-start_time).count(); - - // Repulsive force calculation - start_time = std::chrono::high_resolution_clock::now(); - - ForceCalculationKernel<<>>(nnodes, opt.n_points, thrust::raw_pointer_cast(errl.data()), - theta, epssq, - thrust::raw_pointer_cast(sortl.data()), - thrust::raw_pointer_cast(childl.data()), - thrust::raw_pointer_cast(massl.data()), - thrust::raw_pointer_cast(pts.data()), - thrust::raw_pointer_cast(pts.data() + nnodes + 1), - thrust::raw_pointer_cast(rep_forces.data()), - thrust::raw_pointer_cast(rep_forces.data() + nnodes + 1), - thrust::raw_pointer_cast(norml.data())); - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[10] += std::chrono::duration_cast(end_time-start_time).count(); - - // Attractive Force Computation - start_time = std::chrono::high_resolution_clock::now(); - - // compute attractive forces - computeAttrForce(opt.n_points, sparsePij.size(), nnodes, attr_forces_grid_size, attr_forces_block_size, sparse_handle, descr, sparsePij, pijRowPtr, pijColInd, forceProd, pts, attr_forces, ones, indices); - GpuErrorCheck(cudaDeviceSynchronize()); - - end_time = std::chrono::high_resolution_clock::now(); - times[11] += std::chrono::duration_cast(end_time-start_time).count(); - - - // Move the particles - start_time = std::chrono::high_resolution_clock::now(); - - // Compute the normalization constant - norm = thrust::reduce(norml.begin(), norml.end(), 0.0f, thrust::plus()); - - // Integrate - IntegrationKernel<<>>(opt.n_points, nnodes, eta, norm, momentum, attr_exaggeration, - thrust::raw_pointer_cast(pts.data()), - thrust::raw_pointer_cast(attr_forces.data()), - thrust::raw_pointer_cast(rep_forces.data()), - thrust::raw_pointer_cast(gains.data()), - thrust::raw_pointer_cast(old_forces.data())); - for (int i = 0; i < opt.n_points*2; i++) - h_pts[i] = 0.001 * distribution1(generator); - GpuErrorCheck(cudaDeviceSynchronize()); - thrust::copy(h_pts.begin(), h_pts.end(), rand_noise.begin()); - - // Compute the gradient norm - tsnecuda::util::SquareDeviceVector(attr_forces, old_forces); - thrust::transform(attr_forces.begin(), attr_forces.begin()+opt.n_points, - attr_forces.begin()+opt.n_points, attr_forces.begin(), thrust::plus()); - tsnecuda::util::SqrtDeviceVector(attr_forces, attr_forces); - float grad_norm = thrust::reduce(attr_forces.begin(), attr_forces.begin()+opt.n_points, 0.0f, thrust::plus()) / opt.n_points; - - if (opt.verbosity >= 1 && step % opt.print_interval == 0) - std::cout << "[Step " << step << "] Average Gradient Norm: " << grad_norm << std::endl; - - // Add some random noise to the points - thrust::transform(pts.begin(), pts.begin()+opt.n_points, rand_noise.begin(), pts.begin(), thrust::plus()); - thrust::transform(pts.begin()+nnodes+1, pts.begin()+nnodes+1+opt.n_points, rand_noise.begin()+opt.n_points, pts.begin()+nnodes+1, thrust::plus()); - - end_time = std::chrono::high_resolution_clock::now(); - times[12] += std::chrono::duration_cast(end_time-start_time).count(); - - #ifndef NO_ZMQ - if (send_zmq) { - zmq::message_t message(sizeof(float)*opt.n_points*2); - thrust::copy(pts.begin(), pts.begin()+opt.n_points, static_cast(message.data())); - thrust::copy(pts.begin()+nnodes+1, pts.begin()+nnodes+1+opt.n_points, static_cast(message.data())+opt.n_points); - bool res = false; - res = publisher.send(message); - zmq::message_t request; - res = publisher.recv(&request); - if (!res) { - std::cout << "Server Disconnected, Not sending anymore for this session." << std::endl; - } - send_zmq = res; - } - #endif + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, dev); + if ((deviceProp.major == 9999) && (deviceProp.minor == 9999)) { + fprintf(stderr, "There is no CUDA capable device\n"); + exit(-1); + } + if (deviceProp.major < 3) { + fprintf(stderr, "Need at least compute capability 3.0\n"); + exit(-1); + } + if (deviceProp.warpSize != WARPSIZE) { + fprintf(stderr, "Warp size must be %d\n", deviceProp.warpSize); + exit(-1); + } - if (opt.get_dump_points() && step % opt.get_dump_interval() == 0) { - thrust::copy(pts.begin(), pts.end(), host_ys); - for (int i = 0; i < opt.n_points; i++) { - dump_file << host_ys[i] << " " << host_ys[i + nnodes + 1] << std::endl; - } - } + blocks = deviceProp.multiProcessorCount; + const int mTSM = deviceProp.maxThreadsPerMultiProcessor; + printf("gpu: %s with %d SMs and %d mTpSM (%.1f MHz and %.1f MHz)\n", deviceProp.name, blocks, mTSM, deviceProp.clockRate * 0.001, deviceProp.memoryClockRate * 0.001); - // Handle snapshoting - if (opt.return_style == BHTSNE::RETURN_STYLE::SNAPSHOT && step % snap_interval == 0 && opt.return_data != nullptr) { - thrust::copy(pts.begin(), - pts.begin()+opt.n_points, - snap_num*opt.n_points*2 + opt.return_data); - thrust::copy(pts.begin()+nnodes+1, - pts.begin()+nnodes+1+opt.n_points, - snap_num*opt.n_points*2 + opt.return_data+opt.n_points); - snap_num += 1; - } + if ((WARPSIZE <= 0) || (WARPSIZE & (WARPSIZE-1) != 0)) { + fprintf(stderr, "Warp size must be greater than zero and a power of two\n"); + exit(-1); + } + if (MAXDEPTH > WARPSIZE) { + fprintf(stderr, "MAXDEPTH must be less than or equal to WARPSIZE\n"); + exit(-1); + } + if ((THREADS1 <= 0) || (THREADS1 & (THREADS1-1) != 0)) { + fprintf(stderr, "THREADS1 must be greater than zero and a power of two\n"); + exit(-1); + } + // set L1/shared memory configuration + cudaFuncSetCacheConfig(BoundingBoxKernel, cudaFuncCachePreferShared); + cudaFuncSetCacheConfig(TreeBuildingKernel, cudaFuncCachePreferL1); + cudaFuncSetCacheConfig(ClearKernel1, cudaFuncCachePreferL1); + cudaFuncSetCacheConfig(ClearKernel2, cudaFuncCachePreferL1); + cudaFuncSetCacheConfig(SummarizationKernel, cudaFuncCachePreferShared); + cudaFuncSetCacheConfig(SortKernel, cudaFuncCachePreferL1); + cudaFuncSetCacheConfig(ForceCalculationKernel, cudaFuncCachePreferEqual); + cudaFuncSetCacheConfig(IntegrationKernel, cudaFuncCachePreferL1); + + cudaGetLastError(); // reset error value + for (run = 0; run < 1; run++) { // in case multiple runs are desired for timing + for (i = 0; i < 7; i++) timing[i] = 0.0f; + + nbodies = atoi(argv[1]); + if (nbodies < 1) { + fprintf(stderr, "nbodies is too small: %d\n", nbodies); + exit(-1); } - - // Clean up the dump file if we are dumping points - if (opt.get_dump_points()){ - delete[] host_ys; - dump_file.close(); + if (nbodies > (1 << 30)) { + fprintf(stderr, "nbodies is too large: %d\n", nbodies); + exit(-1); } - - // With verbosity 2, print the timing data - if (opt.verbosity >= 2) { - int p1_time = times[0] + times[1] + times[2] + times[3]; - int p2_time = times[4] + times[5] + times[6] + times[7] + times[8] + times[9] + times[10] + times[11] + times[12]; - std::cout << "Timing data: " << std::endl; - std::cout << "\t Phase 1 (" << p1_time << "us):" << std::endl; - std::cout << "\t\tKernel Setup: " << times[0] << "us" << std::endl; - std::cout << "\t\tKNN Computation: " << times[1] << "us" << std::endl; - std::cout << "\t\tPIJ Computation: " << times[2] << "us" << std::endl; - std::cout << "\t\tPIJ Symmetrization: " << times[3] << "us" << std::endl; - std::cout << "\t Phase 2 (" << p2_time << "us):" << std::endl; - std::cout << "\t\tKernel Setup: " << times[4] << "us" << std::endl; - std::cout << "\t\tForce Reset: " << times[5] << "us" << std::endl; - std::cout << "\t\tBounding Box: " << times[6] << "us" << std::endl; - std::cout << "\t\tTree Building: " << times[7] << "us" << std::endl; - std::cout << "\t\tTree Summarization: " << times[8] << "us" << std::endl; - std::cout << "\t\tSorting: " << times[9] << "us" << std::endl; - std::cout << "\t\tRepulsive Force Calculation: " << times[10] << "us" << std::endl; - std::cout << "\t\tAttractive Force Calculation: " << times[11] << "us" << std::endl; - std::cout << "\t\tIntegration: " << times[12] << "us" << std::endl; - std::cout << "Total Time: " << p1_time + p2_time << "us" << std::endl << std::endl; + nnodes = nbodies * 2; + if (nnodes < 1024*blocks) nnodes = 1024*blocks; + while ((nnodes & (WARPSIZE-1)) != 0) nnodes++; + nnodes--; + + timesteps = atoi(argv[2]); + dtime = 0.025; dthf = dtime * 0.5f; + epssq = 0.05 * 0.05; + itolsq = 1.0f / (0.5 * 0.5); + + // allocate memory + + if (run == 0) { + printf("configuration: %d bodies, %d time steps\n", nbodies, timesteps); + + accVel = (float4*)malloc(sizeof(float4) * nbodies); + if (accVel == NULL) {fprintf(stderr, "cannot allocate accVel\n"); exit(-1);} + vel = (float2*)malloc(sizeof(float2) * nbodies); + if (vel == NULL) {fprintf(stderr, "cannot allocate vel\n"); exit(-1);} + posMass = (float4*)malloc(sizeof(float4) * nbodies); + if (posMass == NULL) {fprintf(stderr, "cannot allocate posMass\n"); exit(-1);} + + if (cudaSuccess != cudaMalloc((void **)&childl, sizeof(int) * (nnodes+1) * 8)) fprintf(stderr, "could not allocate childd\n"); CudaTest("couldn't allocate childd"); + if (cudaSuccess != cudaMalloc((void **)&vell, sizeof(float2) * (nnodes+1))) fprintf(stderr, "could not allocate veld\n"); CudaTest("couldn't allocate veld"); + if (cudaSuccess != cudaMalloc((void **)&accVell, sizeof(float4) * (nnodes+1))) fprintf(stderr, "could not allocate accVeld\n"); CudaTest("couldn't allocate accVeld"); + if (cudaSuccess != cudaMalloc((void **)&countl, sizeof(int) * (nnodes+1))) fprintf(stderr, "could not allocate countd\n"); CudaTest("couldn't allocate countd"); + if (cudaSuccess != cudaMalloc((void **)&startl, sizeof(int) * (nnodes+1))) fprintf(stderr, "could not allocate startd\n"); CudaTest("couldn't allocate startd"); + if (cudaSuccess != cudaMalloc((void **)&sortl, sizeof(int) * (nnodes+1))) fprintf(stderr, "could not allocate sortd\n"); CudaTest("couldn't allocate sortd"); + + if (cudaSuccess != cudaMalloc((void **)&posMassl, sizeof(float4) * (nnodes+1))) fprintf(stderr, "could not allocate posMassd\n"); CudaTest("couldn't allocate posMassd"); + + if (cudaSuccess != cudaMalloc((void **)&maxl, sizeof(float3) * blocks * FACTOR1)) fprintf(stderr, "could not allocate maxd\n"); CudaTest("couldn't allocate maxd"); + if (cudaSuccess != cudaMalloc((void **)&minl, sizeof(float3) * blocks * FACTOR1)) fprintf(stderr, "could not allocate mind\n"); CudaTest("couldn't allocate mind"); } - // std::cout << FACTOR1 << "," << FACTOR2 << "," << FACTOR3 << "," << FACTOR4 << "," << FACTOR5 << "," << FACTOR6 <= 1) std::cout << "Fin." << std::endl; - - // Handle a once off return type - if (opt.return_style == BHTSNE::RETURN_STYLE::ONCE && opt.return_data != nullptr) { - thrust::copy(pts.begin(), pts.begin()+opt.n_points, opt.return_data); - thrust::copy(pts.begin()+nnodes+1, pts.begin()+nnodes+1+opt.n_points, opt.return_data+opt.n_points); + + // generate input (based on SPLASH-2 code at https://github.com/staceyson/splash2/blob/master/codes/apps/barnes/code.C) + + rsc = (3 * 3.1415926535897932384626433832795) / 16; + vsc = sqrt(1.0 / rsc); + for (i = 0; i < nbodies; i++) { + float4 p; + p.w = 1.0 / nbodies; + r = 1.0 / sqrt(pow(drnd()*0.999, -2.0/3.0) - 1); + do { + x = drnd()*2.0 - 1.0; + y = drnd()*2.0 - 1.0; + z = drnd()*2.0 - 1.0; + sq = x*x + y*y + z*z; + } while (sq > 1.0); + scale = rsc * r / sqrt(sq); + p.x = x * scale; + p.y = y * scale; + p.z = z * scale; + posMass[i] = p; + + do { + x = drnd(); + y = drnd() * 0.1; + } while (y > x*x * pow(1 - x*x, 3.5)); + v = x * sqrt(2.0 / sqrt(1 + r*r)); + do { + x = drnd()*2.0 - 1.0; + y = drnd()*2.0 - 1.0; + z = drnd()*2.0 - 1.0; + sq = x*x + y*y + z*z; + } while (sq > 1.0); + scale = vsc * v / sqrt(sq); + float2 v; + v.x = x * scale; + v.y = y * scale; + accVel[i].w = z * scale; + vel[i] = v; } - // Handle snapshoting - if (opt.return_style == BHTSNE::RETURN_STYLE::SNAPSHOT && opt.return_data != nullptr) { - thrust::copy(pts.begin(), pts.begin()+opt.n_points, snap_num*opt.n_points*2 + opt.return_data); - thrust::copy(pts.begin()+nnodes+1, pts.begin()+nnodes+1+opt.n_points, snap_num*opt.n_points*2 + opt.return_data+opt.n_points); + if (cudaSuccess != cudaMemcpy(accVell, accVel, sizeof(float4) * nbodies, cudaMemcpyHostToDevice)) fprintf(stderr, "copying of vel to device failed\n"); CudaTest("vel copy to device failed"); + if (cudaSuccess != cudaMemcpy(vell, vel, sizeof(float2) * nbodies, cudaMemcpyHostToDevice)) fprintf(stderr, "copying of vel to device failed\n"); CudaTest("vel copy to device failed"); + if (cudaSuccess != cudaMemcpy(posMassl, posMass, sizeof(float4) * nbodies, cudaMemcpyHostToDevice)) fprintf(stderr, "copying of posMass to device failed\n"); CudaTest("posMass copy to device failed"); + + // run timesteps (launch GPU kernels) + + cudaEventCreate(&start); cudaEventCreate(&stop); + struct timeval starttime, endtime; + gettimeofday(&starttime, NULL); + + cudaEventRecord(start, 0); + InitializationKernel<<<1, 1>>>(); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[0] += time; + //CudaTest("kernel 0 launch failed"); + + for (step = 0; step < timesteps; step++) { + cudaEventRecord(start, 0); + BoundingBoxKernel<<>>(nnodes, nbodies, startl, childl, posMassl, maxl, minl); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[1] += time; + //CudaTest("kernel 1 launch failed"); + + cudaEventRecord(start, 0); + ClearKernel1<<>>(nnodes, nbodies, childl); + TreeBuildingKernel<<>>(nnodes, nbodies, childl, posMassl); + ClearKernel2<<>>(nnodes, startl, posMassl); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[2] += time; + //CudaTest("kernel 2 launch failed"); + + cudaEventRecord(start, 0); + SummarizationKernel<<>>(nnodes, nbodies, countl, childl, posMassl); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[3] += time; + //CudaTest("kernel 3 launch failed"); + + cudaEventRecord(start, 0); + SortKernel<<>>(nnodes, nbodies, sortl, countl, startl, childl); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[4] += time; + //CudaTest("kernel 4 launch failed"); + + cudaEventRecord(start, 0); + ForceCalculationKernel<<>>(nnodes, nbodies, dthf, itolsq, epssq, sortl, childl, posMassl, vell, accVell); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[5] += time; + //CudaTest("kernel 5 launch failed"); + + cudaEventRecord(start, 0); + IntegrationKernel<<>>(nbodies, dtime, dthf, posMassl, vell, accVell); + cudaEventRecord(stop, 0); cudaEventSynchronize(stop); cudaEventElapsedTime(&time, start, stop); + timing[6] += time; + //CudaTest("kernel 6 launch failed"); } + CudaTest("kernel launch failed"); + cudaEventDestroy(start); cudaEventDestroy(stop); - // Return some final values - opt.trained = true; - opt.trained_norm = norm; + gettimeofday(&endtime, NULL); + runtime = (endtime.tv_sec + endtime.tv_usec/1000000.0 - starttime.tv_sec - starttime.tv_usec/1000000.0); - return; -} + printf("runtime: %.4lf s (", runtime); + time = 0; + for (i = 1; i < 7; i++) { + printf(" %.1f ", timing[i]); + time += timing[i]; + } + printf(") = %.1f ms\n", time); + } + // transfer final result back to CPU + if (cudaSuccess != cudaMemcpy(accVel, accVell, sizeof(float4) * nbodies, cudaMemcpyDeviceToHost)) fprintf(stderr, "copying of accVel from device failed\n"); CudaTest("vel copy from device failed"); + if (cudaSuccess != cudaMemcpy(vel, vell, sizeof(float2) * nbodies, cudaMemcpyDeviceToHost)) fprintf(stderr, "copying of vel from device failed\n"); CudaTest("vel copy from device failed"); + if (cudaSuccess != cudaMemcpy(posMass, posMassl, sizeof(float4) * nbodies, cudaMemcpyDeviceToHost)) fprintf(stderr, "copying of posMass from device failed\n"); CudaTest("posMass copy from device failed"); + + // print output + i = 0; +// for (i = 0; i < nbodies; i++) { + printf("%.2e %.2e %.2e\n", posMass[i].x, posMass[i].y, posMass[i].z); +// } + + free(vel); + free(accVel); + free(posMass); + + cudaFree(childl); + cudaFree(vell); + cudaFree(accVell); + cudaFree(countl); + cudaFree(startl); + cudaFree(sortl); + cudaFree(posMassl); + cudaFree(maxl); + cudaFree(minl); + + return 0; +} \ No newline at end of file diff --git a/cpp/src/tsne/cannylabs_tsne_license.txt b/cpp/src/tsne/cannylabs_tsne_license.txt index 8e5adfac73..1928cc07c3 100644 --- a/cpp/src/tsne/cannylabs_tsne_license.txt +++ b/cpp/src/tsne/cannylabs_tsne_license.txt @@ -1,60 +1,39 @@ /* -CUDA BarnesHut v3.1: Simulation of the gravitational forces -in a galactic cluster using the Barnes-Hut n-body algorithm -Copyright (c) 2013, Texas State University-San Marcos. All rights reserved. -Redistribution and use in source and binary forms, with or without modification, -are permitted for academic, research, experimental, or personal use provided that -the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of Texas State University-San Marcos nor the names of its - contributors may be used to endorse or promote products derived from this - software without specific prior written permission. -For all other uses, please contact the Office for Commercialization and Industry -Relations at Texas State University-San Marcos . -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED -IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE -OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -OF THE POSSIBILITY OF SUCH DAMAGE. -Author: Martin Burtscher -*/ - - -BSD 3-Clause License +ECL-BH v4.5: Simulation of the gravitational forces in a star cluster using +the Barnes-Hut n-body algorithm. -Copyright (c) 2018, Regents of the University of California -All rights reserved. +Copyright (c) 2010-2020 Texas State University. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of Texas State University nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL TEXAS STATE UNIVERSITY BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Authors: Martin Burtscher and Sahar Azimi + +URL: The latest version of this code is available at +https://userweb.cs.txstate.edu/~burtscher/research/ECL-BH/. + +Publication: This work is described in detail in the following paper. +Martin Burtscher and Keshav Pingali. An Efficient CUDA Implementation of the +Tree-based Barnes Hut n-Body Algorithm. Chapter 6 in GPU Computing Gems +Emerald Edition, pp. 75-92. January 2011. +*/ From f5c85893fd5ed37a25b896c78ce418b87eb7ad61 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 12 Oct 2021 15:45:32 +0200 Subject: [PATCH 09/12] Requested changes --- cpp/include/cuml/manifold/tsne.h | 42 ++++++++++++---------- cpp/src/tsne/barnes_hut_kernels.cuh | 7 ++-- cpp/src/tsne/barnes_hut_tsne.cuh | 10 ++---- cpp/src/tsne/distances.cuh | 7 ++++ cpp/src/tsne/exact_kernels.cuh | 8 ++--- cpp/src/tsne/exact_tsne.cuh | 12 ++----- cpp/src/tsne/fft_kernels.cuh | 5 ++- cpp/src/tsne/fft_tsne.cuh | 5 +-- cpp/src/tsne/tsne.cu | 50 ++++++++++++++------------ cpp/src/tsne/utils.cuh | 44 +++++++++++++++++++---- cpp/test/sg/tsne_test.cu | 27 +++++++------- python/cuml/manifold/t_sne.pyx | 56 +++++++++++++++-------------- 12 files changed, 153 insertions(+), 120 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index f18c86ca28..4747fe934c 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -117,6 +117,7 @@ struct TSNEParams { * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model + * @param[in] kl_div (optional) KL divergence * @return The Kullback–Leibler divergence * * The CUDA implementation is derived from the excellent CannyLabs open source @@ -126,14 +127,15 @@ struct TSNEParams { * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -float TSNE_fit(const raft::handle_t& handle, - float* X, - float* Y, - int n, - int p, - int64_t* knn_indices, - float* knn_dists, - TSNEParams& params); +void TSNE_fit(const raft::handle_t& handle, + float* X, + float* Y, + int n, + int p, + int64_t* knn_indices, + float* knn_dists, + TSNEParams& params, + float* kl_div = nullptr); /** * @brief Dimensionality reduction via TSNE using either Barnes Hut O(NlogN) @@ -150,6 +152,7 @@ float TSNE_fit(const raft::handle_t& handle, * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model + * @param[in] kl_div (optional) KL divergence * @return The Kullback–Leibler divergence * * The CUDA implementation is derived from the excellent CannyLabs open source @@ -159,16 +162,17 @@ float TSNE_fit(const raft::handle_t& handle, * approach is available in their article t-SNE-CUDA: GPU-Accelerated t-SNE and * its Applications to Modern Data (https://arxiv.org/abs/1807.11824). */ -float TSNE_fit_sparse(const raft::handle_t& handle, - int* indptr, - int* indices, - float* data, - float* Y, - int nnz, - int n, - int p, - int* knn_indices, - float* knn_dists, - TSNEParams& params); +void TSNE_fit_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + float* Y, + int nnz, + int n, + int p, + int* knn_indices, + float* knn_dists, + TSNEParams& params, + float* kl_div = nullptr); } // namespace ML diff --git a/cpp/src/tsne/barnes_hut_kernels.cuh b/cpp/src/tsne/barnes_hut_kernels.cuh index c00a10e5ca..a0bfccf2a6 100644 --- a/cpp/src/tsne/barnes_hut_kernels.cuh +++ b/cpp/src/tsne/barnes_hut_kernels.cuh @@ -697,10 +697,9 @@ __global__ void attractive_kernel_bh(const value_t* restrict VAL, // NaNs upstream though, so until we find and fix them, enforce that trait. if (!(dist >= 0)) dist = 0.0f; - const value_t exponent = (dof + 1.0) / 2.0; - const value_t P = VAL[index]; - const value_t Q = __powf(dof / (dof + dist), exponent); - const value_t PQ = P * Q; + const value_t P = VAL[index]; + const value_t Q = compute_q(dist, dof); + const value_t PQ = P * Q; // Apply forces atomicAdd(&attract1[i], PQ * y1d); diff --git a/cpp/src/tsne/barnes_hut_tsne.cuh b/cpp/src/tsne/barnes_hut_tsne.cuh index b4865cdef1..42416375a6 100644 --- a/cpp/src/tsne/barnes_hut_tsne.cuh +++ b/cpp/src/tsne/barnes_hut_tsne.cuh @@ -48,6 +48,8 @@ value_t Barnes_Hut(value_t* VAL, { cudaStream_t stream = handle.get_stream(); + value_t kl_div = 0; + // Get device properites //--------------------------------------------------- const int blocks = raft::getMultiProcessorCount(); @@ -286,12 +288,7 @@ value_t Barnes_Hut(value_t* VAL, END_TIMER(attractive_time); if (last_iter) { - value_t P_sum = thrust::reduce(rmm::exec_policy(stream), VAL, VAL + NNZ); - raft::linalg::scalarMultiply(VAL, VAL, 1.0f / P_sum, NNZ, stream); - value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); - raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); - compute_kl_div<<>>( - VAL, Qs, KL_divs, NNZ); + kl_div = compute_kl_div(VAL, Qs, KL_divs, NNZ, stream); CUDA_CHECK(cudaPeekAtLastError()); } @@ -321,7 +318,6 @@ value_t Barnes_Hut(value_t* VAL, raft::copy(Y, YY.data(), n, stream); raft::copy(Y + n, YY.data() + nnodes + 1, n, stream); - value_t kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); return kl_div; } diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 77fcdaf577..6aae76f3d1 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -129,6 +129,13 @@ void get_distances(const raft::handle_t& handle, throw raft::exception("Sparse TSNE does not support 64-bit integer indices yet."); } +/** + * @brief Find the maximum element in the distances matrix, then divide all entries by this. + * This promotes exp(distances) to not explode. + * @param[in] distances: The output sorted distances from KNN + * @param[in] total_nn: The number of rows in the data X + * @param[in] stream: The GPU stream + */ template void normalize_distances(value_t* distances, const size_t total_nn, cudaStream_t stream) { diff --git a/cpp/src/tsne/exact_kernels.cuh b/cpp/src/tsne/exact_kernels.cuh index b40c3f4cec..c8275c8e94 100644 --- a/cpp/src/tsne/exact_kernels.cuh +++ b/cpp/src/tsne/exact_kernels.cuh @@ -190,10 +190,8 @@ __global__ void attractive_kernel(const value_t* restrict VAL, dist += Y[k * n + i] * Y[k * n + j]; dist = norm[i] + norm[j] - 2.0f * dist; - const value_t exponent = (dof + 1.0) / 2.0; - const value_t P = VAL[index]; - const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t Q = compute_q(dist, dof); const value_t PQ = P * Q; // Apply forces @@ -231,10 +229,8 @@ __global__ void attractive_kernel_2d(const value_t* restrict VAL, // #862 const value_t dist = norm[i] + norm[j] - 2.0f * (Y1[i] * Y1[j] + Y2[i] * Y2[j]); - const value_t exponent = (dof + 1.0) / 2.0; - const value_t P = VAL[index]; - const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t Q = compute_q(dist, dof); const value_t PQ = P * Q; // Apply forces diff --git a/cpp/src/tsne/exact_tsne.cuh b/cpp/src/tsne/exact_tsne.cuh index 2976e7be75..d19b9e1307 100644 --- a/cpp/src/tsne/exact_tsne.cuh +++ b/cpp/src/tsne/exact_tsne.cuh @@ -45,6 +45,7 @@ value_t Exact_TSNE(value_t* VAL, const TSNEParams& params) { cudaStream_t stream = handle.get_stream(); + value_t kl_div = 0; const value_idx dim = params.dim; if (params.initialize_embeddings) @@ -116,14 +117,7 @@ value_t Exact_TSNE(value_t* VAL, fmaxf(params.dim - 1, 1), stream); - if (last_iter) { - value_t P_sum = thrust::reduce(rmm::exec_policy(stream), VAL, VAL + NNZ); - raft::linalg::scalarMultiply(VAL, VAL, 1.0f / P_sum, NNZ, stream); - value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); - raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); - compute_kl_div<<>>( - VAL, Qs, KL_divs, NNZ); - } + if (last_iter) { kl_div = compute_kl_div(VAL, Qs, KL_divs, NNZ, stream); } // Compute repulsive forces const float Z = TSNE::repulsive_forces( @@ -165,7 +159,7 @@ value_t Exact_TSNE(value_t* VAL, if (iter % 100 == 0) { CUML_LOG_DEBUG("Z at iter = %d = %f", iter, Z); } } } - value_t kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); + return kl_div; } diff --git a/cpp/src/tsne/fft_kernels.cuh b/cpp/src/tsne/fft_kernels.cuh index 8be17847f4..79508783b4 100644 --- a/cpp/src/tsne/fft_kernels.cuh +++ b/cpp/src/tsne/fft_kernels.cuh @@ -329,11 +329,10 @@ __global__ void compute_Pij_x_Qij_kernel(value_t* __restrict__ attr_forces, value_t dx = ix - jx; value_t dy = iy - jy; - const value_t dist = (dx * dx) + (dy * dy); - const value_t exponent = (dof + 1.0) / 2.0; + const value_t dist = (dx * dx) + (dy * dy); const value_t P = pij[TID]; - const value_t Q = __powf(dof / (dof + dist), exponent); + const value_t Q = compute_q(dist, dof); const value_t PQ = P * Q; atomicAdd(attr_forces + i, PQ * dx); diff --git a/cpp/src/tsne/fft_tsne.cuh b/cpp/src/tsne/fft_tsne.cuh index e29d051600..366cc0d409 100644 --- a/cpp/src/tsne/fft_tsne.cuh +++ b/cpp/src/tsne/fft_tsne.cuh @@ -522,10 +522,7 @@ value_t FFT_TSNE(value_t* VAL, FFT::compute_Pij_x_Qij_kernel<<>>( attractive_forces_device.data(), Qs, VAL, ROW, COL, Y, n, NNZ, dof); - value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); - raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); - compute_kl_div<<>>(VAL, Qs, KL_divs, NNZ); - kl_div = thrust::reduce(handle.get_thrust_policy(), KL_divs, KL_divs + NNZ); + kl_div = compute_kl_div(VAL, Qs, KL_divs, NNZ, stream); } else { FFT::compute_Pij_x_Qij_kernel<<>>( attractive_forces_device.data(), (value_t*)nullptr, VAL, ROW, COL, Y, n, NNZ, dof); diff --git a/cpp/src/tsne/tsne.cu b/cpp/src/tsne/tsne.cu index 41d83b7270..eef47545a3 100644 --- a/cpp/src/tsne/tsne.cu +++ b/cpp/src/tsne/tsne.cu @@ -30,14 +30,15 @@ value_t _fit(const raft::handle_t& handle, return runner.run(); // returns the Kullback–Leibler divergence } -float TSNE_fit(const raft::handle_t& handle, - float* X, - float* Y, - int n, - int p, - int64_t* knn_indices, - float* knn_dists, - TSNEParams& params) +void TSNE_fit(const raft::handle_t& handle, + float* X, + float* Y, + int n, + int p, + int64_t* knn_indices, + float* knn_dists, + TSNEParams& params, + float* kl_div) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && X != NULL && Y != NULL, "Wrong input args"); @@ -45,22 +46,24 @@ float TSNE_fit(const raft::handle_t& handle, manifold_dense_inputs_t input(X, Y, n, p); knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); - return _fit, knn_indices_dense_t, float>( + float kl_div_v = _fit, knn_indices_dense_t, float>( handle, input, k_graph, params); - // returns the Kullback–Leibler divergence + + if (kl_div) { *kl_div = kl_div_v; } } -float TSNE_fit_sparse(const raft::handle_t& handle, - int* indptr, - int* indices, - float* data, - float* Y, - int nnz, - int n, - int p, - int* knn_indices, - float* knn_dists, - TSNEParams& params) +void TSNE_fit_sparse(const raft::handle_t& handle, + int* indptr, + int* indices, + float* data, + float* Y, + int nnz, + int n, + int p, + int* knn_indices, + float* knn_dists, + TSNEParams& params, + float* kl_div) { ASSERT(n > 0 && p > 0 && params.dim > 0 && params.n_neighbors > 0 && indptr != NULL && indices != NULL && data != NULL && Y != NULL, @@ -69,9 +72,10 @@ float TSNE_fit_sparse(const raft::handle_t& handle, manifold_sparse_inputs_t input(indptr, indices, data, Y, nnz, n, p); knn_graph k_graph(n, params.n_neighbors, knn_indices, knn_dists); - return _fit, knn_indices_sparse_t, float>( + float kl_div_v = _fit, knn_indices_sparse_t, float>( handle, input, k_graph, params); - // returns the Kullback–Leibler divergence + + if (kl_div) { *kl_div = kl_div_v; } } } // namespace ML diff --git a/cpp/src/tsne/utils.cuh b/cpp/src/tsne/utils.cuh index bd6f53a467..1c03bc460e 100644 --- a/cpp/src/tsne/utils.cuh +++ b/cpp/src/tsne/utils.cuh @@ -22,6 +22,7 @@ #include #include +#include #include #include @@ -178,18 +179,49 @@ __global__ void min_max_kernel( } /** - * Compute KL divergence + * CUDA kernel to compute KL divergence */ template -__global__ void compute_kl_div(const value_t* restrict Ps, - const value_t* restrict Qs, - value_t* restrict kl_divergences, - const value_idx NNZ) +__global__ void compute_kl_div_k(const value_t* restrict Ps, + const value_t* restrict Qs, + value_t* restrict KL_divs, + const value_idx NNZ) { const auto index = (blockIdx.x * blockDim.x) + threadIdx.x; if (index >= NNZ) return; const value_t P = Ps[index]; const value_t Q = max(Qs[index], FLT_EPSILON); - kl_divergences[index] = P * __logf(__fdividef(max(P, FLT_EPSILON), Q)); + KL_divs[index] = P * __logf(__fdividef(max(P, FLT_EPSILON), Q)); +} + +/** + * Compute KL divergence + */ +template +value_t compute_kl_div(value_t* restrict Ps, + value_t* restrict Qs, + value_t* restrict KL_divs, + const size_t NNZ, + cudaStream_t stream) +{ + value_t P_sum = thrust::reduce(rmm::exec_policy(stream), Ps, Ps + NNZ); + raft::linalg::scalarMultiply(Ps, Ps, 1.0f / P_sum, NNZ, stream); + + value_t Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + NNZ); + raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, NNZ, stream); + + const size_t block = 128; + const size_t grid = raft::ceildiv(NNZ, block); + compute_kl_div_k<<>>(Ps, Qs, KL_divs, NNZ); + + return thrust::reduce(rmm::exec_policy(stream), KL_divs, KL_divs + NNZ); } + +template +__device__ value_t compute_q(value_t dist, value_t dof) +{ + const value_t exponent = (dof + 1.0) / 2.0; + const value_t Q = __powf(dof / (dof + dist), exponent); + return Q; +} \ No newline at end of file diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index e3de3ec750..a25c4d8e4d 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -75,12 +75,8 @@ float get_kl_div(TSNEParams& params, total_nn, [=] __device__(float dist) { return __powf(dof / (dof + dist), exponent); }, stream); - float Q_sum = thrust::reduce(rmm::exec_policy(stream), Qs, Qs + total_nn); - raft::linalg::scalarMultiply(Qs, Qs, 1.0f / Q_sum, total_nn, stream); - compute_kl_div<<>>( - Ps, Qs, KL_divs, total_nn); - float kl_div = thrust::reduce(rmm::exec_policy(stream), KL_divs, KL_divs + total_nn); + float kl_div = compute_kl_div(Ps, Qs, KL_divs, total_nn, stream); return kl_div; } @@ -90,15 +86,20 @@ class TSNETest : public ::testing::TestWithParam { void assert_results(const char* test, TSNEResults& results) { - std::cout << "Testing " << test << ":" << std::endl; - std::cout << "\ttrustworthiness = " << results.trustworthiness << std::endl; - std::cout << "\tkl_div = " << results.kl_div << std::endl; - std::cout << "\tkl_div_ref = " << results.kl_div_ref << std::endl; - ASSERT_TRUE(results.trustworthiness > trustworthiness_threshold); + bool test_tw = results.trustworthiness > trustworthiness_threshold; double kl_div_tol = 0.2; - ASSERT_TRUE(results.kl_div_ref - kl_div_tol < results.kl_div && - results.kl_div < results.kl_div_ref + kl_div_tol); - std::cout << std::endl; + bool test_kl_div = results.kl_div_ref - kl_div_tol < results.kl_div && + results.kl_div < results.kl_div_ref + kl_div_tol; + + if (!test_tw || !test_kl_div) { + std::cout << "Testing " << test << ":" << std::endl; + std::cout << "\ttrustworthiness = " << results.trustworthiness << std::endl; + std::cout << "\tkl_div = " << results.kl_div << std::endl; + std::cout << "\tkl_div_ref = " << results.kl_div_ref << std::endl; + std::cout << std::endl; + } + ASSERT_TRUE(test_tw); + ASSERT_TRUE(test_kl_div); } TSNEResults runTest(TSNE_ALGORITHM algo, bool knn = false) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index fe8f661ab7..b47a077ff5 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -84,7 +84,7 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": cdef extern from "cuml/manifold/tsne.h" namespace "ML": - cdef float TSNE_fit( + cdef void TSNE_fit( handle_t &handle, float *X, float *Y, @@ -92,9 +92,10 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int p, int64_t* knn_indices, float* knn_dists, - TSNEParams ¶ms) except + + TSNEParams ¶ms, + float* kl_div) except + - cdef float TSNE_fit_sparse( + cdef void TSNE_fit_sparse( const handle_t &handle, int *indptr, int *indices, @@ -105,7 +106,8 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML": int p, int* knn_indices, float* knn_dists, - TSNEParams ¶ms) except + + TSNEParams ¶ms, + float* kl_div) except + class TSNE(Base, @@ -501,29 +503,31 @@ class TSNE(Base, cdef float kl_divergence = 0 if self.sparse_fit: - kl_divergence = TSNE_fit_sparse(handle_[0], - - self.X_m.indptr.ptr, - - self.X_m.indices.ptr, - - self.X_m.data.ptr, - embed_ptr, - self.X_m.nnz, - n, - p, - knn_indices_raw, - knn_dists_raw, - deref(params)) + TSNE_fit_sparse(handle_[0], + + self.X_m.indptr.ptr, + + self.X_m.indices.ptr, + + self.X_m.data.ptr, + embed_ptr, + self.X_m.nnz, + n, + p, + knn_indices_raw, + knn_dists_raw, + deref(params), + &kl_divergence) else: - kl_divergence = TSNE_fit(handle_[0], - self.X_m.ptr, - embed_ptr, - n, - p, - knn_indices_raw, - knn_dists_raw, - deref(params)) + TSNE_fit(handle_[0], + self.X_m.ptr, + embed_ptr, + n, + p, + knn_indices_raw, + knn_dists_raw, + deref(params), + &kl_divergence) self.handle.sync() free(params) From 1340b9e2da8e8d266e7734154b45667f74df3122 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 25 Oct 2021 17:20:36 +0200 Subject: [PATCH 10/12] Add copyright check exemption --- ci/checks/copyright.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 0a432c45e5..22de533bcc 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -38,7 +38,7 @@ re.compile(r"[.]flake8[.]cython$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = [] +ExemptFiles = ['cpp/src/tsne/cannylab/bh.cu'] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( From 8177848b2f0368599b79e0e840dd71e935b3531b Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 26 Oct 2021 12:36:15 +0200 Subject: [PATCH 11/12] Fix doc + fix copyright check issue --- ci/checks/copyright.py | 3 ++- cpp/include/cuml/manifold/tsne.h | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index 22de533bcc..30a0ddf519 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -198,7 +198,8 @@ def checkCopyright_main(): (args, dirs) = argparser.parse_known_args() try: - ExemptFiles = [re.compile(pathName) for pathName in args.exclude] + ExemptFiles = ExemptFiles + [pathName for pathName in args.exclude] + ExemptFiles = [re.compile(file) for file in ExemptFiles] except re.error as reException: print("Regular expression error:") print(reException) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index 4747fe934c..5fdfb9a7c9 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -117,7 +117,6 @@ struct TSNEParams { * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model - * @param[in] kl_div (optional) KL divergence * @return The Kullback–Leibler divergence * * The CUDA implementation is derived from the excellent CannyLabs open source @@ -152,7 +151,6 @@ void TSNE_fit(const raft::handle_t& handle, * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model - * @param[in] kl_div (optional) KL divergence * @return The Kullback–Leibler divergence * * The CUDA implementation is derived from the excellent CannyLabs open source From 41dc683c714fb45cde410fde94935cfea97fd2cd Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 27 Oct 2021 17:17:01 +0200 Subject: [PATCH 12/12] Doc fix --- cpp/include/cuml/manifold/tsne.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/cuml/manifold/tsne.h b/cpp/include/cuml/manifold/tsne.h index 5fdfb9a7c9..a72ece863c 100644 --- a/cpp/include/cuml/manifold/tsne.h +++ b/cpp/include/cuml/manifold/tsne.h @@ -117,7 +117,7 @@ struct TSNEParams { * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model - * @return The Kullback–Leibler divergence + * @param[out] kl_div (optional) KL divergence output * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs @@ -151,7 +151,7 @@ void TSNE_fit(const raft::handle_t& handle, * @param[in] knn_indices Array containing nearest neighors indices. * @param[in] knn_dists Array containing nearest neighors distances. * @param[in] params Parameters for TSNE model - * @return The Kullback–Leibler divergence + * @param[out] kl_div (optional) KL divergence output * * The CUDA implementation is derived from the excellent CannyLabs open source * implementation here: https://github.com/CannyLab/tsne-cuda/. The CannyLabs