Skip to content

Commit

Permalink
Optimise Split Finding Kernel (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell authored Sep 10, 2024
1 parent dd01d09 commit 3e9cede
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 68 deletions.
2 changes: 1 addition & 1 deletion src/models/tree/build_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ struct TreeBuilder {
if (left_sum[0].hess <= 0.0 || right_sum[0].hess <= 0.0) continue;
tree.AddSplit(node_id,
best_feature,
split_proposals.split_proposals[{best_bin}],
split_proposals.split_proposals[best_bin],
left_leaf,
right_leaf,
best_gain,
Expand Down
134 changes: 67 additions & 67 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct NodeBatch {

class GradientQuantiser {
IntegerGPair scale;
GPair inverse_scale;

public:
struct GetAbsGPair {
Expand Down Expand Up @@ -83,8 +84,10 @@ class GradientQuantiser {
int64_t double_max_int = 1ll << 51;
int64_t max_int =
std::min<int64_t>(double_max_int, std::numeric_limits<IntegerGPair::value_type>::max());
scale.grad = abs_sum.grad == 0 ? 1 : max_int / abs_sum.grad;
scale.hess = abs_sum.hess == 0 ? 1 : max_int / abs_sum.hess;
scale.grad = abs_sum.grad == 0 ? 1 : max_int / abs_sum.grad;
scale.hess = abs_sum.hess == 0 ? 1 : max_int / abs_sum.hess;
inverse_scale.grad = 1.0 / scale.grad;
inverse_scale.hess = 1.0 / scale.hess;
}

__device__ IntegerGPair Quantise(GPair value) const
Expand All @@ -98,8 +101,8 @@ class GradientQuantiser {
__device__ GPair Dequantise(IntegerGPair value) const
{
GPair result;
result.grad = double(value.grad) / scale.grad;
result.hess = double(value.hess) / scale.hess;
result.grad = value.grad * inverse_scale.grad;
result.hess = value.hess * inverse_scale.hess;
return result;
}
};
Expand Down Expand Up @@ -272,29 +275,26 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK)
// Key/value pair to simplify reduction
struct GainFeaturePair {
double gain;
int feature;
int feature_sample_idx;
int bin_idx;

__device__ void operator=(const GainFeaturePair& other)
{
gain = other.gain;
feature = other.feature;
feature_sample_idx = other.feature_sample_idx;
gain = other.gain;
bin_idx = other.bin_idx;
}

__device__ bool operator==(const GainFeaturePair& other) const
{
return gain == other.gain && feature == other.feature &&
feature_sample_idx == other.feature_sample_idx;
return gain == other.gain && bin_idx == other.bin_idx;
}

__device__ bool operator>(const GainFeaturePair& other) const { return gain > other.gain; }

__device__ bool operator<(const GainFeaturePair& other) const { return gain < other.gain; }
};

template <typename TYPE>
__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
template <typename TYPE, int BLOCK_THREADS>
__global__ static void __launch_bounds__(BLOCK_THREADS)
perform_best_split(Histogram<IntegerGPair> histogram,
size_t n_features,
size_t n_outputs,
Expand All @@ -316,66 +316,64 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
__shared__ typename BlockReduce::TempStorage temp_storage;

__shared__ double node_best_gain;
__shared__ int node_best_feature;
__shared__ int node_best_bin_idx;

double thread_best_gain = 0;
int thread_best_feature = -1;
int thread_best_bin_idx = -1;

for (int feature_id = 0; feature_id < n_features; feature_id++) {
auto [feature_start, feature_end] = split_proposals.FeatureRange(feature_id);

for (int bin_idx = feature_start + threadIdx.x; bin_idx < feature_end; bin_idx += blockDim.x) {
double gain = 0;
for (int output = 0; output < n_outputs; ++output) {
auto node_sum = node_sums[{node_id, output}];
auto left_sum = histogram[{node_id, output, bin_idx}];
auto right_sum = node_sum - left_sum;
auto [G, H] = quantiser.Dequantise(node_sum);
auto [G_L, H_L] = quantiser.Dequantise(left_sum);
auto [G_R, H_R] = quantiser.Dequantise(right_sum);

if (H_L <= 0.0 || H_R <= 0.0) {
gain = 0;
break;
}
double reg = std::max(eps, alpha); // Regularisation term
gain += 0.5 * ((G_L * G_L) / (H_L + reg) + (G_R * G_R) / (H_R + reg) - (G * G) / (H + reg));
}
if (gain > thread_best_gain) {
thread_best_gain = gain;
thread_best_feature = feature_id;
thread_best_bin_idx = bin_idx;
for (int bin_idx = threadIdx.x; bin_idx < split_proposals.histogram_size;
bin_idx += BLOCK_THREADS) {
double gain = 0;
for (int output = 0; output < n_outputs; ++output) {
auto node_sum = node_sums[{node_id, output}];
auto left_sum = histogram[{node_id, output, bin_idx}];
auto right_sum = node_sum - left_sum;
if (left_sum.hess <= 0.0 || right_sum.hess <= 0.0) {
gain = 0;
break;
}
double reg = std::max(eps, alpha); // Regularisation term
auto [G, H] = quantiser.Dequantise(node_sum);
gain -= (G * G) / (H + reg);
auto [G_L, H_L] = quantiser.Dequantise(left_sum);

gain += (G_L * G_L) / (H_L + reg);
auto [G_R, H_R] = quantiser.Dequantise(right_sum);
gain += (G_R * G_R) / (H_R + reg);
}
gain *= 0.5;
if (gain > thread_best_gain) {
thread_best_gain = gain;
thread_best_bin_idx = bin_idx;
}
}

// SYNC BEST GAIN TO FULL BLOCK/NODE
GainFeaturePair thread_best_pair{thread_best_gain, thread_best_feature, thread_best_bin_idx};
GainFeaturePair thread_best_pair{thread_best_gain, thread_best_bin_idx};
GainFeaturePair node_best_pair =
BlockReduce(temp_storage).Reduce(thread_best_pair, cub::Max(), THREADS_PER_BLOCK);
BlockReduce(temp_storage).Reduce(thread_best_pair, cub::Max(), BLOCK_THREADS);
if (threadIdx.x == 0) {
node_best_gain = node_best_pair.gain;
node_best_feature = node_best_pair.feature;
node_best_bin_idx = node_best_pair.feature_sample_idx;
node_best_bin_idx = node_best_pair.bin_idx;
}
__syncthreads();

if (node_best_gain > eps) {
for (int output = threadIdx.x; output < n_outputs; output += blockDim.x) {
auto node_sum = node_sums[{node_id, output}];
auto left_sum = histogram[{node_id, output, node_best_bin_idx}];
auto right_sum = node_sum - left_sum;
int node_best_feature = split_proposals.FindFeature(node_best_bin_idx);
for (int output = threadIdx.x; output < n_outputs; output += BLOCK_THREADS) {
auto node_sum = node_sums[{node_id, output}];
auto left_sum = histogram[{node_id, output, node_best_bin_idx}];
auto right_sum = node_sum - left_sum;
node_sums[{BinaryTree::LeftChild(node_id), output}] = left_sum;
node_sums[{BinaryTree::RightChild(node_id), output}] = right_sum;

auto [G_L, H_L] = quantiser.Dequantise(left_sum);
auto [G_R, H_R] = quantiser.Dequantise(right_sum);
tree_leaf_value[{BinaryTree::LeftChild(node_id), output}] =
CalculateLeafValue(G_L, H_L, alpha);

int left_child = BinaryTree::LeftChild(node_id);
int right_child = BinaryTree::RightChild(node_id);
tree_leaf_value[{left_child, output}] = CalculateLeafValue(G_L, H_L, alpha);
tree_leaf_value[{right_child, output}] = CalculateLeafValue(G_R, H_R, alpha);
node_sums[{left_child, output}] = left_sum;
node_sums[{right_child, output}] = right_sum;
auto [G_R, H_R] = quantiser.Dequantise(right_sum);
tree_leaf_value[{BinaryTree::RightChild(node_id), output}] =
CalculateLeafValue(G_R, H_R, alpha);

if (output == 0) {
tree_feature[node_id] = node_best_feature;
Expand Down Expand Up @@ -673,19 +671,21 @@ struct TreeBuilder {
double alpha,
NodeBatch batch)
{
perform_best_split<<<batch.NodesInBatch(), THREADS_PER_BLOCK, 0, stream>>>(histogram,
num_features,
num_outputs,
split_proposals,
eps,
alpha,
tree.leaf_value,
tree.node_sums,
tree.feature,
tree.split_value,
tree.gain,
batch,
quantiser);
const int kBlockThreads = 256;
perform_best_split<T, kBlockThreads>
<<<batch.NodesInBatch(), kBlockThreads, 0, stream>>>(histogram,
num_features,
num_outputs,
split_proposals,
eps,
alpha,
tree.leaf_value,
tree.node_sums,
tree.feature,
tree.split_value,
tree.gain,
batch,
quantiser);
CHECK_CUDA_STREAM(stream);
}
void InitialiseRoot(legate::TaskContext context,
Expand Down
10 changes: 10 additions & 0 deletions src/models/tree/build_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ class SparseSplitProposals {
}
#endif

#ifdef __CUDACC__
__host__ __device__ int FindFeature(int bin_idx) const
{
// Binary search for the feature
return thrust::upper_bound(
thrust::seq, row_pointers.ptr(0), row_pointers.ptr(num_features), bin_idx) -
row_pointers.ptr(0) - 1;
}
#endif

__host__ __device__ std::tuple<int, int> FeatureRange(int feature) const
{
return std::make_tuple(row_pointers[feature], row_pointers[feature + 1]);
Expand Down

0 comments on commit 3e9cede

Please sign in to comment.