Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements of histogram kernel #152

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/cpp_utils/cpp_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ void SumAllReduce(legate::TaskContext context, T* x, int count, cudaStream_t str
#define DEFAULT_POLICY thrust::cuda::par
#endif

__device__ inline uint32_t ballot(bool inFlag, uint32_t mask = 0xffffffffu)
{
return __ballot_sync(mask, inFlag);
}

template <typename T>
__device__ inline T shfl(T val, int srcLane, int width = 32, uint32_t mask = 0xffffffffu)
{
return __shfl_sync(mask, val, srcLane, width);
}

class ThrustAllocator : public legate::ScopedAllocator {
public:
using value_type = char;
Expand Down
133 changes: 70 additions & 63 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,59 +135,65 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
}
}

template <typename TYPE, int ELEMENTS_PER_THREAD, int FEATURES_PER_BLOCK>
__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
fill_histogram(legate::AccessorRO<TYPE, 3> X,
size_t n_features,
int64_t sample_offset,
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
size_t n_outputs,
SparseSplitProposals<TYPE> split_proposals,
NodeBatch batch,
Histogram<IntegerGPair> histogram,
legate::Buffer<IntegerGPair, 2> node_sums,
GradientQuantiser quantiser)
template <typename TYPE, int TPB, int FEATURES_PER_WARP>
__global__ static void __launch_bounds__(TPB, MIN_CTAS_PER_SM)
fill_histogram_warp(legate::AccessorRO<TYPE, 3> X,
size_t n_features,
int64_t sample_offset,
legate::AccessorRO<double, 3> g,
legate::AccessorRO<double, 3> h,
size_t n_outputs,
SparseSplitProposals<TYPE> split_proposals,
NodeBatch batch,
Histogram<IntegerGPair> histogram,
legate::Buffer<IntegerGPair, 2> node_sums,
GradientQuantiser quantiser)
{
// block dimensions are (THREADS_PER_BLOCK, 1, 1)
// each thread processes ELEMENTS_PER_THREAD samples and FEATURES_PER_BLOCK features
// the features to process are defined via blockIdx.y

// further improvements:
// * quantize values to work with int instead of double

#pragma unroll
for (int32_t elementIdx = 0; elementIdx < ELEMENTS_PER_THREAD; ++elementIdx) {
// within each iteration a (THREADS_PER_BLOCK, FEATURES_PER_BLOCK)-block of
// data from X is processed.

// check if thread has actual work to do
int64_t idx = (blockIdx.x + elementIdx * gridDim.x) * THREADS_PER_BLOCK + threadIdx.x;
bool validThread = idx < batch.InstancesInBatch();
if (!validThread) continue;
auto [sampleNode, localSampleId] = batch.instances_begin[idx];
int64_t globalSampleId = localSampleId + sample_offset;

bool computeHistogram = ComputeHistogramBin(
sampleNode, node_sums, histogram.ContainsNode(BinaryTree::Parent(sampleNode)));

for (int32_t output = 0; output < n_outputs; output++) {
auto gpair_quantised =
quantiser.Quantise({g[{globalSampleId, 0, output}], h[{globalSampleId, 0, output}]});
for (int32_t featureIdx = 0; featureIdx < FEATURES_PER_BLOCK; featureIdx++) {
int32_t feature = featureIdx + blockIdx.y * FEATURES_PER_BLOCK;
if (computeHistogram && feature < n_features) {
auto x_value = X[{globalSampleId, feature, 0}];
auto bin_idx = split_proposals.FindBin(x_value, feature);

// bin_idx is the first sample that is larger than x_value
if (bin_idx != SparseSplitProposals<TYPE>::NOT_FOUND) {
Histogram<IntegerGPair>::atomic_add_type* addPosition =
reinterpret_cast<Histogram<IntegerGPair>::atomic_add_type*>(
&histogram[{sampleNode, output, bin_idx}]);
atomicAdd(addPosition, gpair_quantised.grad);
atomicAdd(addPosition + 1, gpair_quantised.hess);
}
constexpr int32_t WarpSize = 32;
const int32_t warp_id = threadIdx.x / WarpSize;
const int32_t lane_id = threadIdx.x % WarpSize;

const int32_t localIdx = blockIdx.x * TPB + warp_id * WarpSize + lane_id;

// prefetch sampleNode information for all 32 ids
auto [sampleNode_lane, localSampleId_lane] = (localIdx < batch.InstancesInBatch())
? batch.instances_begin[localIdx]
: cuda::std::make_tuple(-1, -1);
const bool computeHistogram =
localIdx < batch.InstancesInBatch() &&
ComputeHistogramBin(
sampleNode_lane, node_sums, histogram.ContainsNode(BinaryTree::Parent(sampleNode_lane)));

// mask contains all sample bits of the next 32 ids that need to be bin'ed
auto lane_mask = ballot(computeHistogram);

// reverse to use __clz instead of __ffs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of this? Just curious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to warp-synchronously jump to samples that need to be computed. Otherwise we would need to check every id explicitly which would require 32 _shfl operations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean why __cls vs __ffs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe __cls is preferrable for performance reasons -- I did not confirm in this scope though.

lane_mask = __brev(lane_mask);

while (lane_mask) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way of getting the bit fidding and intrinsics abstracted out:

template<typname T>
class WarpQueue{
private:
T x;
int32_t mask;
public:
WarpQueue(T x, bool active) ...
bool Empty(){}
T Pop(){}
}

WarpQueue warp_items(local_items, active);
while(!warp_items.empty()){
  auto [node, sampleId] = warp_items.Pop();
}

Feel free to ignore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work for queuing the offsets, but introducing the different storage locations for the actual items would increase the overhead in registers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggestion is only for the offsets.

// look for next lane_offset / sample to process within warp-batch
const uint32_t lane_offset = __clz(lane_mask);
const int32_t sampleNode = shfl(sampleNode_lane, lane_offset);
const int32_t localSampleId = shfl(localSampleId_lane, lane_offset);

// remove lane_offset bit from lane_mask for next iteration
lane_mask &= (0x7fffffff >> lane_offset);

auto feature_begin = blockIdx.y * FEATURES_PER_WARP;
auto feature_end = min(n_features, (size_t)feature_begin + FEATURES_PER_WARP);
for (int32_t feature = feature_begin + lane_id; feature < feature_end; feature += WarpSize) {
const int32_t bin_idx =
split_proposals.FindBin(X[{sample_offset + localSampleId, feature, 0}], feature);
for (int32_t output = 0; output < n_outputs; output++) {
// get same G/H from every thread in warp
auto gpair_quantised = quantiser.Quantise({g[{sample_offset + localSampleId, 0, output}],
h[{sample_offset + localSampleId, 0, output}]});
Histogram<IntegerGPair>::atomic_add_type* addPosition =
reinterpret_cast<Histogram<IntegerGPair>::atomic_add_type*>(
&histogram[{sampleNode, output, bin_idx}]);
if (bin_idx != SparseSplitProposals<TYPE>::NOT_FOUND) {
atomicAdd(addPosition, gpair_quantised.grad);
atomicAdd(addPosition + 1, gpair_quantised.hess);
}
}
}
Expand Down Expand Up @@ -621,17 +627,18 @@ struct TreeBuilder {
legate::AccessorRO<double, 3> h,
NodeBatch batch)
{
// TODO adjust kernel parameters dynamically
constexpr size_t elements_per_thread = 8;
constexpr size_t features_per_block = 16;

const size_t blocks_x =
(batch.InstancesInBatch() + THREADS_PER_BLOCK * elements_per_thread - 1) /
(THREADS_PER_BLOCK * elements_per_thread);
const size_t blocks_y = (num_features + features_per_block - 1) / features_per_block;
dim3 grid_shape = dim3(blocks_x, blocks_y, 1);
fill_histogram<TYPE, elements_per_thread, features_per_block>
<<<grid_shape, THREADS_PER_BLOCK, 0, stream>>>(X,
// warp kernel without additional caching / prefetching
const int threads_per_block = 256;
const size_t blocks_x = (batch.InstancesInBatch() + threads_per_block - 1) / threads_per_block;

// splitting the features to ensure better work distribution for large numbers of features
// while larger value also allow better caching of g & h,
// smaller values improve access of the split_proposals
const int features_per_warp = 64;
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
const size_t blocks_y = (num_features + features_per_warp - 1) / features_per_warp;
dim3 grid_shape = dim3(blocks_x, blocks_y, 1);
fill_histogram_warp<TYPE, threads_per_block, features_per_warp>
<<<grid_shape, threads_per_block, 0, stream>>>(X,
num_features,
X_shape.lo[0],
g,
Expand Down