From 297029d6b131bf066a43e7e28bbd80cb62619012 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 24 Mar 2023 18:55:17 +0800 Subject: [PATCH 01/14] Enable prediction cache for multi-target tree. --- include/xgboost/tree_updater.h | 4 +-- src/gbm/gbtree.cc | 10 ++++-- src/tree/hist/evaluate_splits.h | 42 ++++++++++++++++++++++--- src/tree/updater_approx.cc | 4 +-- src/tree/updater_gpu_hist.cu | 7 +++-- src/tree/updater_quantile_hist.cc | 21 ++++++++++--- tests/cpp/tree/test_gpu_hist.cu | 2 +- tests/cpp/tree/test_prediction_cache.cc | 36 ++++++++++++--------- 8 files changed, 92 insertions(+), 34 deletions(-) diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 02248ed8ce50..79b80319f6da 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -85,8 +85,8 @@ class TreeUpdater : public Configurable { * the prediction cache. If true, the prediction cache will have been * updated by the time this function returns. */ - virtual bool UpdatePredictionCache(const DMatrix * /*data*/, - linalg::VectorView /*out_preds*/) { + virtual bool UpdatePredictionCache(const DMatrix* /*data*/, + linalg::MatrixView /*out_preds*/) { return false; } diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index a912d6a75d81..21479225935c 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -281,15 +281,19 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); // No update prediction cache yet. new_trees.push_back(std::move(ret)); + std::size_t num_new_trees = ret.size(); + if (updaters_.size() > 0 && num_new_trees == 1 && predt->predictions.Size() > 0 && + updaters_.back()->UpdatePredictionCache(p_fmat, out)) { + predt->Update(1); + } } else if (model_.learner_model_param->OutputLength() == 1) { std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); - auto v_predt = out.Slice(linalg::All(), 0); if (updaters_.size() > 0 && num_new_trees == 1 && predt->predictions.Size() > 0 && - updaters_.back()->UpdatePredictionCache(p_fmat, v_predt)) { + updaters_.back()->UpdatePredictionCache(p_fmat, out)) { predt->Update(1); } } else { @@ -305,7 +309,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret); const size_t num_new_trees = ret.size(); new_trees.push_back(std::move(ret)); - auto v_predt = out.Slice(linalg::All(), gid); + auto v_predt = out.Slice(linalg::All(), linalg::Range(gid, gid + 1)); if (!(updaters_.size() > 0 && predt->predictions.Size() > 0 && num_new_trees == 1 && updaters_.back()->UpdatePredictionCache(p_fmat, v_predt))) { update_predict = false; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 925a5fb76b8f..6000a735d3e0 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -677,9 +677,6 @@ template void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, std::vector const &partitioner, linalg::VectorView out_preds) { - CHECK_GT(out_preds.Size(), 0U); - - CHECK(p_last_tree); auto const &tree = *p_last_tree; CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId); size_t n_nodes = p_last_tree->GetNodes().size(); @@ -687,7 +684,7 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, CHECK_EQ(part.Size(), n_nodes); common::BlockedSpace2d space( part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); - common::ParallelFor2d(space, ctx->Threads(), [&](size_t nidx, common::Range1d r) { + common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) { if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) { auto const &rowset = part[nidx]; auto leaf_value = tree[nidx].LeafValue(); @@ -698,5 +695,42 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, }); } } + +template +void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, + std::vector const &partitioner, + linalg::MatrixView out_preds) { + CHECK_GT(out_preds.Size(), 0U); + CHECK(p_last_tree); + + auto const &tree = *p_last_tree; + if (!tree.IsMultiTarget()) { + UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0)); + return; + } + CHECK_EQ(out_preds.Shape(1), 1); + + auto const *mttree = tree.GetMultiTargetTree(); + auto n_nodes = mttree->Size(); + auto n_targets = tree.NumTargets(); + CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId); + + for (auto &part : partitioner) { + CHECK_EQ(part.Size(), n_nodes); + common::BlockedSpace2d space( + part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); + common::ParallelFor2d(space, ctx->Threads(), [&](bst_node_t nidx, common::Range1d r) { + if (tree.IsLeaf(nidx)) { + auto const &rowset = part[nidx]; + auto leaf_value = mttree->LeafValue(nidx); + for (std::size_t const *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { + for (std::size_t i = 0; i < n_targets; ++i) { + out_preds(*it, i) += leaf_value(i); + } + } + } + }); + } +} } // namespace xgboost::tree #endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_ diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index fd636d3a3c39..d6bc23f44b94 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -116,7 +116,7 @@ class GloablApproxBuilder { return nodes.front(); } - void UpdatePredictionCache(DMatrix const *data, linalg::VectorView out_preds) const { + void UpdatePredictionCache(DMatrix const *data, linalg::MatrixView out_preds) const { monitor_->Start(__func__); // Caching prediction seems redundant for approx tree method, as sketching takes up // majority of training time. @@ -303,7 +303,7 @@ class GlobalApproxUpdater : public TreeUpdater { } } - bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { + bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView out_preds) override { if (data != cached_ || !pimpl_) { return false; } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 54ff7ea1a962..2d1b7a24df7c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -517,7 +517,7 @@ struct GPUHistMakerDevice { }); } - bool UpdatePredictionCache(linalg::VectorView out_preds_d, RegTree const* p_tree) { + bool UpdatePredictionCache(linalg::MatrixView out_preds_d, RegTree const* p_tree) { if (positions.empty()) { return false; } @@ -535,11 +535,12 @@ struct GPUHistMakerDevice { h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice, ctx_->CUDACtx()->Stream())); auto d_nodes = dh::ToSpan(nodes); + CHECK_EQ(out_preds_d.Shape(1), 1); dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) mutable { bst_node_t nidx = d_position[idx]; auto weight = d_nodes[nidx].LeafValue(); - out_preds_d(idx) += weight; + out_preds_d(idx, 0) += weight; }); return true; } @@ -858,7 +859,7 @@ class GPUHistMaker : public TreeUpdater { } bool UpdatePredictionCache(const DMatrix* data, - linalg::VectorView p_out_preds) override { + linalg::MatrixView p_out_preds) override { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 012b8e78179e..dd8c6a964d84 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -125,6 +125,7 @@ class MultiTargetHistBuilder { std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. RegTree const *p_last_tree_{nullptr}; + DMatrix const *const p_last_fmat_{nullptr}; ObjInfo const *task_{nullptr}; @@ -312,6 +313,19 @@ class MultiTargetHistBuilder { task_{task} { monitor_->Init(__func__); } + + bool UpdatePredictionCache(DMatrix const *data, linalg::MatrixView out_preds) const { + // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in + // conjunction with Update(). + if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { + return false; + } + monitor_->Start(__func__); + CHECK_EQ(out_preds.Size(), data->Info().num_row_ * p_last_tree_->NumTargets()); + UpdatePredictionCacheImpl(ctx_, p_last_tree_, partitioner_, out_preds); + monitor_->Stop(__func__); + return true; + } }; class HistBuilder { @@ -347,7 +361,7 @@ class HistBuilder { monitor_->Init(__func__); } - bool UpdatePredictionCache(DMatrix const *data, linalg::VectorView out_preds) const { + bool UpdatePredictionCache(DMatrix const *data, linalg::MatrixView out_preds) const { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { @@ -582,12 +596,11 @@ class QuantileHistMaker : public TreeUpdater { } } - bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { + bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView out_preds) override { if (p_impl_) { return p_impl_->UpdatePredictionCache(data, out_preds); } else if (p_mtimpl_) { - // Not yet supported. - return false; + return p_mtimpl_->UpdatePredictionCache(data, out_preds); } else { return false; } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ed21230edc02..003347c8e29a 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -246,7 +246,7 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, std::vector> position(1); hist_maker.Update(¶m, gpair, dmat, common::Span>{position}, {tree}); - auto cache = linalg::VectorView{preds->DeviceSpan(), {preds->Size()}, 0}; + auto cache = linalg::MakeTensorView(&ctx, preds->DeviceSpan(), preds->Size(), 1); hist_maker.UpdatePredictionCache(dmat, cache); } diff --git a/tests/cpp/tree/test_prediction_cache.cc b/tests/cpp/tree/test_prediction_cache.cc index 4f5a05eb6ead..1877b7a35b6b 100644 --- a/tests/cpp/tree/test_prediction_cache.cc +++ b/tests/cpp/tree/test_prediction_cache.cc @@ -15,15 +15,17 @@ namespace xgboost { class TestPredictionCache : public ::testing::Test { std::shared_ptr Xy_; - size_t n_samples_{2048}; + std::size_t n_samples_{2048}; protected: void SetUp() override { - size_t n_features = 13; - Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.GenerateDMatrix(true); + std::size_t n_features = 13; + bst_target_t n_targets = 3; + Xy_ = RandomDataGenerator{n_samples_, n_features, 0}.Targets(n_targets).GenerateDMatrix(true); } - void RunLearnerTest(std::string updater_name, float subsample, std::string grow_policy) { + void RunLearnerTest(std::string updater_name, float subsample, std::string const& grow_policy, + std::string const& strategy) { std::unique_ptr learner{Learner::Create({Xy_})}; if (updater_name == "grow_gpu_hist") { // gpu_id setup @@ -31,6 +33,7 @@ class TestPredictionCache : public ::testing::Test { } else { learner->SetParam("updater", updater_name); } + learner->SetParam("multi_strategy", strategy); learner->SetParam("grow_policy", grow_policy); learner->SetParam("subsample", std::to_string(subsample)); learner->SetParam("nthread", "0"); @@ -62,7 +65,7 @@ class TestPredictionCache : public ::testing::Test { } } - void RunTest(std::string updater_name) { + void RunTest(std::string const& updater_name, std::string const& strategy) { { Context ctx; ctx.InitAllowUnknown(Args{{"nthread", "8"}}); @@ -85,28 +88,31 @@ class TestPredictionCache : public ::testing::Test { HostDeviceVector out_prediction_cached; out_prediction_cached.SetDevice(ctx.gpu_id); out_prediction_cached.Resize(n_samples_); - auto cache = linalg::VectorView{ctx.gpu_id == Context::kCpuId - ? out_prediction_cached.HostSpan() - : out_prediction_cached.DeviceSpan(), - {out_prediction_cached.Size()}, - ctx.gpu_id}; + auto cache = + linalg::MakeTensorView(&ctx, &out_prediction_cached, out_prediction_cached.Size(), 1); ASSERT_TRUE(updater->UpdatePredictionCache(Xy_.get(), cache)); } for (auto policy : {"depthwise", "lossguide"}) { for (auto subsample : {1.0f, 0.4f}) { - this->RunLearnerTest(updater_name, subsample, policy); - this->RunLearnerTest(updater_name, subsample, policy); + this->RunLearnerTest(updater_name, subsample, policy, strategy); + this->RunLearnerTest(updater_name, subsample, policy, strategy); } } } }; -TEST_F(TestPredictionCache, Approx) { this->RunTest("grow_histmaker"); } +TEST_F(TestPredictionCache, Approx) { this->RunTest("grow_histmaker", "one_output_per_tree"); } -TEST_F(TestPredictionCache, Hist) { this->RunTest("grow_quantile_histmaker"); } +TEST_F(TestPredictionCache, Hist) { + this->RunTest("grow_quantile_histmaker", "one_output_per_tree"); +} + +TEST_F(TestPredictionCache, HistMulti) { + this->RunTest("grow_quantile_histmaker", "multi_output_tree"); +} #if defined(XGBOOST_USE_CUDA) -TEST_F(TestPredictionCache, GpuHist) { this->RunTest("grow_gpu_hist"); } +TEST_F(TestPredictionCache, GpuHist) { this->RunTest("grow_gpu_hist", "one_output_per_tree"); } #endif // defined(XGBOOST_USE_CUDA) } // namespace xgboost From a1fa7b127b865505f44b599a6d4694fbe6ffbbca Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 24 Mar 2023 20:39:20 +0800 Subject: [PATCH 02/14] Fix prediction range. --- src/gbm/gbtree.cc | 10 +++------- src/gbm/gbtree.h | 8 ++++---- src/gbm/gbtree_model.cc | 14 +++++++------- src/gbm/gbtree_model.h | 15 ++++++++++----- src/tree/hist/evaluate_splits.h | 2 +- src/tree/updater_quantile_hist.cc | 3 ++- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 21479225935c..15322ca3e861 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -279,9 +279,8 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, std::vector> ret; BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); - // No update prediction cache yet. - new_trees.push_back(std::move(ret)); std::size_t num_new_trees = ret.size(); + new_trees.push_back(std::move(ret)); if (updaters_.size() > 0 && num_new_trees == 1 && predt->predictions.Size() > 0 && updaters_.back()->UpdatePredictionCache(p_fmat, out)) { predt->Update(1); @@ -563,11 +562,8 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step, }); } -void GBTree::PredictBatch(DMatrix* p_fmat, - PredictionCacheEntry* out_preds, - bool, - unsigned layer_begin, - unsigned layer_end) { +void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool, + std::uint32_t layer_begin, std::uint32_t layer_end) { CHECK(configured_); if (layer_end == 0) { layer_end = this->BoostedRounds(); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index b64532c614e9..bcbfc44e4e30 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -144,7 +144,7 @@ inline std::pair LayerToTree(gbm::GBTreeModel const& model, std::uint32_t layer_end) { std::uint32_t tree_begin; std::uint32_t tree_end; - if (model.learner_model_param->IsVectorLeaf()) { + if (model.HasMultiTargetTree()) { tree_begin = layer_begin * model.param.num_parallel_tree; tree_end = layer_end * model.param.num_parallel_tree; } else { @@ -243,7 +243,7 @@ class GBTree : public GradientBooster { // Number of trees per layer. [[nodiscard]] std::uint32_t LayerTrees() const { - if (model_.learner_model_param->IsVectorLeaf()) { + if (model_.HasMultiTargetTree()) { return model_.param.num_parallel_tree; } return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength(); @@ -264,8 +264,8 @@ class GBTree : public GradientBooster { return !model_.trees.empty() || !model_.trees_to_update.empty(); } - void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds, - bool training, unsigned layer_begin, unsigned layer_end) override; + void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training, + std::uint32_t layer_begin, std::uint32_t layer_end) override; void InplacePredict(std::shared_ptr p_m, float missing, PredictionCacheEntry* out_preds, uint32_t layer_begin, unsigned layer_end) const override { diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 4e9cc6655eaa..391c07d39dc5 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -1,5 +1,5 @@ -/*! - * Copyright 2019-2022 by Contributors +/** + * Copyright 2019-2023, XGBoost Contributors */ #include @@ -8,8 +8,7 @@ #include "gbtree_model.h" #include "gbtree.h" -namespace xgboost { -namespace gbm { +namespace xgboost::gbm { void GBTreeModel::Save(dmlc::Stream* fo) const { CHECK_EQ(param.num_trees, static_cast(trees.size())); @@ -97,11 +96,14 @@ void GBTreeModel::LoadModel(Json const& in) { trees.resize(trees_json.size()); CHECK(ctx_); + std::atomic has_multi_target_tree{0}; common::ParallelFor(trees_json.size(), ctx_->Threads(), [&](auto t) { auto tree_id = get(trees_json[t]["id"]); trees.at(tree_id).reset(new RegTree()); trees.at(tree_id)->LoadModel(trees_json[t]); + has_multi_target_tree += !!trees.at(tree_id)->IsMultiTarget(); }); + has_multi_tree_ = has_multi_target_tree > 0; tree_info.resize(param.num_trees); auto const& tree_info_json = get(in["tree_info"]); @@ -109,6 +111,4 @@ void GBTreeModel::LoadModel(Json const& in) { tree_info[i] = get(tree_info_json[i]); } } - -} // namespace gbm -} // namespace xgboost +} // namespace xgboost::gbm diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 1f2bdfa639e1..9b79ed294cba 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -116,21 +116,22 @@ struct GBTreeModel : public Model { void SaveModel(Json* p_out) const override; void LoadModel(Json const& p_out) override; - std::vector DumpModel(const FeatureMap& fmap, bool with_stats, int32_t n_threads, - std::string format) const { + [[nodiscard]] std::vector DumpModel(const FeatureMap& fmap, bool with_stats, + int32_t n_threads, std::string format) const { std::vector dump(trees.size()); common::ParallelFor(trees.size(), n_threads, [&](size_t i) { dump[i] = trees[i]->DumpModel(fmap, with_stats, format); }); return dump; } - void CommitModel(std::vector >&& new_trees, - int bst_group) { - for (auto & new_tree : new_trees) { + void CommitModel(std::vector >&& new_trees, int bst_group) { + for (auto& new_tree : new_trees) { + has_multi_tree_ |= new_tree->IsMultiTarget(); trees.push_back(std::move(new_tree)); tree_info.push_back(bst_group); } param.num_trees += static_cast(new_trees.size()); } + [[nodiscard]] bool HasMultiTargetTree() const { return has_multi_tree_; } // base margin LearnerModelParam const* learner_model_param; @@ -144,6 +145,10 @@ struct GBTreeModel : public Model { std::vector tree_info; private: + /** + * \brief Whether the stack contains multi-target tree. + */ + bool has_multi_tree_{false}; Context const* ctx_; }; } // namespace gbm diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 6000a735d3e0..0a79fbebce0f 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -708,11 +708,11 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree, UpdatePredictionCacheImpl(ctx, p_last_tree, partitioner, out_preds.Slice(linalg::All(), 0)); return; } - CHECK_EQ(out_preds.Shape(1), 1); auto const *mttree = tree.GetMultiTargetTree(); auto n_nodes = mttree->Size(); auto n_targets = tree.NumTargets(); + CHECK_EQ(out_preds.Shape(1), n_targets); CHECK_EQ(out_preds.DeviceIdx(), Context::kCpuId); for (auto &part : partitioner) { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index dd8c6a964d84..90d9128c2b22 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -125,7 +125,7 @@ class MultiTargetHistBuilder { std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. RegTree const *p_last_tree_{nullptr}; - DMatrix const *const p_last_fmat_{nullptr}; + DMatrix const * p_last_fmat_{nullptr}; ObjInfo const *task_{nullptr}; @@ -148,6 +148,7 @@ class MultiTargetHistBuilder { void InitData(DMatrix *p_fmat, RegTree const *p_tree) { monitor_->Start(__func__); + p_last_fmat_= p_fmat; std::size_t page_id = 0; bst_bin_t n_total_bins = 0; partitioner_.clear(); From 060490df99fc9b9ff88de64274a433d09104e9a2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 24 Mar 2023 20:43:30 +0800 Subject: [PATCH 03/14] lint. --- src/tree/updater_quantile_hist.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 90d9128c2b22..8387177aae79 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -148,7 +148,7 @@ class MultiTargetHistBuilder { void InitData(DMatrix *p_fmat, RegTree const *p_tree) { monitor_->Start(__func__); - p_last_fmat_= p_fmat; + p_last_fmat_ = p_fmat; std::size_t page_id = 0; bst_bin_t n_total_bins = 0; partitioner_.clear(); From 9c73eafd54cc037bc8616bd6f01a7ac14887832d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 12:25:03 +0800 Subject: [PATCH 04/14] Cleanup model IO, param, and headers. --- src/gbm/gbtree_model.cc | 28 ++++++++++------- src/gbm/gbtree_model.h | 47 +++++++++++----------------- src/predictor/cpu_predictor.cc | 5 +-- src/predictor/gpu_predictor.cu | 1 - tests/cpp/common/test_json.cc | 7 +++-- tests/cpp/helpers.cc | 17 ---------- tests/cpp/helpers.h | 8 ++--- tests/cpp/predictor/test_predictor.h | 17 ++++++++++ tests/cpp/test_serialization.cc | 5 ++- 9 files changed, 64 insertions(+), 71 deletions(-) diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 391c07d39dc5..1cbce3ebc4f2 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -71,10 +71,10 @@ void GBTreeModel::SaveModel(Json* p_out) const { CHECK(ctx_); common::ParallelFor(trees.size(), ctx_->Threads(), [&](auto t) { auto const& tree = trees[t]; - Json tree_json{Object()}; - tree->SaveModel(&tree_json); - tree_json["id"] = Integer{static_cast(t)}; - trees_json[t] = std::move(tree_json); + Json jtree{Object{}}; + tree->SaveModel(&jtree); + jtree["id"] = Integer{static_cast(t)}; + trees_json[t] = std::move(jtree); }); std::vector tree_info_json(tree_info.size()); @@ -93,20 +93,24 @@ void GBTreeModel::LoadModel(Json const& in) { trees_to_update.clear(); auto const& trees_json = get(in["trees"]); - trees.resize(trees_json.size()); + CHECK_EQ(trees_json.size(), param.num_trees); + trees.resize(param.num_trees); + + auto const& tree_info_json = get(in["tree_info"]); + CHECK_EQ(tree_info_json.size(), param.num_trees); + tree_info.resize(param.num_trees); CHECK(ctx_); + std::atomic has_multi_target_tree{0}; - common::ParallelFor(trees_json.size(), ctx_->Threads(), [&](auto t) { - auto tree_id = get(trees_json[t]["id"]); - trees.at(tree_id).reset(new RegTree()); - trees.at(tree_id)->LoadModel(trees_json[t]); - has_multi_target_tree += !!trees.at(tree_id)->IsMultiTarget(); + common::ParallelFor(param.num_trees, ctx_->Threads(), [&](auto t) { + auto tree_id = get(trees_json[t]["id"]); + trees.at(tree_id).reset(new RegTree{}); + trees[tree_id]->LoadModel(trees_json[t]); + has_multi_target_tree += !!trees[tree_id]->IsMultiTarget(); }); has_multi_tree_ = has_multi_target_tree > 0; - tree_info.resize(param.num_trees); - auto const& tree_info_json = get(in["tree_info"]); for (int32_t i = 0; i < param.num_trees; ++i) { tree_info[i] = get(tree_info_json[i]); } diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 9b79ed294cba..7581acb33788 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2017-2020 by Contributors +/** + * Copyright 2017-2023, XGBoost Contributors * \file gbtree_model.h */ #ifndef XGBOOST_GBM_GBTREE_MODEL_H_ @@ -29,22 +29,16 @@ namespace gbm { /*! \brief model parameters */ struct GBTreeModelParam : public dmlc::Parameter { public: - /*! \brief number of trees */ - int32_t num_trees; - /*! \brief (Deprecated) number of roots */ - int32_t num_parallel_tree; - /*! \brief number of features to be used by trees */ - int32_t deprecated_num_feature; - /*! \brief pad this space, for backward compatibility reason.*/ - int32_t pad_32bit; - /*! \brief deprecated padding space. */ - int64_t deprecated_num_pbuffer; - // deprecated. use learner_model_param_->num_output_group. - int32_t deprecated_num_output_group; - /*! \brief size of leaf vector needed in tree */ - int32_t size_leaf_vector; + /** + * \brief number of trees + */ + std::int32_t num_trees; + /** + * \brief Number of trees for a forest. + */ + std::int32_t num_parallel_tree; /*! \brief reserved parameters */ - int32_t reserved[32]; + int32_t reserved[38]; /*! \brief constructor */ GBTreeModelParam() { @@ -66,23 +60,14 @@ struct GBTreeModelParam : public dmlc::Parameter { .describe( "Number of parallel trees constructed during each iteration." " This option is used to support boosted random forest."); - DMLC_DECLARE_FIELD(size_leaf_vector) - .set_lower_bound(0) - .set_default(0) - .describe("Reserved option for vector tree."); } // Swap byte order for all fields. Useful for transporting models between machines with different // endianness (big endian vs little endian) - inline GBTreeModelParam ByteSwap() const { + GBTreeModelParam ByteSwap() const { GBTreeModelParam x = *this; dmlc::ByteSwap(&x.num_trees, sizeof(x.num_trees), 1); dmlc::ByteSwap(&x.num_parallel_tree, sizeof(x.num_parallel_tree), 1); - dmlc::ByteSwap(&x.deprecated_num_feature, sizeof(x.deprecated_num_feature), 1); - dmlc::ByteSwap(&x.pad_32bit, sizeof(x.pad_32bit), 1); - dmlc::ByteSwap(&x.deprecated_num_pbuffer, sizeof(x.deprecated_num_pbuffer), 1); - dmlc::ByteSwap(&x.deprecated_num_output_group, sizeof(x.deprecated_num_output_group), 1); - dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1); dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0])); return x; } @@ -141,8 +126,14 @@ struct GBTreeModel : public Model { std::vector > trees; /*! \brief for the update process, a place to keep the initial trees */ std::vector > trees_to_update; - /*! \brief some information indicator of the tree, reserved */ + /** + * \brief Group index for trees. + */ std::vector tree_info; + /** + * \brief Number of trees accumulated for each iteration. + */ + std::vector iteration_trees; private: /** diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 3d5dfbd674ea..fe6fea02faf9 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -287,7 +287,6 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod linalg::TensorView out_predt) { auto &thread_temp = *p_thread_temp; - CHECK_EQ(model.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far"; // parallel over local batch const auto nsize = static_cast(batch.Size()); const int num_feature = model.learner_model_param->num_feature; @@ -515,7 +514,6 @@ class ColumnSplitHelper { void PredictBatchKernel(DataView batch, std::vector *out_preds) { auto const num_group = model_.learner_model_param->num_output_group; - CHECK_EQ(model_.param.size_leaf_vector, 0) << "size_leaf_vector is enforced to 0 so far"; // parallel over local batch auto const nsize = batch.Size(); auto const num_feature = model_.learner_model_param->num_feature; @@ -736,8 +734,7 @@ class CPUPredictor : public Predictor { if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } - out_preds->resize(model.learner_model_param->num_output_group * - (model.param.size_leaf_vector + 1)); + out_preds->resize(model.learner_model_param->num_output_group); auto base_score = model.learner_model_param->BaseScore(ctx_)(0); // loop over output groups for (uint32_t gid = 0; gid < model.learner_model_param->num_output_group; ++gid) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 4a5c5b104179..2439a277f13d 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -342,7 +342,6 @@ class DeviceModel { void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, int32_t gpu_id) { dh::safe_cuda(cudaSetDevice(gpu_id)); - CHECK_EQ(model.param.size_leaf_vector, 0); // Copy decision trees to device tree_segments = std::move(HostDeviceVector({}, gpu_id)); auto& h_tree_segments = tree_segments.HostVector(); diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index cf8bcd81ddff..3e2038e13deb 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) by Contributors 2019-2022 +/** + * Copyright (c) 2019-2023, XGBoost Contributors */ #include @@ -8,7 +8,8 @@ #include "../../../src/common/charconv.h" #include "../../../src/common/io.h" -#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../../../src/common/threading_utils.h" // for ParallelFor +#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" #include "dmlc/logging.h" #include "xgboost/json.h" diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 49813f1d04de..27742bf6b3a8 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -557,23 +557,6 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( return dmat; } -gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, Context const* ctx, - size_t n_classes) { - gbm::GBTreeModel model(param, ctx); - - for (size_t i = 0; i < n_classes; ++i) { - std::vector> trees; - trees.push_back(std::unique_ptr(new RegTree)); - if (i == 0) { - (*trees.back())[0].SetLeaf(1.5f); - (*trees.back()).Stat(0).sum_hess = 1.0f; - } - model.CommitModel(std::move(trees), i); - } - - return model; -} - std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, size_t kRows, size_t kCols, LearnerModelParam const* learner_model_param, diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index c835444131c5..9d820e4b3e48 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -9,8 +9,10 @@ #include #include #include +#include // for LearnerModelParam +#include // for Configurable -#include // std::int32_t +#include // std::int32_t #include #include #include @@ -22,7 +24,6 @@ #include "../../src/collective/communicator-inl.h" #include "../../src/common/common.h" #include "../../src/data/array_interface.h" -#include "../../src/gbm/gbtree_model.h" #include "filesystem.h" // dmlc::TemporaryDirectory #include "xgboost/linalg.h" @@ -362,9 +363,6 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory()); -gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, Context const* ctx, - size_t n_classes = 1); - std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, size_t kRows, size_t kCols, LearnerModelParam const* learner_model_param, diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 56c1523a1cf1..6b0ae1fb1ff5 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -14,6 +14,23 @@ #include "../helpers.h" namespace xgboost { +inline gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, Context const* ctx, + size_t n_classes = 1) { + gbm::GBTreeModel model(param, ctx); + + for (size_t i = 0; i < n_classes; ++i) { + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + if (i == 0) { + (*trees.back())[0].SetLeaf(1.5f); + (*trees.back()).Stat(0).sum_hess = 1.0f; + } + model.CommitModel(std::move(trees), i); + } + + return model; +} + template void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols, std::shared_ptr p_hist) { diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 15765f09f29d..731f85563092 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -1,7 +1,10 @@ -// Copyright (c) 2019-2022 by Contributors +/** + * Copyright (c) 2019-2023, XGBoost Contributors + */ #include #include #include +#include // for FeatureMap #include #include From da563d0b0cfbc14d2dda72e574e0b458e19c17c9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 15:04:57 +0800 Subject: [PATCH 05/14] Start working on model slicing. --- include/xgboost/base.h | 12 +- include/xgboost/gbm.h | 12 +- include/xgboost/learner.h | 8 +- src/gbm/gbtree.cc | 132 +++++++++++----------- src/gbm/gbtree.h | 99 +++++++--------- src/gbm/gbtree_model.cc | 41 +++++-- src/gbm/gbtree_model.h | 24 +++- src/learner.cc | 8 +- tests/cpp/predictor/test_gpu_predictor.cu | 8 +- tests/cpp/predictor/test_predictor.cc | 4 +- tests/cpp/predictor/test_predictor.h | 2 +- 11 files changed, 186 insertions(+), 164 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 00fc7fb4ac63..43540beea928 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -113,8 +113,18 @@ using bst_row_t = std::size_t; // NOLINT using bst_node_t = std::int32_t; // NOLINT /*! \brief Type for ranking group index. */ using bst_group_t = std::uint32_t; // NOLINT -/*! \brief Type for indexing into output targets. */ +/** + * \brief Type for indexing into output targets. + */ using bst_target_t = std::uint32_t; // NOLINT +/** + * brief Type for indexing boosted layers. + */ +using bst_layer_t = std::int32_t; // NOLINT +/** + * \brief Type for indexing trees. + */ +using bst_tree_t = std::int32_t; // NOLINT namespace detail { /*! \brief Implementation of gradient statistics pair. Template specialisation diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 07758a524469..57db4b436dc1 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -59,16 +59,16 @@ class GradientBooster : public Model, public Configurable { * \param fo output stream */ virtual void Save(dmlc::Stream* fo) const = 0; - /*! + /** * \brief Slice a model using boosting index. The slice m:n indicates taking all trees * that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1). - * \param layer_begin Beginning of boosted tree layer used for prediction. - * \param layer_end End of booster layer. 0 means do not limit trees. - * \param out Output gradient booster + * \param begin Beginning of boosted tree layer used for prediction. + * \param end End of booster layer. 0 means do not limit trees. + * \param out Output gradient booster */ - virtual void Slice(int32_t /*layer_begin*/, int32_t /*layer_end*/, int32_t /*step*/, + virtual void Slice(bst_layer_t /*begin*/, bst_layer_t /*end*/, bst_layer_t /*step*/, GradientBooster* /*out*/, bool* /*out_of_bound*/) const { - LOG(FATAL) << "Slice is not supported by current booster."; + LOG(FATAL) << "Slice is not supported by the current booster."; } /*! \brief Return number of boosted rounds. */ diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 08e1ded09e95..f2b377ac1d18 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -9,7 +9,7 @@ #define XGBOOST_LEARNER_H_ #include // for Serializable -#include // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair +#include // for bst_feature_t, bst_target_t, bst_float, Args, GradientPair, .. #include // for Context #include // for Tensor, TensorView #include // for Metric @@ -229,7 +229,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { */ virtual void GetFeatureTypes(std::vector* ft) const = 0; - /*! + /** * \brief Slice the model. * * See InplacePredict for layer parameters. @@ -239,8 +239,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * * \return a sliced model. */ - virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step, - bool *out_of_bound) = 0; + virtual Learner* Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, + bool* out_of_bound) = 0; /*! * \brief dump the model in the requested format * \param fmap feature map that may help give interpretations of feature diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 15322ca3e861..558b382c488c 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -225,10 +225,9 @@ void CopyGradient(HostDeviceVector const* in_gpair, int32_t n_thre } void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const& predictions, - ObjFunction const* obj, - std::int32_t group_idx, + ObjFunction const* obj, std::int32_t group_idx, std::vector> const& node_position, - std::vector>* p_trees) { + TreesOneGroup* p_trees) { CHECK(!updaters_.empty()); if (!updaters_.back()->HasNodePosition()) { return; @@ -252,7 +251,7 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry* predt, ObjFunction const* obj) { - std::vector>> new_trees; + TreesOneIter new_trees; const int ngroup = model_.learner_model_param->OutputLength(); ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); @@ -276,7 +275,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, std::vector> node_position; if (model_.learner_model_param->IsVectorLeaf()) { - std::vector> ret; + TreesOneGroup ret; BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); std::size_t num_new_trees = ret.size(); @@ -285,8 +284,8 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, updaters_.back()->UpdatePredictionCache(p_fmat, out)) { predt->Update(1); } - } else if (model_.learner_model_param->OutputLength() == 1) { - std::vector> ret; + } else if (model_.learner_model_param->OutputLength() == 1u) { + TreesOneGroup ret; BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret); const size_t num_new_trees = ret.size(); @@ -303,7 +302,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, for (int gid = 0; gid < ngroup; ++gid) { node_position.clear(); CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp); - std::vector> ret; + TreesOneGroup ret; BoostNewTrees(&tmp, p_fmat, gid, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret); const size_t num_new_trees = ret.size(); @@ -320,7 +319,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, } monitor_.Stop("BoostNewTrees"); - this->CommitModel(std::move(new_trees)); + this->model_.CommitModel(std::move(new_trees)); } void GBTree::InitUpdater(Args const& cfg) { @@ -366,7 +365,7 @@ void GBTree::InitUpdater(Args const& cfg) { void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fmat, int bst_group, std::vector>* out_position, - std::vector>* ret) { + TreesOneGroup* ret) { std::vector new_trees; ret->clear(); // create the trees @@ -422,15 +421,9 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, DMatrix* p_fma tree_param_.learning_rate = lr; } -void GBTree::CommitModel(std::vector>>&& new_trees) { +void GBTree::CommitModel(TreesOneIter&& new_trees) { monitor_.Start("CommitModel"); - if (this->model_.learner_model_param->IsVectorLeaf()) { - model_.CommitModel(std::move(new_trees[0]), 0); - } else { - for (std::uint32_t gid = 0; gid < model_.learner_model_param->OutputLength(); ++gid) { - model_.CommitModel(std::move(new_trees[gid]), gid); - } - } + model_.CommitModel(std::forward(new_trees)); monitor_.Stop("CommitModel"); } @@ -522,28 +515,30 @@ void GBTree::SaveModel(Json* p_out) const { model_.SaveModel(&model); } -void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step, - GradientBooster *out, bool* out_of_bound) const { +void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, GradientBooster* out, + bool* out_of_bound) const { CHECK(configured_); CHECK(out); auto p_gbtree = dynamic_cast(out); CHECK(p_gbtree); - GBTreeModel &out_model = p_gbtree->model_; - auto layer_trees = this->LayerTrees(); - CHECK_NE(this->model_.learner_model_param->num_feature, 0); - CHECK_NE(layer_trees, 0); + GBTreeModel& out_model = p_gbtree->model_; + CHECK(this->model_.learner_model_param->Initialized()); - layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end; - CHECK_GT(layer_end, layer_begin); + end = end == 0 ? model_.iteration_indptr.size() : end; CHECK_GE(step, 1); - int32_t n_layers = (layer_end - layer_begin) / step; - std::vector> &out_trees = out_model.trees; - out_trees.resize(layer_trees * n_layers); - std::vector &out_trees_info = out_model.tree_info; - out_trees_info.resize(layer_trees * n_layers); - out_model.param.num_trees = out_model.trees.size(); - out_model.param.num_parallel_tree = model_.param.num_parallel_tree; + if (step > (end - begin)) { + *out_of_bound = true; + return; + } + + auto& out_iteration_indptr = out_model.iteration_indptr; + TreesOneGroup& out_trees = out_model.trees; + std::vector& out_trees_info = out_model.tree_info; + + bst_layer_t n_layers = (end - begin) / step; + out_iteration_indptr.resize(n_layers + 1, 0); + if (!this->model_.trees_to_update.empty()) { CHECK_EQ(this->model_.trees_to_update.size(), this->model_.trees.size()) << "Not all trees are updated, " @@ -552,14 +547,23 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step, "want to update a portion of trees."; } - *out_of_bound = detail::SliceTrees(layer_begin, layer_end, step, this->model_, layer_trees, - [&](auto const& in_it, auto const& out_it) { - auto new_tree = - std::make_unique(*this->model_.trees.at(in_it)); - bst_group_t group = this->model_.tree_info[in_it]; - out_trees.at(out_it) = std::move(new_tree); - out_trees_info.at(out_it) = group; - }); + *out_of_bound = detail::SliceTrees( + begin, end, step, this->model_, [&](auto in_tree_idx, auto out_l) { + auto new_tree = std::make_unique(*this->model_.trees.at(in_tree_idx)); + out_trees.push_back(std::move(new_tree)); + + bst_group_t group = this->model_.tree_info[in_tree_idx]; + out_trees_info.push_back(group); + + out_model.iteration_indptr[out_l + 1]++; + }); + + std::partial_sum(out_iteration_indptr.cbegin(), out_iteration_indptr.cend(), + out_iteration_indptr.begin()); + CHECK_EQ(out_model.iteration_indptr.front(), 0); + + out_model.param.num_trees = out_model.trees.size(); + out_model.param.num_parallel_tree = model_.param.num_parallel_tree; } void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool, @@ -590,8 +594,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, model_); } - std::uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; if (tree_end > tree_begin) { predictor->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end); @@ -729,10 +732,9 @@ class Dart : public GBTree { auto p_dart = dynamic_cast(out); CHECK(p_dart); CHECK(p_dart->weight_drop_.empty()); - detail::SliceTrees(layer_begin, layer_end, step, model_, this->LayerTrees(), - [&](auto const& in_it, auto const&) { - p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it)); - }); + detail::SliceTrees(layer_begin, layer_end, step, model_, [&](auto const& in_it, auto const&) { + p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it)); + }); } void SaveModel(Json *p_out) const override { @@ -798,8 +800,7 @@ class Dart : public GBTree { predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_); p_out_preds->version = 0; - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); auto n_groups = model_.learner_model_param->num_output_group; PredictionCacheEntry predts; // temporary storage for prediction @@ -807,14 +808,18 @@ class Dart : public GBTree { predts.predictions.SetDevice(ctx_->gpu_id); } predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0); + // multi-target is not yet supported. + auto layer_trees = [&]() { + return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength(); + }; - for (size_t i = tree_begin; i < tree_end; i += 1) { + for (bst_tree_t i = tree_begin; i < tree_end; i += 1) { if (training && std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) { continue; } CHECK_GE(i, p_out_preds->version); - auto version = i / this->LayerTrees(); + auto version = i / layer_trees(); p_out_preds->version = version; predts.predictions.Fill(0); predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1); @@ -854,8 +859,7 @@ class Dart : public GBTree { PredictionCacheEntry* p_out_preds, uint32_t layer_begin, unsigned layer_end) const override { CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented(); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); auto n_groups = model_.learner_model_param->num_output_group; std::vector predictors { @@ -897,7 +901,7 @@ class Dart : public GBTree { }; // Inplace predict is not used for training, so no need to drop tree. - for (size_t i = tree_begin; i < tree_end; ++i) { + for (bst_tree_t i = tree_begin; i < tree_end; ++i) { predict_impl(i); if (i == tree_begin) { predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_); @@ -941,31 +945,25 @@ class Dart : public GBTree { unsigned layer_begin, unsigned layer_end, bool approximate, int, unsigned) override { CHECK(configured_); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); - cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, - tree_end, &weight_drop_, approximate); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); + cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, &weight_drop_, + approximate); } void PredictInteractionContributions( DMatrix *p_fmat, HostDeviceVector *out_contribs, unsigned layer_begin, unsigned layer_end, bool approximate) override { CHECK(configured_); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end, &weight_drop_, approximate); } protected: // commit new trees all at once - void CommitModel(std::vector>>&& new_trees) override { - int num_new_trees = 0; - for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) { - num_new_trees += new_trees[gid].size(); - model_.CommitModel(std::move(new_trees[gid]), gid); - } - size_t num_drop = NormalizeTrees(num_new_trees); + void CommitModel(TreesOneIter&& new_trees) override { + auto n_new_trees = model_.CommitModel(std::forward(new_trees)); + size_t num_drop = NormalizeTrees(n_new_trees); LOG(INFO) << "drop " << num_drop << " trees, " << "weight = " << weight_drop_.back(); } diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index bcbfc44e4e30..0ed418b8d1a4 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -139,23 +139,12 @@ struct DartTrainParam : public XGBoostParameter { namespace detail { // From here on, layer becomes concrete trees. -inline std::pair LayerToTree(gbm::GBTreeModel const& model, - std::uint32_t layer_begin, - std::uint32_t layer_end) { - std::uint32_t tree_begin; - std::uint32_t tree_end; - if (model.HasMultiTargetTree()) { - tree_begin = layer_begin * model.param.num_parallel_tree; - tree_end = layer_end * model.param.num_parallel_tree; - } else { - bst_group_t groups = model.learner_model_param->OutputLength(); - tree_begin = layer_begin * groups * model.param.num_parallel_tree; - tree_end = layer_end * groups * model.param.num_parallel_tree; - } - - if (tree_end == 0) { - tree_end = model.trees.size(); - } +inline std::pair LayerToTree(gbm::GBTreeModel const& model, + bst_layer_t layer_begin, + bst_layer_t layer_end) { + CHECK(!model.iteration_indptr.empty()); + bst_tree_t tree_begin = model.iteration_indptr[layer_begin]; + bst_tree_t tree_end = model.iteration_indptr[layer_end]; if (model.trees.size() != 0) { CHECK_LE(tree_begin, tree_end); } @@ -164,27 +153,30 @@ inline std::pair LayerToTree(gbm::GBTreeModel const& model, // Call fn for each pair of input output tree. Return true if index is out of bound. template -bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step, GBTreeModel const& model, - uint32_t layer_trees, Func fn) { - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model, layer_begin, layer_end); - if (tree_end > model.trees.size()) { +bool SliceTrees(bst_layer_t begin, bst_layer_t end, bst_layer_t step, GBTreeModel const& model, + Func&& fn) { + end = end == 0 ? model.iteration_indptr.size() : end; + CHECK_GE(step, 1); + if (step > end - begin) { return true; } - layer_end = layer_end == 0 ? model.trees.size() / layer_trees : layer_end; - uint32_t n_layers = (layer_end - layer_begin) / step; - int32_t in_it = tree_begin; - int32_t out_it = 0; - for (uint32_t l = 0; l < n_layers; ++l) { - for (uint32_t i = 0; i < layer_trees; ++i) { - CHECK_LT(in_it, tree_end); - fn(in_it, out_it); - out_it++; - in_it++; + bst_layer_t n_layers = (end - begin) / step; + bst_layer_t out_l = 0; + + for (bst_layer_t l = begin; l < end; l += step) { + auto [tree_begin, tree_end] = detail::LayerToTree(model, l, l + 1); + if (tree_end >= static_cast(model.trees.size())) { + return true; + } + + for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + fn(tree_idx, out_l); } - in_it += (step - 1) * layer_trees; + ++out_l; } + + CHECK_EQ(out_l, n_layers); return false; } } // namespace detail @@ -241,23 +233,15 @@ class GBTree : public GradientBooster { void SaveModel(Json* p_out) const override; void LoadModel(Json const& in) override; - // Number of trees per layer. - [[nodiscard]] std::uint32_t LayerTrees() const { - if (model_.HasMultiTargetTree()) { - return model_.param.num_parallel_tree; - } - return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength(); - } - // slice the trees, out must be already allocated - void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, - GradientBooster *out, bool* out_of_bound) const override; + void Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, GradientBooster* out, + bool* out_of_bound) const override; [[nodiscard]] std::int32_t BoostedRounds() const override { - CHECK_NE(model_.param.num_parallel_tree, 0); - CHECK_NE(model_.learner_model_param->num_output_group, 0); - - return model_.trees.size() / this->LayerTrees(); + if (model_.trees.empty()) { + return 0; + } + return model_.iteration_indptr.size() - 1; } [[nodiscard]] bool ModelFitted() const override { @@ -270,8 +254,7 @@ class GBTree : public GradientBooster { void InplacePredict(std::shared_ptr p_m, float missing, PredictionCacheEntry* out_preds, uint32_t layer_begin, unsigned layer_end) const override { CHECK(configured_); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; std::vector predictors{ cpu_predictor_.get(), @@ -364,20 +347,18 @@ class GBTree : public GradientBooster { } } - void PredictInstance(const SparsePage::Inst& inst, - std::vector* out_preds, + void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, uint32_t layer_begin, uint32_t layer_end) override { CHECK(configured_); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + std::uint32_t _, tree_end; + std::tie(_, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); cpu_predictor_->PredictInstance(inst, out_preds, model_, tree_end); } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, uint32_t layer_begin, uint32_t layer_end) override { - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, " "n_iteration), use model slicing instead."; this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end); @@ -388,8 +369,7 @@ class GBTree : public GradientBooster { uint32_t layer_begin, uint32_t layer_end, bool approximate, int, unsigned) override { CHECK(configured_); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: (0, " "n_iteration), using model slicing instead."; @@ -401,8 +381,7 @@ class GBTree : public GradientBooster { DMatrix *p_fmat, HostDeviceVector *out_contribs, uint32_t layer_begin, uint32_t layer_end, bool approximate) override { CHECK(configured_); - uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end); + auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: (0, " "n_iteration), using model slicing instead."; @@ -427,7 +406,7 @@ class GBTree : public GradientBooster { DMatrix* f_dmat = nullptr) const; // commit new trees all at once - virtual void CommitModel(std::vector>>&& new_trees); + virtual void CommitModel(TreesOneIter&& new_trees); // --- data structure --- GBTreeModel model_; diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 1cbce3ebc4f2..74f361882cd6 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -1,12 +1,19 @@ /** * Copyright 2019-2023, XGBoost Contributors */ -#include - -#include "xgboost/json.h" -#include "xgboost/logging.h" #include "gbtree_model.h" -#include "gbtree.h" + +#include // for size_t +#include // for operator<<, basic_ostream +#include // for move + +#include "../common/threading_utils.h" // for ParallelFor +#include "dmlc/base.h" // for BeginPtr +#include "dmlc/io.h" // for Stream +#include "xgboost/context.h" // for Context +#include "xgboost/json.h" // for Json, get, Integer, Array, FromJson, ToJson, Object +#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK +#include "xgboost/tree_model.h" // for RegTree namespace xgboost::gbm { void GBTreeModel::Save(dmlc::Stream* fo) const { @@ -102,17 +109,33 @@ void GBTreeModel::LoadModel(Json const& in) { CHECK(ctx_); - std::atomic has_multi_target_tree{0}; common::ParallelFor(param.num_trees, ctx_->Threads(), [&](auto t) { auto tree_id = get(trees_json[t]["id"]); trees.at(tree_id).reset(new RegTree{}); trees[tree_id]->LoadModel(trees_json[t]); - has_multi_target_tree += !!trees[tree_id]->IsMultiTarget(); }); - has_multi_tree_ = has_multi_target_tree > 0; - for (int32_t i = 0; i < param.num_trees; ++i) { + for (bst_tree_t i = 0; i < param.num_trees; ++i) { tree_info[i] = get(tree_info_json[i]); } } + +std::uint32_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) { + CHECK(!iteration_indptr.empty()); + CHECK_EQ(iteration_indptr.back(), param.num_trees); + std::uint32_t n_new_trees{0}; + + if (learner_model_param->IsVectorLeaf()) { + n_new_trees += new_trees.front().size(); + this->CommitModelGroup(std::move(new_trees.front()), 0); + } else { + for (bst_target_t gidx{0}; gidx < learner_model_param->OutputLength(); ++gidx) { + n_new_trees += new_trees[gidx].size(); + this->CommitModelGroup(std::move(new_trees[gidx]), gidx); + } + } + + iteration_indptr.push_back(n_new_trees + iteration_indptr.back()); + return n_new_trees; +} } // namespace xgboost::gbm diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 7581acb33788..5a23dec8f544 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -25,6 +25,14 @@ namespace xgboost { class Json; namespace gbm { +/** + * \brief Container for all trees built (not update) for one group. + */ +using TreesOneGroup = std::vector>; +/** + * \brief Container for all trees built (not update) for one iteration. + */ +using TreesOneIter = std::vector; /*! \brief model parameters */ struct GBTreeModelParam : public dmlc::Parameter { @@ -108,15 +116,20 @@ struct GBTreeModel : public Model { [&](size_t i) { dump[i] = trees[i]->DumpModel(fmap, with_stats, format); }); return dump; } - void CommitModel(std::vector >&& new_trees, int bst_group) { + /** + * \brief Add trees to the model. + * + * \return The number of new trees. + */ + std::uint32_t CommitModel(TreesOneIter&& new_trees); + + void CommitModelGroup(std::vector>&& new_trees, bst_target_t group_idx) { for (auto& new_tree : new_trees) { - has_multi_tree_ |= new_tree->IsMultiTarget(); trees.push_back(std::move(new_tree)); - tree_info.push_back(bst_group); + tree_info.push_back(group_idx); } param.num_trees += static_cast(new_trees.size()); } - [[nodiscard]] bool HasMultiTargetTree() const { return has_multi_tree_; } // base margin LearnerModelParam const* learner_model_param; @@ -133,13 +146,12 @@ struct GBTreeModel : public Model { /** * \brief Number of trees accumulated for each iteration. */ - std::vector iteration_trees; + std::vector iteration_indptr{0}; private: /** * \brief Whether the stack contains multi-target tree. */ - bool has_multi_tree_{false}; Context const* ctx_; }; } // namespace gbm diff --git a/src/learner.cc b/src/learner.cc index 50d54c9fcbc7..8808c3392c64 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -45,7 +45,7 @@ #include "common/timer.h" // for Monitor #include "common/version.h" // for Version #include "dmlc/endian.h" // for ByteSwap, DMLC_IO_NO_ENDIAN_SWAP -#include "xgboost/base.h" // for Args, bst_float, GradientPair, bst_feature_t +#include "xgboost/base.h" // for Args, bst_float, GradientPair, bst_feature_t, ... #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo #include "xgboost/gbm.h" // for GradientBooster @@ -1247,19 +1247,19 @@ class LearnerImpl : public LearnerIO { return gbm_->DumpModel(fmap, with_stats, format); } - Learner* Slice(int32_t begin_layer, int32_t end_layer, int32_t step, + Learner* Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, bool* out_of_bound) override { this->Configure(); this->CheckModelInitialized(); CHECK_NE(this->learner_model_param_.num_feature, 0); - CHECK_GE(begin_layer, 0); + CHECK_GE(begin, 0); auto* out_impl = new LearnerImpl({}); out_impl->learner_model_param_.Copy(this->learner_model_param_); out_impl->ctx_ = this->ctx_; auto gbm = std::unique_ptr(GradientBooster::Create( this->tparam_.booster, &out_impl->ctx_, &out_impl->learner_model_param_)); - this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound); + this->gbm_->Slice(begin, end, step, gbm.get(), out_of_bound); out_impl->gbm_ = std::move(gbm); Json config{Object()}; diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 4a3293dbe73d..fecb5028a6ba 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -1,5 +1,5 @@ -/*! - * Copyright 2017-2022 XGBoost contributors +/** + * Copyright 2017-2023, XGBoost contributors */ #include #include @@ -155,7 +155,7 @@ TEST(GPUPredictor, ShapStump) { std::vector> trees; trees.push_back(std::unique_ptr(new RegTree)); - model.CommitModel(std::move(trees), 0); + model.CommitModelGroup(std::move(trees), 0); auto gpu_lparam = CreateEmptyGenericParam(0); std::unique_ptr gpu_predictor = std::unique_ptr( @@ -183,7 +183,7 @@ TEST(GPUPredictor, Shap) { std::vector> trees; trees.push_back(std::unique_ptr(new RegTree)); trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0); - model.CommitModel(std::move(trees), 0); + model.CommitModelGroup(std::move(trees), 0); auto gpu_lparam = CreateEmptyGenericParam(0); auto cpu_lparam = CreateEmptyGenericParam(-1); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index d6cf33445893..575a85497761 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -209,7 +209,7 @@ void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind, p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, left_weight, right_weight, 3.0f, 2.2f, 7.0f, 9.0f); - model->CommitModel(std::move(trees), 0); + model->CommitModelGroup(std::move(trees), 0); } void TestCategoricalPrediction(std::string name) { @@ -445,7 +445,7 @@ void TestVectorLeafPrediction(Context const *ctx) { ASSERT_TRUE(mparam.IsVectorLeaf()); gbm::GBTreeModel model{&mparam, ctx}; - model.CommitModel(std::move(trees), 0); + model.CommitModelGroup(std::move(trees), 0); auto run_test = [&](float expected, HostDeviceVector *p_data) { { diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 6b0ae1fb1ff5..302c6bfaed95 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -25,7 +25,7 @@ inline gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, Context (*trees.back())[0].SetLeaf(1.5f); (*trees.back()).Stat(0).sum_hess = 1.0f; } - model.CommitModel(std::move(trees), i); + model.CommitModelGroup(std::move(trees), i); } return model; From 2db19190a5c9a9f8b0c3e4eb66beeaab7d7cdb31 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 15:17:35 +0800 Subject: [PATCH 06/14] model IO. --- src/gbm/gbtree.cc | 27 +++++++++++++-------------- src/gbm/gbtree_model.cc | 35 +++++++++++++++++++++++++++++++++++ src/gbm/gbtree_model.h | 2 +- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 558b382c488c..fd07d1d8df73 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -252,7 +252,7 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector const void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry* predt, ObjFunction const* obj) { TreesOneIter new_trees; - const int ngroup = model_.learner_model_param->OutputLength(); + bst_target_t const n_groups = model_.learner_model_param->OutputLength(); ConfigureWithKnownData(this->cfg_, p_fmat); monitor_.Start("BoostNewTrees"); @@ -264,7 +264,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, device, device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(), p_fmat->Info().num_row_, model_.learner_model_param->OutputLength()); - CHECK_NE(ngroup, 0); + CHECK_NE(n_groups, 0); if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) { LOG(FATAL) << "Current objective doesn't support external memory."; @@ -295,13 +295,13 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, predt->Update(1); } } else { - CHECK_EQ(in_gpair->Size() % ngroup, 0U) << "must have exactly ngroup * nrow gpairs"; - HostDeviceVector tmp(in_gpair->Size() / ngroup, GradientPair(), + CHECK_EQ(in_gpair->Size() % n_groups, 0U) << "must have exactly ngroup * nrow gpairs"; + HostDeviceVector tmp(in_gpair->Size() / n_groups, GradientPair(), in_gpair->DeviceIdx()); bool update_predict = true; - for (int gid = 0; gid < ngroup; ++gid) { + for (bst_target_t gid = 0; gid < n_groups; ++gid) { node_position.clear(); - CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp); + CopyGradient(in_gpair, ctx_->Threads(), n_groups, gid, &tmp); TreesOneGroup ret; BoostNewTrees(&tmp, p_fmat, gid, &node_position, &ret); UpdateTreeLeaf(p_fmat, predt->predictions, obj, gid, node_position, &ret); @@ -520,7 +520,7 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien CHECK(configured_); CHECK(out); - auto p_gbtree = dynamic_cast(out); + auto p_gbtree = dynamic_cast(out); CHECK(p_gbtree); GBTreeModel& out_model = p_gbtree->model_; CHECK(this->model_.learner_model_param->Initialized()); @@ -532,12 +532,12 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien return; } - auto& out_iteration_indptr = out_model.iteration_indptr; + auto& out_indptr = out_model.iteration_indptr; TreesOneGroup& out_trees = out_model.trees; std::vector& out_trees_info = out_model.tree_info; bst_layer_t n_layers = (end - begin) / step; - out_iteration_indptr.resize(n_layers + 1, 0); + out_indptr.resize(n_layers + 1, 0); if (!this->model_.trees_to_update.empty()) { CHECK_EQ(this->model_.trees_to_update.size(), this->model_.trees.size()) @@ -547,10 +547,10 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien "want to update a portion of trees."; } - *out_of_bound = detail::SliceTrees( - begin, end, step, this->model_, [&](auto in_tree_idx, auto out_l) { + *out_of_bound = + detail::SliceTrees(begin, end, step, this->model_, [&](auto in_tree_idx, auto out_l) { auto new_tree = std::make_unique(*this->model_.trees.at(in_tree_idx)); - out_trees.push_back(std::move(new_tree)); + out_trees.emplace_back(std::move(new_tree)); bst_group_t group = this->model_.tree_info[in_tree_idx]; out_trees_info.push_back(group); @@ -558,8 +558,7 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien out_model.iteration_indptr[out_l + 1]++; }); - std::partial_sum(out_iteration_indptr.cbegin(), out_iteration_indptr.cend(), - out_iteration_indptr.begin()); + std::partial_sum(out_indptr.cbegin(), out_indptr.cend(), out_indptr.begin()); CHECK_EQ(out_model.iteration_indptr.front(), 0); out_model.param.num_trees = out_model.trees.size(); diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 74f361882cd6..2b4c9ea17bed 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -16,6 +16,20 @@ #include "xgboost/tree_model.h" // for RegTree namespace xgboost::gbm { +namespace { +// For creating the tree indptr from old models. +void MakeIndptr(GBTreeModel* out_model) { + auto const& tree_info = out_model->tree_info; + auto& indptr = out_model->iteration_indptr; + + auto n_groups = *std::max_element(tree_info.cbegin(), tree_info.cend()) + 1; + for (std::size_t i = 1; i < indptr.size(); ++i) { + indptr[i] = n_groups * out_model->param.num_parallel_tree; + } + std::partial_sum(indptr.cbegin(), indptr.cend(), indptr.begin()); +} +} // namespace + void GBTreeModel::Save(dmlc::Stream* fo) const { CHECK_EQ(param.num_trees, static_cast(trees.size())); @@ -67,6 +81,8 @@ void GBTreeModel::Load(dmlc::Stream* fi) { } } } + + MakeIndptr(this); } void GBTreeModel::SaveModel(Json* p_out) const { @@ -91,6 +107,11 @@ void GBTreeModel::SaveModel(Json* p_out) const { out["trees"] = Array(std::move(trees_json)); out["tree_info"] = Array(std::move(tree_info_json)); + + std::vector jiteration_indptr(iteration_indptr.size()); + std::transform(iteration_indptr.cbegin(), iteration_indptr.cend(), jiteration_indptr.begin(), + [](bst_tree_t i) { return Integer{i}; }); + out["iteration_indptr"] = Array{std::move(jiteration_indptr)}; } void GBTreeModel::LoadModel(Json const& in) { @@ -99,6 +120,8 @@ void GBTreeModel::LoadModel(Json const& in) { trees.clear(); trees_to_update.clear(); + auto const& jmodel = get(in); + auto const& trees_json = get(in["trees"]); CHECK_EQ(trees_json.size(), param.num_trees); trees.resize(param.num_trees); @@ -118,6 +141,18 @@ void GBTreeModel::LoadModel(Json const& in) { for (bst_tree_t i = 0; i < param.num_trees; ++i) { tree_info[i] = get(tree_info_json[i]); } + + auto indptr_it = jmodel.find("iteration_indptr"); + iteration_indptr.clear(); + if (indptr_it != jmodel.cend()) { + auto const& vec = get(indptr_it->second); + iteration_indptr.resize(vec.size()); + std::transform(vec.cbegin(), vec.cend(), iteration_indptr.begin(), + [](Json const& v) { return get(v); }); + CHECK_EQ(iteration_indptr.back(), trees.size()); + } else { + MakeIndptr(this); + } } std::uint32_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) { diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 5a23dec8f544..d8cc3d61e636 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -146,7 +146,7 @@ struct GBTreeModel : public Model { /** * \brief Number of trees accumulated for each iteration. */ - std::vector iteration_indptr{0}; + std::vector iteration_indptr{0}; private: /** From 2af50ef6e0b805db8eb253af90b8d6915ae5a473 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 15:28:21 +0800 Subject: [PATCH 07/14] Type for prediction. --- include/xgboost/gbm.h | 37 +++++++++++++++++-------------------- src/gbm/gblinear.cc | 2 +- src/gbm/gbtree.cc | 15 ++++++--------- src/gbm/gbtree.h | 4 ++-- 4 files changed, 26 insertions(+), 32 deletions(-) diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 57db4b436dc1..4f690064f873 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -88,34 +88,31 @@ class GradientBooster : public Model, public Configurable { virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, PredictionCacheEntry*, ObjFunction const* obj) = 0; - /*! - * \brief generate predictions for given feature matrix - * \param dmat feature matrix + /** + * \brief Generate predictions for given feature matrix + * + * \param dmat The feature matrix. * \param out_preds output vector to hold the predictions * \param training Whether the prediction value is used for training. For dart booster * drop out is performed during training. - * \param layer_begin Beginning of boosted tree layer used for prediction. - * \param layer_end End of booster layer. 0 means do not limit trees. + * \param begin Beginning of boosted tree layer used for prediction. + * \param end End of booster layer. 0 means do not limit trees. */ - virtual void PredictBatch(DMatrix* dmat, - PredictionCacheEntry* out_preds, - bool training, - unsigned layer_begin, - unsigned layer_end) = 0; + virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, bool training, + bst_layer_t begin, bst_layer_t end) = 0; - /*! + /** * \brief Inplace prediction. * - * \param p_fmat A proxy DMatrix that contains the data and related - * meta info. - * \param missing Missing value in the data. - * \param [in,out] out_preds The output preds. - * \param layer_begin (Optional) Beginning of boosted tree layer used for prediction. - * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. + * \param p_fmat A proxy DMatrix that contains the data and related. + * \param missing Missing value in the data. + * \param [in,out] out_preds The output preds. + * \param begin (Optional) Beginning of boosted tree layer used for prediction. + * \param end (Optional) End of booster layer. 0 means do not limit trees. */ - virtual void InplacePredict(std::shared_ptr, float, PredictionCacheEntry*, uint32_t, - uint32_t) const { - LOG(FATAL) << "Inplace predict is not supported by current booster."; + virtual void InplacePredict(std::shared_ptr, float, PredictionCacheEntry*, bst_layer_t, + bst_layer_t) const { + LOG(FATAL) << "Inplace predict is not supported by the current booster."; } /*! * \brief online prediction function, predict score for one instance at a time diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 575820758492..f1189886c345 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -148,7 +148,7 @@ class GBLinear : public GradientBooster { } void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* predts, bool /*training*/, - uint32_t layer_begin, uint32_t) override { + bst_layer_t layer_begin, bst_layer_t) override { monitor_.Start("PredictBatch"); LinearCheckLayer(layer_begin); auto* out_preds = &predts->predictions; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index fd07d1d8df73..97bcb11ff85a 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -566,12 +566,12 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien } void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool, - std::uint32_t layer_begin, std::uint32_t layer_end) { + bst_layer_t layer_begin, bst_layer_t layer_end) { CHECK(configured_); if (layer_end == 0) { layer_end = this->BoostedRounds(); } - if (layer_begin != 0 || layer_end < out_preds->version) { + if (layer_begin != 0 || layer_end < static_cast(out_preds->version)) { // cache is dropped. out_preds->version = 0; } @@ -845,18 +845,15 @@ class Dart : public GBTree { } } - void PredictBatch(DMatrix* p_fmat, - PredictionCacheEntry* p_out_preds, - bool training, - unsigned layer_begin, - unsigned layer_end) override { + void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training, + bst_layer_t layer_begin, bst_layer_t layer_end) override { DropTrees(training); this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end); } void InplacePredict(std::shared_ptr p_fmat, float missing, - PredictionCacheEntry* p_out_preds, uint32_t layer_begin, - unsigned layer_end) const override { + PredictionCacheEntry* p_out_preds, bst_layer_t layer_begin, + bst_layer_t layer_end) const override { CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented(); auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); auto n_groups = model_.learner_model_param->num_output_group; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 0ed418b8d1a4..066fc591cbbf 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -249,10 +249,10 @@ class GBTree : public GradientBooster { } void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training, - std::uint32_t layer_begin, std::uint32_t layer_end) override; + bst_layer_t layer_begin, bst_layer_t layer_end) override; void InplacePredict(std::shared_ptr p_m, float missing, PredictionCacheEntry* out_preds, - uint32_t layer_begin, unsigned layer_end) const override { + bst_layer_t layer_begin, bst_layer_t layer_end) const override { CHECK(configured_); auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; From 7d49640f600f09560f86d5ff8723923b84127554 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 15:38:33 +0800 Subject: [PATCH 08/14] lint. --- src/gbm/gbtree_model.cc | 11 +++++++---- src/gbm/gbtree_model.h | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 2b4c9ea17bed..292cc1915808 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -3,15 +3,18 @@ */ #include "gbtree_model.h" +#include // for transform, max_element #include // for size_t +#include // for partial_sum #include // for operator<<, basic_ostream -#include // for move +#include // for move, pair #include "../common/threading_utils.h" // for ParallelFor #include "dmlc/base.h" // for BeginPtr #include "dmlc/io.h" // for Stream #include "xgboost/context.h" // for Context -#include "xgboost/json.h" // for Json, get, Integer, Array, FromJson, ToJson, Object +#include "xgboost/json.h" // for Json, get, Integer, Array, FromJson, ToJson, Json... +#include "xgboost/learner.h" // for LearnerModelParam #include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK #include "xgboost/tree_model.h" // for RegTree @@ -155,10 +158,10 @@ void GBTreeModel::LoadModel(Json const& in) { } } -std::uint32_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) { +bst_tree_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) { CHECK(!iteration_indptr.empty()); CHECK_EQ(iteration_indptr.back(), param.num_trees); - std::uint32_t n_new_trees{0}; + bst_tree_t n_new_trees{0}; if (learner_model_param->IsVectorLeaf()) { n_new_trees += new_trees.front().size(); diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index d8cc3d61e636..bbf361126e5a 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -121,7 +121,7 @@ struct GBTreeModel : public Model { * * \return The number of new trees. */ - std::uint32_t CommitModel(TreesOneIter&& new_trees); + bst_tree_t CommitModel(TreesOneIter&& new_trees); void CommitModelGroup(std::vector>&& new_trees, bst_target_t group_idx) { for (auto& new_tree : new_trees) { From a9e0bb2d6576495e8337021419599987779eb5cc Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 16:23:25 +0800 Subject: [PATCH 09/14] Fixes. --- src/gbm/gbtree.cc | 2 +- src/gbm/gbtree_model.cc | 19 ++++++++++++++++++- src/gbm/gbtree_model.h | 3 +++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 97bcb11ff85a..ed0b97aaa063 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -319,7 +319,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector* in_gpair, } monitor_.Stop("BoostNewTrees"); - this->model_.CommitModel(std::move(new_trees)); + this->CommitModel(std::move(new_trees)); } void GBTree::InitUpdater(Args const& cfg) { diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 292cc1915808..2f7eda398909 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -23,7 +23,13 @@ namespace { // For creating the tree indptr from old models. void MakeIndptr(GBTreeModel* out_model) { auto const& tree_info = out_model->tree_info; + if (tree_info.empty()) { + return; + } + auto& indptr = out_model->iteration_indptr; + indptr.resize(tree_info.size() + 1, 0); + indptr[0] = 0; auto n_groups = *std::max_element(tree_info.cbegin(), tree_info.cend()) + 1; for (std::size_t i = 1; i < indptr.size(); ++i) { @@ -31,6 +37,14 @@ void MakeIndptr(GBTreeModel* out_model) { } std::partial_sum(indptr.cbegin(), indptr.cend(), indptr.begin()); } + +// Validate the consistency of the model. +void Validate(GBTreeModel const& model) { + CHECK_EQ(model.trees.size(), model.param.num_trees); + CHECK_EQ(model.tree_info.size(), model.param.num_trees); + // True even if the model is empty since we should always have 0 as the first element. + CHECK_EQ(model.iteration_indptr.back(), model.param.num_trees); +} } // namespace void GBTreeModel::Save(dmlc::Stream* fo) const { @@ -86,6 +100,7 @@ void GBTreeModel::Load(dmlc::Stream* fi) { } MakeIndptr(this); + Validate(*this); } void GBTreeModel::SaveModel(Json* p_out) const { @@ -152,10 +167,11 @@ void GBTreeModel::LoadModel(Json const& in) { iteration_indptr.resize(vec.size()); std::transform(vec.cbegin(), vec.cend(), iteration_indptr.begin(), [](Json const& v) { return get(v); }); - CHECK_EQ(iteration_indptr.back(), trees.size()); } else { MakeIndptr(this); } + + Validate(*this); } bst_tree_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) { @@ -174,6 +190,7 @@ bst_tree_t GBTreeModel::CommitModel(TreesOneIter&& new_trees) { } iteration_indptr.push_back(n_new_trees + iteration_indptr.back()); + Validate(*this); return n_new_trees; } } // namespace xgboost::gbm diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index bbf361126e5a..364353a76832 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -100,6 +100,9 @@ struct GBTreeModel : public Model { trees.clear(); param.num_trees = 0; tree_info.clear(); + + iteration_indptr.clear(); + iteration_indptr.push_back(0); } } From 3f36b78fd4c2b7a52ff7e5c64bceeff9f2b86e60 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 16:38:37 +0800 Subject: [PATCH 10/14] Fix boosted rounds. --- src/gbm/gbtree.h | 16 +++++----------- src/gbm/gbtree_model.h | 7 +++++++ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 066fc591cbbf..29e94980c70f 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -140,11 +140,11 @@ struct DartTrainParam : public XGBoostParameter { namespace detail { // From here on, layer becomes concrete trees. inline std::pair LayerToTree(gbm::GBTreeModel const& model, - bst_layer_t layer_begin, - bst_layer_t layer_end) { + bst_layer_t begin, bst_layer_t end) { CHECK(!model.iteration_indptr.empty()); - bst_tree_t tree_begin = model.iteration_indptr[layer_begin]; - bst_tree_t tree_end = model.iteration_indptr[layer_end]; + end = end == 0 ? model.BoostedRounds() : end; + bst_tree_t tree_begin = model.iteration_indptr[begin]; + bst_tree_t tree_end = model.iteration_indptr[end]; if (model.trees.size() != 0) { CHECK_LE(tree_begin, tree_end); } @@ -237,13 +237,7 @@ class GBTree : public GradientBooster { void Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, GradientBooster* out, bool* out_of_bound) const override; - [[nodiscard]] std::int32_t BoostedRounds() const override { - if (model_.trees.empty()) { - return 0; - } - return model_.iteration_indptr.size() - 1; - } - + [[nodiscard]] std::int32_t BoostedRounds() const override { return this->model_.BoostedRounds(); } [[nodiscard]] bool ModelFitted() const override { return !model_.trees.empty() || !model_.trees_to_update.empty(); } diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 364353a76832..1012acd32e25 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -134,6 +134,13 @@ struct GBTreeModel : public Model { param.num_trees += static_cast(new_trees.size()); } + [[nodiscard]] std::int32_t BoostedRounds() const { + if (trees.empty()) { + return 0; + } + return static_cast(iteration_indptr.size() - 1); + } + // base margin LearnerModelParam const* learner_model_param; // model parameter From c8e6f8ec0741eedf5da8e7dd0b5e32530ee7af37 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 16:41:58 +0800 Subject: [PATCH 11/14] throw error. --- src/gbm/gbtree.h | 1 + tests/cpp/gbm/test_gbtree.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 29e94980c70f..0a2e69310f20 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -143,6 +143,7 @@ inline std::pair LayerToTree(gbm::GBTreeModel const& mod bst_layer_t begin, bst_layer_t end) { CHECK(!model.iteration_indptr.empty()); end = end == 0 ? model.BoostedRounds() : end; + CHECK_LE(end, model.BoostedRounds()) << "Out of range for tree layers."; bst_tree_t tree_begin = model.iteration_indptr[begin]; bst_tree_t tree_end = model.iteration_indptr[end]; if (model.trees.size() != 0) { diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 270eacf21710..93d0cf525b54 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -505,7 +505,7 @@ TEST(GBTree, PredictRange) { auto h_out_predt_full = out_predt->HostVector(); ASSERT_TRUE(std::equal(h_out_predt.begin(), h_out_predt.end(), h_out_predt_full.begin())); - + // Out of range. ASSERT_THROW(learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits::quiet_NaN(), &out_predt, 0, 3), dmlc::Error); From 2470773581774f989d8c6e7f8ff6ae4a131f2dfc Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 17:24:14 +0800 Subject: [PATCH 12/14] Work on slicing tests. --- src/gbm/gbtree.cc | 4 +++- src/gbm/gbtree.h | 5 ++++- src/gbm/gbtree_model.cc | 7 +++++-- src/gbm/gbtree_model.h | 2 +- tests/python/test_basic_models.py | 2 +- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index ed0b97aaa063..f67c053448e3 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -525,8 +525,10 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien GBTreeModel& out_model = p_gbtree->model_; CHECK(this->model_.learner_model_param->Initialized()); - end = end == 0 ? model_.iteration_indptr.size() : end; + end = end == 0 ? model_.BoostedRounds() : end; CHECK_GE(step, 1); + CHECK_NE(end, begin) << "Empty slice is not allowed."; + if (step > (end - begin)) { *out_of_bound = true; return; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 0a2e69310f20..6e7da77ac1d6 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -161,13 +161,16 @@ bool SliceTrees(bst_layer_t begin, bst_layer_t end, bst_layer_t step, GBTreeMode if (step > end - begin) { return true; } + if (end > model.BoostedRounds()) { + return true; + } bst_layer_t n_layers = (end - begin) / step; bst_layer_t out_l = 0; for (bst_layer_t l = begin; l < end; l += step) { auto [tree_begin, tree_end] = detail::LayerToTree(model, l, l + 1); - if (tree_end >= static_cast(model.trees.size())) { + if (tree_end > static_cast(model.trees.size())) { return true; } diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 2f7eda398909..1373e3e2b412 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -27,11 +27,14 @@ void MakeIndptr(GBTreeModel* out_model) { return; } + auto n_groups = *std::max_element(tree_info.cbegin(), tree_info.cend()) + 1; + auto& indptr = out_model->iteration_indptr; - indptr.resize(tree_info.size() + 1, 0); + auto layer_trees = out_model->param.num_parallel_tree * n_groups; + CHECK_NE(layer_trees, 0); + indptr.resize(out_model->param.num_trees / layer_trees + 1, 0); indptr[0] = 0; - auto n_groups = *std::max_element(tree_info.cbegin(), tree_info.cend()) + 1; for (std::size_t i = 1; i < indptr.size(); ++i) { indptr[i] = n_groups * out_model->param.num_parallel_tree; } diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 1012acd32e25..32fa868638bb 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -136,7 +136,7 @@ struct GBTreeModel : public Model { [[nodiscard]] std::int32_t BoostedRounds() const { if (trees.empty()) { - return 0; + CHECK_EQ(iteration_indptr.size(), 1); } return static_cast(iteration_indptr.size() - 1); } diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index d03ce142bc83..6b9c4bc91e36 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -524,7 +524,7 @@ def run_slice( booster[-1:0] # we do not accept empty slice. - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Empty slice"): booster[1:1] # stop can not be smaller than begin with pytest.raises(ValueError, match=r"Invalid.*"): From 4396a3b7759c09c1e4c24df94215e7784335bb00 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 17:56:17 +0800 Subject: [PATCH 13/14] Support slicing. --- python-package/xgboost/core.py | 5 ++++ tests/python/test_basic_models.py | 40 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5a0cfb3a2ece..a4026d47cd9e 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -17,6 +17,7 @@ Any, Callable, Dict, + Generator, Iterable, List, Optional, @@ -1755,6 +1756,10 @@ def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster": sliced.handle = sliced_handle return sliced + def __iter__(self) -> Generator["Booster", None, None]: + for i in range(0, self.num_boosted_rounds()): + yield self[i] + def save_config(self) -> str: """Output internal parameter configuration of Booster as a JSON string. diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 6b9c4bc91e36..516cbd6cf76e 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -615,6 +615,46 @@ def test_slice(self, booster): booster = xgb.Booster(model_file=bytesarray) self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round) + def test_slice_multi(self) -> None: + from sklearn.datasets import make_classification + + num_classes = 3 + X, y = make_classification( + n_samples=1000, n_informative=5, n_classes=num_classes + ) + Xy = xgb.DMatrix(data=X, label=y) + num_parallel_tree = 4 + num_boost_round = 16 + + class ResetStrategy(xgb.callback.TrainingCallback): + def after_iteration(self, model, epoch: int, evals_log) -> bool: + model.set_param({"multi_strategy": "multi_output_tree"}) + return False + + booster = xgb.train( + { + "num_parallel_tree": num_parallel_tree, + "num_class": num_classes, + "booster": "gbtree", + "objective": "multi:softprob", + "multi_strategy": "multi_output_tree", + "tree_method": "hist", + "base_score": 0, + }, + num_boost_round=num_boost_round, + dtrain=Xy, + callbacks=[ResetStrategy()] + ) + sliced = [t for t in booster] + assert len(sliced) == 16 + + predt0 = booster.predict(Xy, output_margin=True) + predt1 = np.zeros(predt0.shape) + for t in booster: + predt1 += t.predict(Xy, output_margin=True) + + np.testing.assert_allclose(predt0, predt1, atol=1e-5) + @pytest.mark.skipif(**tm.no_pandas()) def test_feature_info(self): import pandas as pd From 2a408f1f830002600ee98bafac55e23bceaae5d0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 27 Mar 2023 18:12:13 +0800 Subject: [PATCH 14/14] Model schema. --- demo/json-model/json_parser.py | 3 --- doc/model.schema | 23 ++--------------------- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/demo/json-model/json_parser.py b/demo/json-model/json_parser.py index 315ede61b232..b744d9569aea 100644 --- a/demo/json-model/json_parser.py +++ b/demo/json-model/json_parser.py @@ -162,9 +162,6 @@ def __init__(self, model: dict) -> None: # Load the trees self.num_trees = int(model_shape["num_trees"]) - self.leaf_size = int(model_shape["size_leaf_vector"]) - # Right now XGBoost doesn't support vector leaf yet - assert self.leaf_size == 0, str(self.leaf_size) trees: List[Tree] = [] for i in range(self.num_trees): diff --git a/doc/model.schema b/doc/model.schema index 07a871820b5a..b9e2da3058db 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -19,23 +19,7 @@ "type": "object", "properties": { "tree_param": { - "type": "object", - "properties": { - "num_nodes": { - "type": "string" - }, - "size_leaf_vector": { - "type": "string" - }, - "num_feature": { - "type": "string" - } - }, - "required": [ - "num_nodes", - "num_feature", - "size_leaf_vector" - ] + "$ref": "#/definitions/tree_param" }, "id": { "type": "integer" @@ -170,14 +154,11 @@ }, "num_parallel_tree": { "type": "string" - }, - "size_leaf_vector": { - "type": "string" } }, "required": [ "num_trees", - "size_leaf_vector" + "num_parallel_tree" ] }, "tree_param": {