Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the code path between local and distributed training. #9433

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class HistCollection {
data_[0].resize(new_size);
}
}
[[nodiscard]] bool IsContiguous() const { return contiguous_allocation_; }

private:
/*! \brief number of all bins over all features */
Expand Down
181 changes: 33 additions & 148 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
#include "expand_entry.h"
#include "xgboost/tree_model.h" // for RegTree

namespace xgboost {
namespace tree {
namespace xgboost::tree {
template <typename ExpandEntry>
class HistogramBuilder {
/*! \brief culmulative histogram of gradients. */
common::HistCollection hist_;
/*! \brief culmulative local parent histogram of gradients. */
common::HistCollection hist_local_worker_;
common::ParallelGHistBuilder buffer_;
BatchParam param_;
int32_t n_threads_{-1};
Expand All @@ -46,12 +43,9 @@ class HistogramBuilder {
n_batches_ = n_batches;
param_ = p;
hist_.Init(total_bins);
hist_local_worker_.Init(total_bins);
buffer_.Init(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
// Workaround s390x gcc 7.5.0
auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce;
}

template <bool any_missing>
Expand Down Expand Up @@ -91,17 +85,19 @@ class HistogramBuilder {
});
}

void AddHistRows(int *starting_index, int *sync_count,
void AddHistRows(int *starting_index,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
if (is_distributed_ && !is_col_split_) {
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
} else {
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
for (auto const &entry : nodes_for_explicit_hist_build) {
int nid = entry.nid;
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
}

for (auto const &node : nodes_for_subtraction_trick) {
this->hist_.AddHistRow(node.nid);
}
this->hist_.AllocateAllData();
}

/** Main entry point of this class, build histogram for tree nodes. */
Expand All @@ -111,10 +107,9 @@ class HistogramBuilder {
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
if (page_id == 0) {
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, p_tree);
this->AddHistRows(&starting_index, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick);
}
if (gidx.IsDense()) {
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
Expand All @@ -129,13 +124,8 @@ class HistogramBuilder {
return;
}

if (is_distributed_ && !is_col_split_) {
this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick,
starting_index, sync_count);
} else {
this->SyncHistogramLocal(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick);
}
this->SyncHistogram(p_tree, nodes_for_explicit_hist_build,
nodes_for_subtraction_trick, starting_index);
}
/** same as the other build hist but handles only single batch data (in-core) */
void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree,
Expand All @@ -156,62 +146,33 @@ class HistogramBuilder {
nodes_for_subtraction_trick, gpair, force_read_by_column);
}

void SyncHistogramDistributed(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index, int sync_count) {
void SyncHistogram(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
int starting_index) {
auto n_bins = buffer_.TotalBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return n_bins; }, 1024);
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) {
CHECK(hist_.IsContiguous());
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
buffer_.ReduceHist(node, r.begin(), r.end());
// Store posible parent node
auto this_local = hist_local_worker_[entry.nid];
common::CopyHist(this_local, this_hist, r.begin(), r.end());

if (!p_tree->IsRoot(entry.nid)) {
const size_t parent_id = p_tree->Parent(entry.nid);
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto parent_hist = this->hist_local_worker_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
// Store posible parent node
auto sibling_local = hist_local_worker_[subtraction_node_id];
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
}
this->buffer_.ReduceHist(node, r.begin(), r.end());
});

collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()), n_bins * sync_count * 2);

ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
p_tree);

common::BlockedSpace2d space2(
nodes_for_subtraction_trick.size(), [&](size_t) { return n_bins; }, 1024);
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
p_tree);
}

void SyncHistogramLocal(RegTree const *p_tree,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
const size_t nbins = this->buffer_.TotalBins();
common::BlockedSpace2d space(
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
if (is_distributed_ && !is_col_split_) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[starting_index].data()),
n_bins * nodes_for_explicit_hist_build.size() * 2);
}

common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[node];
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
const auto &entry = nodes_for_explicit_hist_build[nidx_in_set];
auto this_hist = this->hist_[entry.nid];
// Merging histograms from each thread into once
this->buffer_.ReduceHist(node, r.begin(), r.end());

