diff --git a/src/cpp_utils/cpp_utils.cuh b/src/cpp_utils/cpp_utils.cuh index 34e03437..f8dc3bf3 100644 --- a/src/cpp_utils/cpp_utils.cuh +++ b/src/cpp_utils/cpp_utils.cuh @@ -96,6 +96,17 @@ void SumAllReduce(legate::TaskContext context, T* x, int count, cudaStream_t str #define DEFAULT_POLICY thrust::cuda::par #endif +__device__ inline uint32_t ballot(bool inFlag, uint32_t mask = 0xffffffffu) +{ + return __ballot_sync(mask, inFlag); +} + +template +__device__ inline T shfl(T val, int srcLane, int width = 32, uint32_t mask = 0xffffffffu) +{ + return __shfl_sync(mask, val, srcLane, width); +} + class ThrustAllocator : public legate::ScopedAllocator { public: using value_type = char; diff --git a/src/models/tree/build_tree.cu b/src/models/tree/build_tree.cu index aa5f2ad0..a28d346c 100644 --- a/src/models/tree/build_tree.cu +++ b/src/models/tree/build_tree.cu @@ -135,59 +135,65 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) } } -template -__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - fill_histogram(legate::AccessorRO X, - size_t n_features, - int64_t sample_offset, - legate::AccessorRO g, - legate::AccessorRO h, - size_t n_outputs, - SparseSplitProposals split_proposals, - NodeBatch batch, - Histogram histogram, - legate::Buffer node_sums, - GradientQuantiser quantiser) +template +__global__ static void __launch_bounds__(TPB, MIN_CTAS_PER_SM) + fill_histogram_warp(legate::AccessorRO X, + size_t n_features, + int64_t sample_offset, + legate::AccessorRO g, + legate::AccessorRO h, + size_t n_outputs, + SparseSplitProposals split_proposals, + NodeBatch batch, + Histogram histogram, + legate::Buffer node_sums, + GradientQuantiser quantiser) { - // block dimensions are (THREADS_PER_BLOCK, 1, 1) - // each thread processes ELEMENTS_PER_THREAD samples and FEATURES_PER_BLOCK features - // the features to process are defined via blockIdx.y - - // further improvements: - // * quantize values to work with int instead of double - -#pragma unroll - for (int32_t elementIdx = 0; elementIdx < ELEMENTS_PER_THREAD; ++elementIdx) { - // within each iteration a (THREADS_PER_BLOCK, FEATURES_PER_BLOCK)-block of - // data from X is processed. - - // check if thread has actual work to do - int64_t idx = (blockIdx.x + elementIdx * gridDim.x) * THREADS_PER_BLOCK + threadIdx.x; - bool validThread = idx < batch.InstancesInBatch(); - if (!validThread) continue; - auto [sampleNode, localSampleId] = batch.instances_begin[idx]; - int64_t globalSampleId = localSampleId + sample_offset; - - bool computeHistogram = ComputeHistogramBin( - sampleNode, node_sums, histogram.ContainsNode(BinaryTree::Parent(sampleNode))); - - for (int32_t output = 0; output < n_outputs; output++) { - auto gpair_quantised = - quantiser.Quantise({g[{globalSampleId, 0, output}], h[{globalSampleId, 0, output}]}); - for (int32_t featureIdx = 0; featureIdx < FEATURES_PER_BLOCK; featureIdx++) { - int32_t feature = featureIdx + blockIdx.y * FEATURES_PER_BLOCK; - if (computeHistogram && feature < n_features) { - auto x_value = X[{globalSampleId, feature, 0}]; - auto bin_idx = split_proposals.FindBin(x_value, feature); - - // bin_idx is the first sample that is larger than x_value - if (bin_idx != SparseSplitProposals::NOT_FOUND) { - Histogram::atomic_add_type* addPosition = - reinterpret_cast::atomic_add_type*>( - &histogram[{sampleNode, output, bin_idx}]); - atomicAdd(addPosition, gpair_quantised.grad); - atomicAdd(addPosition + 1, gpair_quantised.hess); - } + constexpr int32_t WarpSize = 32; + const int32_t warp_id = threadIdx.x / WarpSize; + const int32_t lane_id = threadIdx.x % WarpSize; + + const int32_t localIdx = blockIdx.x * TPB + warp_id * WarpSize + lane_id; + + // prefetch sampleNode information for all 32 ids + auto [sampleNode_lane, localSampleId_lane] = (localIdx < batch.InstancesInBatch()) + ? batch.instances_begin[localIdx] + : cuda::std::make_tuple(-1, -1); + const bool computeHistogram = + localIdx < batch.InstancesInBatch() && + ComputeHistogramBin( + sampleNode_lane, node_sums, histogram.ContainsNode(BinaryTree::Parent(sampleNode_lane))); + + // mask contains all sample bits of the next 32 ids that need to be bin'ed + auto lane_mask = ballot(computeHistogram); + + // reverse to use __clz instead of __ffs + lane_mask = __brev(lane_mask); + + while (lane_mask) { + // look for next lane_offset / sample to process within warp-batch + const uint32_t lane_offset = __clz(lane_mask); + const int32_t sampleNode = shfl(sampleNode_lane, lane_offset); + const int32_t localSampleId = shfl(localSampleId_lane, lane_offset); + + // remove lane_offset bit from lane_mask for next iteration + lane_mask &= (0x7fffffff >> lane_offset); + + auto feature_begin = blockIdx.y * FEATURES_PER_WARP; + auto feature_end = min(n_features, (size_t)feature_begin + FEATURES_PER_WARP); + for (int32_t feature = feature_begin + lane_id; feature < feature_end; feature += WarpSize) { + const int32_t bin_idx = + split_proposals.FindBin(X[{sample_offset + localSampleId, feature, 0}], feature); + for (int32_t output = 0; output < n_outputs; output++) { + // get same G/H from every thread in warp + auto gpair_quantised = quantiser.Quantise({g[{sample_offset + localSampleId, 0, output}], + h[{sample_offset + localSampleId, 0, output}]}); + Histogram::atomic_add_type* addPosition = + reinterpret_cast::atomic_add_type*>( + &histogram[{sampleNode, output, bin_idx}]); + if (bin_idx != SparseSplitProposals::NOT_FOUND) { + atomicAdd(addPosition, gpair_quantised.grad); + atomicAdd(addPosition + 1, gpair_quantised.hess); } } } @@ -621,17 +627,18 @@ struct TreeBuilder { legate::AccessorRO h, NodeBatch batch) { - // TODO adjust kernel parameters dynamically - constexpr size_t elements_per_thread = 8; - constexpr size_t features_per_block = 16; - - const size_t blocks_x = - (batch.InstancesInBatch() + THREADS_PER_BLOCK * elements_per_thread - 1) / - (THREADS_PER_BLOCK * elements_per_thread); - const size_t blocks_y = (num_features + features_per_block - 1) / features_per_block; - dim3 grid_shape = dim3(blocks_x, blocks_y, 1); - fill_histogram - <<>>(X, + // warp kernel without additional caching / prefetching + const int threads_per_block = 256; + const size_t blocks_x = (batch.InstancesInBatch() + threads_per_block - 1) / threads_per_block; + + // splitting the features to ensure better work distribution for large numbers of features + // while larger value also allow better caching of g & h, + // smaller values improve access of the split_proposals + const int features_per_warp = 64; + const size_t blocks_y = (num_features + features_per_warp - 1) / features_per_warp; + dim3 grid_shape = dim3(blocks_x, blocks_y, 1); + fill_histogram_warp + <<>>(X, num_features, X_shape.lo[0], g,