diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 38c58b495c3e..324674b33dbe 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -427,36 +427,6 @@ class TemporaryArray { size_t size_; }; -/** - * \brief A double buffer, useful for algorithms like sort. - */ -template -class DoubleBuffer { - public: - cub::DoubleBuffer buff; - xgboost::common::Span a, b; - DoubleBuffer() = default; - template - DoubleBuffer(VectorT *v1, VectorT *v2) { - a = xgboost::common::Span(v1->data().get(), v1->size()); - b = xgboost::common::Span(v2->data().get(), v2->size()); - buff = cub::DoubleBuffer(a.data(), b.data()); - } - - size_t Size() const { - CHECK_EQ(a.size(), b.size()); - return a.size(); - } - cub::DoubleBuffer &CubBuffer() { return buff; } - - T *Current() { return buff.Current(); } - xgboost::common::Span CurrentSpan() { - return xgboost::common::Span{buff.Current(), Size()}; - } - - T *Other() { return buff.Alternate(); } -}; - /** * \brief Copies device span to std::vector. * diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 25a72940c222..2b7fbe4afda1 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -93,26 +93,23 @@ void RowPartitioner::SortPosition(common::Span position, position.size(), stream); } +void Reset(int device_idx, common::Span ridx, + common::Span position) { + CHECK_EQ(ridx.size(), position.size()); + dh::LaunchN(device_idx, ridx.size(), [=] __device__(size_t idx) { + ridx[idx] = idx; + position[idx] = 0; + }); +} + RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) - : device_idx_(device_idx) { + : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows) { dh::safe_cuda(cudaSetDevice(device_idx_)); - ridx_a_.resize(num_rows); - ridx_b_.resize(num_rows); - position_a_.resize(num_rows); - position_b_.resize(num_rows); - ridx_ = dh::DoubleBuffer{&ridx_a_, &ridx_b_}; - position_ = dh::DoubleBuffer{&position_a_, &position_b_}; - ridx_segments_.emplace_back(Segment(0, num_rows)); - - thrust::sequence( - thrust::device_pointer_cast(ridx_.CurrentSpan().data()), - thrust::device_pointer_cast(ridx_.CurrentSpan().data() + ridx_.Size())); - thrust::fill( - thrust::device_pointer_cast(position_.Current()), - thrust::device_pointer_cast(position_.Current() + position_.Size()), 0); + Reset(device_idx, dh::ToSpan(ridx_a_), dh::ToSpan(position_a_)); left_counts_.resize(256); thrust::fill(left_counts_.begin(), left_counts_.end(), 0); streams_.resize(2); + ridx_segments_.emplace_back(Segment(0, num_rows)); for (auto& stream : streams_) { dh::safe_cuda(cudaStreamCreate(&stream)); } @@ -132,15 +129,15 @@ common::Span RowPartitioner::GetRows( if (segment.Size() == 0) { return common::Span(); } - return ridx_.CurrentSpan().subspan(segment.begin, segment.Size()); + return dh::ToSpan(ridx_a_).subspan(segment.begin, segment.Size()); } common::Span RowPartitioner::GetRows() { - return ridx_.CurrentSpan(); + return dh::ToSpan(ridx_a_); } common::Span RowPartitioner::GetPosition() { - return position_.CurrentSpan(); + return dh::ToSpan(position_a_); } std::vector RowPartitioner::GetRowsHost( bst_node_t nidx) { @@ -162,23 +159,25 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment, bst_node_t right_nidx, int64_t* d_left_count, cudaStream_t stream) { + dh::TemporaryArray position_temp(position_a_.size()); + dh::TemporaryArray ridx_temp(ridx_a_.size()); SortPosition( // position_in - common::Span(position_.Current() + segment.begin, + common::Span(position_a_.data().get() + segment.begin, segment.Size()), // position_out - common::Span(position_.Other() + segment.begin, + common::Span(position_temp.data().get() + segment.begin, segment.Size()), // row index in - common::Span(ridx_.Current() + segment.begin, segment.Size()), + common::Span(ridx_a_.data().get() + segment.begin, segment.Size()), // row index out - common::Span(ridx_.Other() + segment.begin, segment.Size()), + common::Span(ridx_temp.data().get() + segment.begin, segment.Size()), left_nidx, right_nidx, d_left_count, stream); // Copy back key/value - const auto d_position_current = position_.Current() + segment.begin; - const auto d_position_other = position_.Other() + segment.begin; - const auto d_ridx_current = ridx_.Current() + segment.begin; - const auto d_ridx_other = ridx_.Other() + segment.begin; + const auto d_position_current = position_a_.data().get() + segment.begin; + const auto d_position_other = position_temp.data().get() + segment.begin; + const auto d_ridx_current = ridx_a_.data().get() + segment.begin; + const auto d_ridx_other = ridx_temp.data().get() + segment.begin; dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) { d_position_current[idx] = d_position_other[idx]; d_ridx_current[idx] = d_ridx_other[idx]; diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index fd42234fd345..e0e0998ea98f 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -46,18 +46,8 @@ class RowPartitioner { */ /*! \brief Range of row index for each node, pointers into ridx below. */ std::vector ridx_segments_; - dh::caching_device_vector ridx_a_; - dh::caching_device_vector ridx_b_; - dh::caching_device_vector position_a_; - dh::caching_device_vector position_b_; - /*! \brief mapping for node id -> rows. - * This looks like: - * node id | 1 | 2 | - * rows idx | 3, 5, 1 | 13, 31 | - */ - dh::DoubleBuffer ridx_; - /*! \brief mapping for row -> node id. */ - dh::DoubleBuffer position_; + dh::TemporaryArray ridx_a_; + dh::TemporaryArray position_a_; dh::caching_device_vector left_counts_; // Useful to keep a bunch of zeroed memory for sort position std::vector streams_; @@ -110,8 +100,8 @@ class RowPartitioner { void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx, bst_node_t right_nidx, UpdatePositionOpT op) { Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx - auto d_ridx = ridx_.CurrentSpan(); - auto d_position = position_.CurrentSpan(); + auto d_ridx = dh::ToSpan(ridx_a_); + auto d_position = dh::ToSpan(position_a_); if (left_counts_.size() <= nidx) { left_counts_.resize((nidx * 2) + 1); thrust::fill(left_counts_.begin(), left_counts_.end(), 0); @@ -159,9 +149,9 @@ class RowPartitioner { */ template void FinalisePosition(FinalisePositionOpT op) { - auto d_position = position_.Current(); - const auto d_ridx = ridx_.Current(); - dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) { + auto d_position = position_a_.data().get(); + const auto d_ridx = ridx_a_.data().get(); + dh::LaunchN(device_idx_, position_a_.size(), [=] __device__(size_t idx) { auto position = d_position[idx]; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx, position); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2de692c0a117..25d2645e1032 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -511,7 +511,6 @@ struct GPUHistMakerDevice { reinterpret_cast(d_node_hist), reinterpret_cast(d_node_hist), page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); - reducer->Synchronize(); monitor.Stop("AllReduce"); }