if (!p_tree->IsRoot(entry.nid)) {
auto const parent_id = p_tree->Parent(entry.nid);
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
auto const subtraction_node_id = nodes_for_subtraction_trick[nidx_in_set].nid;
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[subtraction_node_id];
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
Expand All @@ -222,82 +183,7 @@ class HistogramBuilder {
public:
/* Getters for tests. */
common::HistCollection const &Histogram() { return hist_; }
auto& Buffer() { return buffer_; }

private:
void
ParallelSubtractionHist(const common::BlockedSpace2d &space,
const std::vector<ExpandEntry> &nodes,
const std::vector<ExpandEntry> &subtraction_nodes,
const RegTree *p_tree) {
common::ParallelFor2d(
space, this->n_threads_, [&](size_t node, common::Range1d r) {
const auto &entry = nodes[node];
if (!(p_tree->IsLeftChild(entry.nid))) {
auto this_hist = this->hist_[entry.nid];

if (!p_tree->IsRoot(entry.nid)) {
const int subtraction_node_id = subtraction_nodes[node].nid;
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
auto sibling_hist = hist_[subtraction_node_id];
common::SubtractionHist(this_hist, parent_hist, sibling_hist,
r.begin(), r.end());
}
}
});
}

// Add a tree node to histogram buffer in local training environment.
void AddHistRowsLocal(
int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
for (auto const &entry : nodes_for_explicit_hist_build) {
int nid = entry.nid;
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
}
(*sync_count) = nodes_for_explicit_hist_build.size();

for (auto const &node : nodes_for_subtraction_trick) {
this->hist_.AddHistRow(node.nid);
}
this->hist_.AllocateAllData();
}

void AddHistRowsDistributed(int *starting_index, int *sync_count,
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
RegTree const *p_tree) {
const size_t explicit_size = nodes_for_explicit_hist_build.size();
const size_t subtaction_size = nodes_for_subtraction_trick.size();
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
for (size_t i = 0; i < explicit_size; ++i) {
merged_node_ids[i] = nodes_for_explicit_hist_build[i].nid;
}
for (size_t i = 0; i < subtaction_size; ++i) {
merged_node_ids[explicit_size + i] = nodes_for_subtraction_trick[i].nid;
}
std::sort(merged_node_ids.begin(), merged_node_ids.end());
int n_left = 0;
for (auto const &nid : merged_node_ids) {
if (p_tree->IsLeftChild(nid)) {
this->hist_.AddHistRow(nid);
(*starting_index) = std::min(nid, (*starting_index));
n_left++;
this->hist_local_worker_.AddHistRow(nid);
}
}
for (auto const &nid : merged_node_ids) {
if (!(p_tree->IsLeftChild(nid))) {
this->hist_.AddHistRow(nid);
this->hist_local_worker_.AddHistRow(nid);
}
}
this->hist_.AllocateAllData();
this->hist_local_worker_.AllocateAllData();
(*sync_count) = std::max(1, n_left);
}
auto &Buffer() { return buffer_; }
};

// Construct a work space for building histogram. Eventually we should move this
Expand All @@ -318,6 +204,5 @@ common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256};
return space;
}
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree
#endif // XGBOOST_TREE_HIST_HISTOGRAM_H_
34 changes: 10 additions & 24 deletions tests/cpp/tree/hist/test_histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void TestAddHistRows(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;

size_t constexpr kNRows = 8, kNCols = 16;
int32_t constexpr kMaxBins = 4;
Expand All @@ -49,11 +48,9 @@ void TestAddHistRows(bool is_distributed) {
HistogramBuilder<CPUExpandEntry> histogram_builder;
histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1,
is_distributed, false);
histogram_builder.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram_builder.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);

ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3);

for (const CPUExpandEntry &node : nodes_for_explicit_hist_build_) {
Expand All @@ -78,7 +75,6 @@ void TestSyncHist(bool is_distributed) {
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
std::vector<CPUExpandEntry> nodes_for_subtraction_trick_;
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
RegTree tree;

auto p_fmat = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
Expand All @@ -100,9 +96,8 @@ void TestSyncHist(bool is_distributed) {

// level 0
nodes_for_explicit_hist_build_.emplace_back(0, tree.GetDepth(0));
histogram.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);

tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.clear();
Expand All @@ -112,9 +107,8 @@ void TestSyncHist(bool is_distributed) {
nodes_for_explicit_hist_build_.emplace_back(tree[0].LeftChild(), tree.GetDepth(1));
nodes_for_subtraction_trick_.emplace_back(tree[0].RightChild(), tree.GetDepth(2));

histogram.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);

tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
Expand All @@ -127,9 +121,8 @@ void TestSyncHist(bool is_distributed) {
nodes_for_explicit_hist_build_.emplace_back(5, tree.GetDepth(5));
nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6));

histogram.AddHistRows(&starting_index, &sync_count,
nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, &tree);
histogram.AddHistRows(&starting_index, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);

const size_t n_nodes = nodes_for_explicit_hist_build_.size();
ASSERT_EQ(n_nodes, 2ul);
Expand Down Expand Up @@ -175,14 +168,8 @@ void TestSyncHist(bool is_distributed) {

histogram.Buffer().Reset(1, n_nodes, space, target_hists);
// sync hist
if (is_distributed) {
histogram.SyncHistogramDistributed(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_,
starting_index, sync_count);
} else {
histogram.SyncHistogramLocal(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_);
}
histogram.SyncHistogram(&tree, nodes_for_explicit_hist_build_,
nodes_for_subtraction_trick_, starting_index);

using GHistRowT = common::GHistRow;
auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right,
Expand Down Expand Up @@ -487,4 +474,3 @@ TEST(CPUHistogram, ExternalMemory) {
TestHistogramExternalMemory(&ctx, {kBins, sparse_thresh}, false, true);
}
} // namespace xgboost::tree

Loading