From 3e9cede7695ae1b95ec6db5bb22f5549de2191dc Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 10 Sep 2024 13:04:29 +0200 Subject: [PATCH] Optimise Split Finding Kernel (#156) --- src/models/tree/build_tree.cc | 2 +- src/models/tree/build_tree.cu | 134 +++++++++++++++++----------------- src/models/tree/build_tree.h | 10 +++ 3 files changed, 78 insertions(+), 68 deletions(-) diff --git a/src/models/tree/build_tree.cc b/src/models/tree/build_tree.cc index 87420e77..96ce8b4c 100644 --- a/src/models/tree/build_tree.cc +++ b/src/models/tree/build_tree.cc @@ -320,7 +320,7 @@ struct TreeBuilder { if (left_sum[0].hess <= 0.0 || right_sum[0].hess <= 0.0) continue; tree.AddSplit(node_id, best_feature, - split_proposals.split_proposals[{best_bin}], + split_proposals.split_proposals[best_bin], left_leaf, right_leaf, best_gain, diff --git a/src/models/tree/build_tree.cu b/src/models/tree/build_tree.cu index 22cdc8b7..85de7019 100644 --- a/src/models/tree/build_tree.cu +++ b/src/models/tree/build_tree.cu @@ -47,6 +47,7 @@ struct NodeBatch { class GradientQuantiser { IntegerGPair scale; + GPair inverse_scale; public: struct GetAbsGPair { @@ -83,8 +84,10 @@ class GradientQuantiser { int64_t double_max_int = 1ll << 51; int64_t max_int = std::min(double_max_int, std::numeric_limits::max()); - scale.grad = abs_sum.grad == 0 ? 1 : max_int / abs_sum.grad; - scale.hess = abs_sum.hess == 0 ? 1 : max_int / abs_sum.hess; + scale.grad = abs_sum.grad == 0 ? 1 : max_int / abs_sum.grad; + scale.hess = abs_sum.hess == 0 ? 1 : max_int / abs_sum.hess; + inverse_scale.grad = 1.0 / scale.grad; + inverse_scale.hess = 1.0 / scale.hess; } __device__ IntegerGPair Quantise(GPair value) const @@ -98,8 +101,8 @@ class GradientQuantiser { __device__ GPair Dequantise(IntegerGPair value) const { GPair result; - result.grad = double(value.grad) / scale.grad; - result.hess = double(value.hess) / scale.hess; + result.grad = value.grad * inverse_scale.grad; + result.hess = value.hess * inverse_scale.hess; return result; } }; @@ -272,20 +275,17 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK) // Key/value pair to simplify reduction struct GainFeaturePair { double gain; - int feature; - int feature_sample_idx; + int bin_idx; __device__ void operator=(const GainFeaturePair& other) { - gain = other.gain; - feature = other.feature; - feature_sample_idx = other.feature_sample_idx; + gain = other.gain; + bin_idx = other.bin_idx; } __device__ bool operator==(const GainFeaturePair& other) const { - return gain == other.gain && feature == other.feature && - feature_sample_idx == other.feature_sample_idx; + return gain == other.gain && bin_idx == other.bin_idx; } __device__ bool operator>(const GainFeaturePair& other) const { return gain > other.gain; } @@ -293,8 +293,8 @@ struct GainFeaturePair { __device__ bool operator<(const GainFeaturePair& other) const { return gain < other.gain; } }; -template -__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) +template +__global__ static void __launch_bounds__(BLOCK_THREADS) perform_best_split(Histogram histogram, size_t n_features, size_t n_outputs, @@ -316,66 +316,64 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ double node_best_gain; - __shared__ int node_best_feature; __shared__ int node_best_bin_idx; double thread_best_gain = 0; - int thread_best_feature = -1; int thread_best_bin_idx = -1; - for (int feature_id = 0; feature_id < n_features; feature_id++) { - auto [feature_start, feature_end] = split_proposals.FeatureRange(feature_id); - - for (int bin_idx = feature_start + threadIdx.x; bin_idx < feature_end; bin_idx += blockDim.x) { - double gain = 0; - for (int output = 0; output < n_outputs; ++output) { - auto node_sum = node_sums[{node_id, output}]; - auto left_sum = histogram[{node_id, output, bin_idx}]; - auto right_sum = node_sum - left_sum; - auto [G, H] = quantiser.Dequantise(node_sum); - auto [G_L, H_L] = quantiser.Dequantise(left_sum); - auto [G_R, H_R] = quantiser.Dequantise(right_sum); - - if (H_L <= 0.0 || H_R <= 0.0) { - gain = 0; - break; - } - double reg = std::max(eps, alpha); // Regularisation term - gain += 0.5 * ((G_L * G_L) / (H_L + reg) + (G_R * G_R) / (H_R + reg) - (G * G) / (H + reg)); - } - if (gain > thread_best_gain) { - thread_best_gain = gain; - thread_best_feature = feature_id; - thread_best_bin_idx = bin_idx; + for (int bin_idx = threadIdx.x; bin_idx < split_proposals.histogram_size; + bin_idx += BLOCK_THREADS) { + double gain = 0; + for (int output = 0; output < n_outputs; ++output) { + auto node_sum = node_sums[{node_id, output}]; + auto left_sum = histogram[{node_id, output, bin_idx}]; + auto right_sum = node_sum - left_sum; + if (left_sum.hess <= 0.0 || right_sum.hess <= 0.0) { + gain = 0; + break; } + double reg = std::max(eps, alpha); // Regularisation term + auto [G, H] = quantiser.Dequantise(node_sum); + gain -= (G * G) / (H + reg); + auto [G_L, H_L] = quantiser.Dequantise(left_sum); + + gain += (G_L * G_L) / (H_L + reg); + auto [G_R, H_R] = quantiser.Dequantise(right_sum); + gain += (G_R * G_R) / (H_R + reg); + } + gain *= 0.5; + if (gain > thread_best_gain) { + thread_best_gain = gain; + thread_best_bin_idx = bin_idx; } } // SYNC BEST GAIN TO FULL BLOCK/NODE - GainFeaturePair thread_best_pair{thread_best_gain, thread_best_feature, thread_best_bin_idx}; + GainFeaturePair thread_best_pair{thread_best_gain, thread_best_bin_idx}; GainFeaturePair node_best_pair = - BlockReduce(temp_storage).Reduce(thread_best_pair, cub::Max(), THREADS_PER_BLOCK); + BlockReduce(temp_storage).Reduce(thread_best_pair, cub::Max(), BLOCK_THREADS); if (threadIdx.x == 0) { node_best_gain = node_best_pair.gain; - node_best_feature = node_best_pair.feature; - node_best_bin_idx = node_best_pair.feature_sample_idx; + node_best_bin_idx = node_best_pair.bin_idx; } __syncthreads(); if (node_best_gain > eps) { - for (int output = threadIdx.x; output < n_outputs; output += blockDim.x) { - auto node_sum = node_sums[{node_id, output}]; - auto left_sum = histogram[{node_id, output, node_best_bin_idx}]; - auto right_sum = node_sum - left_sum; + int node_best_feature = split_proposals.FindFeature(node_best_bin_idx); + for (int output = threadIdx.x; output < n_outputs; output += BLOCK_THREADS) { + auto node_sum = node_sums[{node_id, output}]; + auto left_sum = histogram[{node_id, output, node_best_bin_idx}]; + auto right_sum = node_sum - left_sum; + node_sums[{BinaryTree::LeftChild(node_id), output}] = left_sum; + node_sums[{BinaryTree::RightChild(node_id), output}] = right_sum; + auto [G_L, H_L] = quantiser.Dequantise(left_sum); - auto [G_R, H_R] = quantiser.Dequantise(right_sum); + tree_leaf_value[{BinaryTree::LeftChild(node_id), output}] = + CalculateLeafValue(G_L, H_L, alpha); - int left_child = BinaryTree::LeftChild(node_id); - int right_child = BinaryTree::RightChild(node_id); - tree_leaf_value[{left_child, output}] = CalculateLeafValue(G_L, H_L, alpha); - tree_leaf_value[{right_child, output}] = CalculateLeafValue(G_R, H_R, alpha); - node_sums[{left_child, output}] = left_sum; - node_sums[{right_child, output}] = right_sum; + auto [G_R, H_R] = quantiser.Dequantise(right_sum); + tree_leaf_value[{BinaryTree::RightChild(node_id), output}] = + CalculateLeafValue(G_R, H_R, alpha); if (output == 0) { tree_feature[node_id] = node_best_feature; @@ -673,19 +671,21 @@ struct TreeBuilder { double alpha, NodeBatch batch) { - perform_best_split<<>>(histogram, - num_features, - num_outputs, - split_proposals, - eps, - alpha, - tree.leaf_value, - tree.node_sums, - tree.feature, - tree.split_value, - tree.gain, - batch, - quantiser); + const int kBlockThreads = 256; + perform_best_split + <<>>(histogram, + num_features, + num_outputs, + split_proposals, + eps, + alpha, + tree.leaf_value, + tree.node_sums, + tree.feature, + tree.split_value, + tree.gain, + batch, + quantiser); CHECK_CUDA_STREAM(stream); } void InitialiseRoot(legate::TaskContext context, diff --git a/src/models/tree/build_tree.h b/src/models/tree/build_tree.h index 0dcdb442..96d5fc97 100644 --- a/src/models/tree/build_tree.h +++ b/src/models/tree/build_tree.h @@ -143,6 +143,16 @@ class SparseSplitProposals { } #endif +#ifdef __CUDACC__ + __host__ __device__ int FindFeature(int bin_idx) const + { + // Binary search for the feature + return thrust::upper_bound( + thrust::seq, row_pointers.ptr(0), row_pointers.ptr(num_features), bin_idx) - + row_pointers.ptr(0) - 1; + } +#endif + __host__ __device__ std::tuple FeatureRange(int feature) const { return std::make_tuple(row_pointers[feature], row_pointers[feature + 1]);