Skip to content

Commit

Permalink
Merge branch 'main' of github.com:rapidsai/legateboost into scan-kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Sep 20, 2024
2 parents fca55a1 + 2f42ebb commit 3f9fb59
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 34 deletions.
14 changes: 11 additions & 3 deletions src/cpp_utils/cpp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ constexpr decltype(auto) type_dispatch_float(legate::Type::Code code, Functor&&
type_dispatch<float, double>(code, f, std::forward<Fnargs>(args)...);
}

template <typename T>
void SumAllReduce(legate::TaskContext context, T* x, int count)
template <typename T, typename OpT>
void AllReduce(legate::TaskContext context, T* x, int count, OpT op)
{
auto domain = context.get_launch_domain();
size_t num_ranks = domain.get_volume();
Expand Down Expand Up @@ -138,7 +138,9 @@ void SumAllReduce(legate::TaskContext context, T* x, int count)
// Sum partials
std::vector<T> partials(items_per_rank, 0.0);
for (size_t j = 0; j < items_per_rank; j++) {
for (size_t i = 0; i < num_ranks; i++) { partials[j] += recvbuf[i * items_per_rank + j]; }
for (size_t i = 0; i < num_ranks; i++) {
partials[j] = op(partials[j], recvbuf[i * items_per_rank + j]);
}
}

result = legate::comm::coll::collAllgather(
Expand All @@ -147,6 +149,12 @@ void SumAllReduce(legate::TaskContext context, T* x, int count)
std::copy(recvbuf.begin(), recvbuf.begin() + count, x);
}

template <typename T>
void SumAllReduce(legate::TaskContext context, T* x, int count)
{
AllReduce(context, x, count, std::plus<T>());
}

/**
* @brief Turns linear index into multi-dimension index. Similar to numpy unravel.
*/
Expand Down
106 changes: 75 additions & 31 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct NodeBatch {
};

class GradientQuantiser {
IntegerGPair scale;
GPair scale;
GPair inverse_scale;

public:
Expand Down Expand Up @@ -74,28 +74,38 @@ class GradientQuantiser {
int num_outputs = g_shape.hi[2] - g_shape.lo[2] + 1;
std::size_t n = (g_shape.hi[0] - g_shape.lo[0] + 1) * num_outputs;
auto zip_gpair = thrust::make_transform_iterator(counting, GetAbsGPair{num_outputs, g, h});
GPair abs_sum =
GPair local_abs_sum =
thrust::reduce(policy, zip_gpair, zip_gpair + n, GPair{0.0, 0.0}, thrust::plus<GPair>());
SumAllReduce(context, reinterpret_cast<double*>(&abs_sum), 2);
// Take the max of the local sums
AllReduce(context, reinterpret_cast<double*>(&local_abs_sum), 2, [](double a, double b) {
return std::max(a, b);
});

// We will quantise values between -max_int and max_int
// Double precision can exactly represent integers in this range
// So we can go back and forth between double and int64_t without overflow
int64_t double_max_int = 1ll << 51;
int64_t max_int =
std::min<int64_t>(double_max_int, std::numeric_limits<IntegerGPair::value_type>::max());
scale.grad = abs_sum.grad == 0 ? 1 : max_int / abs_sum.grad;
scale.hess = abs_sum.hess == 0 ? 1 : max_int / abs_sum.hess;
int64_t max_int = std::numeric_limits<int32_t>::max();
scale.grad = local_abs_sum.grad == 0 ? 1 : max_int / local_abs_sum.grad;
scale.hess = local_abs_sum.hess == 0 ? 1 : max_int / local_abs_sum.hess;
inverse_scale.grad = 1.0 / scale.grad;
inverse_scale.hess = 1.0 / scale.hess;
}

__device__ IntegerGPair Quantise(GPair value) const
// Round gradient and hessian using stochastic rounding
// Thus the expected value of the quantised value is unbiased
// Also the expected error grows as O(1/sqrt(n)) where n is the number of samples
// Vs. O(1/n) for round nearest
// The seed here should be unique for each gpair over each boosting iteration
// Use a hash combine function to generate the seed
__device__ IntegerGPair QuantiseStochasticRounding(GPair value, int64_t seed) const
{
IntegerGPair result;
result.grad = value.grad * scale.grad;
result.hess = value.hess * scale.hess;
return result;
thrust::default_random_engine eng(seed);
thrust::uniform_real_distribution<double> dist(0.0, 1.0);
auto scaled_grad = value.grad * scale.grad;
auto scaled_hess = value.hess * scale.hess;
double grad_remainder = scaled_grad - floor(scaled_grad);
double hess_remainder = scaled_hess - floor(scaled_hess);
IntegerGPair::value_type grad_quantised = floor(scaled_grad) + (dist(eng) < grad_remainder);
IntegerGPair::value_type hess_quantised = floor(scaled_hess) + (dist(eng) < hess_remainder);
return IntegerGPair{grad_quantised, hess_quantised};
}

__device__ GPair Dequantise(IntegerGPair value) const
Expand All @@ -107,6 +117,28 @@ class GradientQuantiser {
}
};

// Hash function fmix64 from MurmurHash3
__device__ int64_t hash(int64_t k)
{
k ^= k >> 33;
k *= 0xff51afd7ed558ccd;
k ^= k >> 33;
k *= 0xc4ceb9fe1a85ec53;
k ^= k >> 33;
return k;
}

__device__ int64_t hash_combine(int64_t seed) { return seed; }

// Hash combine from boost
// This function is used to combine several random seeds e.g. a 3d index
template <typename... Rest>
__device__ int64_t hash_combine(int64_t seed, const int64_t& v, Rest... rest)
{
seed ^= hash(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return hash_combine(seed, rest...);
}

template <int BLOCK_THREADS>
__global__ static void __launch_bounds__(BLOCK_THREADS)
reduce_base_sums(legate::AccessorRO<double, 3> g,
Expand All @@ -115,7 +147,8 @@ __global__ static void __launch_bounds__(BLOCK_THREADS)
int64_t sample_offset,
legate::Buffer<IntegerGPair, 2> node_sums,
size_t n_outputs,
GradientQuantiser quantiser)
GradientQuantiser quantiser,
int64_t seed)
{
typedef cub::BlockReduce<IntegerGPair, BLOCK_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
Expand All @@ -124,10 +157,13 @@ __global__ static void __launch_bounds__(BLOCK_THREADS)

int64_t sample_id = threadIdx.x + blockDim.x * blockIdx.x;

double grad = sample_id < n_local_samples ? g[{sample_id + sample_offset, 0, output}] : 0.0;
double hess = sample_id < n_local_samples ? h[{sample_id + sample_offset, 0, output}] : 0.0;
legate::Point<3> p = {sample_id + sample_offset, 0, output};
double grad = sample_id < n_local_samples ? g[p] : 0.0;
double hess = sample_id < n_local_samples ? h[p] : 0.0;

IntegerGPair blocksum = BlockReduce(temp_storage).Sum(quantiser.Quantise({grad, hess}));
auto quantised =
quantiser.QuantiseStochasticRounding({grad, hess}, hash_combine(seed, p[0], p[2]));
IntegerGPair blocksum = BlockReduce(temp_storage).Sum(quantised);

if (threadIdx.x == 0) {
atomicAdd(
Expand All @@ -151,7 +187,8 @@ __global__ static void __launch_bounds__(TPB, 4)
NodeBatch batch,
Histogram<IntegerGPair> histogram,
legate::Buffer<IntegerGPair, 2> node_sums,
GradientQuantiser quantiser)
GradientQuantiser quantiser,
int64_t seed)
{
constexpr int32_t WarpSize = 32;
const int32_t warp_id = threadIdx.x / WarpSize;
Expand Down Expand Up @@ -190,12 +227,16 @@ __global__ static void __launch_bounds__(TPB, 4)
split_proposals.FindBin(X[{sample_offset + localSampleId, feature, 0}], feature);
for (int32_t output = 0; output < n_outputs; output++) {
// get same G/H from every thread in warp
auto gpair_quantised = quantiser.Quantise({g[{sample_offset + localSampleId, 0, output}],
h[{sample_offset + localSampleId, 0, output}]});
Histogram<IntegerGPair>::atomic_add_type* addPosition =
reinterpret_cast<Histogram<IntegerGPair>::atomic_add_type*>(
&histogram[{sampleNode, output, bin_idx}]);
legate::Point<3> p = {sample_offset + localSampleId, feature, output};
auto gpair_quantised =
quantiser.QuantiseStochasticRounding({g[p], h[p]}, hash_combine(seed, p[0], p[2]));
auto* addPosition = reinterpret_cast<typename IntegerGPair::value_type*>(
&histogram[{sampleNode, output, bin_idx}]);

if (bin_idx != SparseSplitProposals<TYPE>::NOT_FOUND) {
Histogram<IntegerGPair>::atomic_add_type* addPosition =
reinterpret_cast<Histogram<IntegerGPair>::atomic_add_type*>(
&histogram[{sampleNode, output, bin_idx}]);
atomicAdd(addPosition, gpair_quantised.grad);
atomicAdd(addPosition + 1, gpair_quantised.hess);
}
Expand Down Expand Up @@ -641,7 +682,8 @@ struct TreeBuilder {
legate::Rect<3> X_shape,
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
NodeBatch batch)
NodeBatch batch,
int64_t seed)
{
// warp kernel without additional caching / prefetching
const int threads_per_block = 256;
Expand All @@ -664,10 +706,11 @@ struct TreeBuilder {
batch,
histogram,
tree.node_sums,
quantiser);
quantiser,
seed);

CHECK_CUDA_STREAM(stream);
static_assert(sizeof(GPair) == 2 * sizeof(double), "GPair must be 2 doubles");

SumAllReduce(context,
reinterpret_cast<Histogram<IntegerGPair>::value_type::value_type*>(
histogram.Ptr(batch.node_idx_begin)),
Expand Down Expand Up @@ -713,7 +756,8 @@ struct TreeBuilder {
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
legate::Rect<3> g_shape,
double alpha)
double alpha,
int64_t seed)
{
const int kBlockThreads = 256;
const size_t blocks = (num_rows + kBlockThreads - 1) / kBlockThreads;
Expand Down Expand Up @@ -875,15 +919,15 @@ struct build_tree_fn {
split_proposals,
quantiser);

builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha);
builder.InitialiseRoot(context, tree, g_accessor, h_accessor, g_shape, alpha, seed);

for (int depth = 0; depth < max_depth; ++depth) {
auto batches = builder.PrepareBatches(depth, thrust_exec_policy);
for (auto batch : batches) {
auto histogram = builder.GetHistogram(batch);

builder.ComputeHistogram(
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch);
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch, seed);

builder.PerformBestSplit(tree, histogram, alpha, batch);
}
Expand Down

0 comments on commit 3f9fb59

Please sign in to comment.