Skip to content

Commit

Permalink
Add stochastic rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Sep 16, 2024
1 parent b378a77 commit b5193c0
Showing 1 changed file with 58 additions and 19 deletions.
77 changes: 58 additions & 19 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,23 @@ class GradientQuantiser {
scale.hess = local_abs_sum.hess == 0 ? 1 : max_int / local_abs_sum.hess;
}

__device__ IntegerGPair Quantise(GPair value) const
// Round gradient and hessian using stochastic rounding
// Thus the expected value of the quantised value is unbiased
// Also the expected error grows as O(1/sqrt(n)) where n is the number of samples
// Vs. O(1/n) for round nearest
// The seed here should be unique for each gpair over each boosting iteration
// Use a hash combine function to generate the seed
__device__ IntegerGPair QuantiseStochasticRounding(GPair value, int64_t seed) const
{
IntegerGPair result;
result.grad = value.grad * scale.grad;
result.hess = value.hess * scale.hess;
return result;
thrust::default_random_engine eng(seed);
thrust::uniform_real_distribution<double> dist(0.0, 1.0);
auto scaled_grad = value.grad * scale.grad;
auto scaled_hess = value.hess * scale.hess;
double grad_remainder = scaled_grad - floor(scaled_grad);
double hess_remainder = scaled_hess - floor(scaled_hess);
IntegerGPair::value_type grad_quantised = floor(scaled_grad) + (dist(eng) < grad_remainder);
IntegerGPair::value_type hess_quantised = floor(scaled_hess) + (dist(eng) < hess_remainder);
return IntegerGPair{grad_quantised, hess_quantised};
}

__device__ GPair Dequantise(IntegerGPair value) const
Expand All @@ -103,14 +114,35 @@ class GradientQuantiser {
}
};

__device__ int hash(int a)
{
a = (a + 0x7ed55d16) + (a << 12);
a = (a ^ 0xc761c23c) ^ (a >> 19);
a = (a + 0x165667b1) + (a << 5);
a = (a + 0xd3a2646c) ^ (a << 9);
a = (a + 0xfd7046c5) + (a << 3);
a = (a ^ 0xb55a4f09) ^ (a >> 16);
return a;
}

__device__ int64_t hash_combine(int64_t seed) { return seed; }

template <typename... Rest>
__device__ int64_t hash_combine(int64_t seed, const int& v, Rest... rest)
{
seed ^= hash(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return hash_combine(seed, rest...);
}

__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
reduce_base_sums(legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
size_t n_local_samples,
int64_t sample_offset,
legate::Buffer<IntegerGPair, 2> node_sums,
size_t n_outputs,
GradientQuantiser quantiser)
GradientQuantiser quantiser,
int64_t seed)
{
typedef cub::BlockReduce<IntegerGPair, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
Expand All @@ -119,10 +151,12 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)

int64_t sample_id = threadIdx.x + blockDim.x * blockIdx.x;

double grad = sample_id < n_local_samples ? g[{sample_id + sample_offset, 0, output}] : 0.0;
double hess = sample_id < n_local_samples ? h[{sample_id + sample_offset, 0, output}] : 0.0;
legate::Point<3> p = {sample_id + sample_offset, 0, output};
double grad = sample_id < n_local_samples ? g[p] : 0.0;
double hess = sample_id < n_local_samples ? h[p] : 0.0;

auto quantised = quantiser.Quantise({grad, hess});
auto quantised =
quantiser.QuantiseStochasticRounding({grad, hess}, hash_combine(seed, p[0], p[2]));
IntegerGPair blocksum = BlockReduce(temp_storage).Sum(quantised);

if (threadIdx.x == 0) {
Expand All @@ -147,7 +181,8 @@ __global__ static void __launch_bounds__(TPB, MIN_CTAS_PER_SM)
NodeBatch batch,
Histogram<IntegerGPair> histogram,
legate::Buffer<IntegerGPair, 2> node_sums,
GradientQuantiser quantiser)
GradientQuantiser quantiser,
int64_t seed)
{
constexpr int32_t WarpSize = 32;
const int32_t warp_id = threadIdx.x / WarpSize;
Expand Down Expand Up @@ -186,9 +221,10 @@ __global__ static void __launch_bounds__(TPB, MIN_CTAS_PER_SM)
split_proposals.FindBin(X[{sample_offset + localSampleId, feature, 0}], feature);
for (int32_t output = 0; output < n_outputs; output++) {
// get same G/H from every thread in warp
auto gpair_quantised = quantiser.Quantise({g[{sample_offset + localSampleId, 0, output}],
h[{sample_offset + localSampleId, 0, output}]});
auto* addPosition = reinterpret_cast<typename IntegerGPair::value_type*>(
legate::Point<3> p = {sample_offset + localSampleId, feature, output};
auto gpair_quantised =
quantiser.QuantiseStochasticRounding({g[p], h[p]}, hash_combine(seed, p[0], p[2]));
auto* addPosition = reinterpret_cast<typename IntegerGPair::value_type*>(
&histogram[{sampleNode, output, bin_idx}]);

if (bin_idx != SparseSplitProposals<TYPE>::NOT_FOUND) {
Expand Down Expand Up @@ -627,7 +663,8 @@ struct TreeBuilder {
legate::Rect<3> X_shape,
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
NodeBatch batch)
NodeBatch batch,
int64_t seed)
{
// warp kernel without additional caching / prefetching
const int threads_per_block = 256;
Expand All @@ -650,7 +687,8 @@ struct TreeBuilder {
batch,
histogram,
tree.node_sums,
quantiser);
quantiser,
seed);

CHECK_CUDA_STREAM(stream);

Expand Down Expand Up @@ -712,12 +750,13 @@ struct TreeBuilder {
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
legate::Rect<3> g_shape,
double alpha)
double alpha,
int64_t seed)
{
const size_t blocks = (num_rows + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
dim3 grid_shape = dim3(blocks, num_outputs);
reduce_base_sums<<<grid_shape, THREADS_PER_BLOCK, 0, stream>>>(
g, h, num_rows, g_shape.lo[0], tree.node_sums, num_outputs, quantiser);
g, h, num_rows, g_shape.lo[0], tree.node_sums, num_outputs, quantiser, seed);
CHECK_CUDA_STREAM(stream);

SumAllReduce(
Expand Down Expand Up @@ -873,15 +912,15 @@ struct build_tree_fn {
split_proposals,
quantiser);

builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha);
builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha, seed);

for (int depth = 0; depth < max_depth; ++depth) {
auto batches = builder.PrepareBatches(depth, thrust_exec_policy);
for (auto batch : batches) {
auto histogram = builder.GetHistogram(batch);

builder.ComputeHistogram(
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch);
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch, seed);

builder.PerformBestSplit(tree, histogram, alpha, batch);
}
Expand Down

0 comments on commit b5193c0

Please sign in to comment.