-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYCL] Add basic features for QuantileHistMaker (#10174)
--------- Co-authored-by: Dmitry Razdoburdin <>
- Loading branch information
1 parent
882f413
commit 6e5c335
Showing
3 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/*! | ||
* Copyright 2017-2024 by Contributors | ||
* \file updater_quantile_hist.cc | ||
*/ | ||
#include <vector> | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#pragma GCC diagnostic ignored "-W#pragma-messages" | ||
#include "xgboost/tree_updater.h" | ||
#pragma GCC diagnostic pop | ||
|
||
#include "xgboost/logging.h" | ||
|
||
#include "updater_quantile_hist.h" | ||
#include "../data.h" | ||
|
||
namespace xgboost { | ||
namespace sycl { | ||
namespace tree { | ||
|
||
DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl); | ||
|
||
DMLC_REGISTER_PARAMETER(HistMakerTrainParam); | ||
|
||
void QuantileHistMaker::Configure(const Args& args) { | ||
const DeviceOrd device_spec = ctx_->Device(); | ||
qu_ = device_manager.GetQueue(device_spec); | ||
|
||
param_.UpdateAllowUnknown(args); | ||
hist_maker_param_.UpdateAllowUnknown(args); | ||
} | ||
|
||
void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, | ||
linalg::Matrix<GradientPair>* gpair, | ||
DMatrix *dmat, | ||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position, | ||
const std::vector<RegTree *> &trees) { | ||
LOG(FATAL) << "Not Implemented yet"; | ||
} | ||
|
||
bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, | ||
linalg::MatrixView<float> out_preds) { | ||
LOG(FATAL) << "Not Implemented yet"; | ||
} | ||
|
||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") | ||
.describe("Grow tree using quantized histogram with SYCL.") | ||
.set_body( | ||
[](Context const* ctx, ObjInfo const * task) { | ||
return new QuantileHistMaker(ctx, task); | ||
}); | ||
} // namespace tree | ||
} // namespace sycl | ||
} // namespace xgboost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/*! | ||
* Copyright 2017-2024 by Contributors | ||
* \file updater_quantile_hist.h | ||
*/ | ||
#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ | ||
#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ | ||
|
||
#include <dmlc/timer.h> | ||
#include <xgboost/tree_updater.h> | ||
|
||
#include <vector> | ||
|
||
#include "../data/gradient_index.h" | ||
#include "../common/hist_util.h" | ||
#include "../common/row_set.h" | ||
#include "../common/partition_builder.h" | ||
#include "split_evaluator.h" | ||
#include "../device_manager.h" | ||
|
||
#include "xgboost/data.h" | ||
#include "xgboost/json.h" | ||
#include "../../src/tree/constraints.h" | ||
#include "../../src/common/random.h" | ||
|
||
namespace xgboost { | ||
namespace sycl { | ||
namespace tree { | ||
|
||
// training parameters specific to this algorithm | ||
struct HistMakerTrainParam | ||
: public XGBoostParameter<HistMakerTrainParam> { | ||
bool single_precision_histogram = false; | ||
// declare parameters | ||
DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { | ||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( | ||
"Use single precision to build histograms."); | ||
} | ||
}; | ||
|
||
/*! \brief construct a tree using quantized feature values with SYCL backend*/ | ||
class QuantileHistMaker: public TreeUpdater { | ||
public: | ||
QuantileHistMaker(Context const* ctx, ObjInfo const * task) : | ||
TreeUpdater(ctx), task_{task} { | ||
updater_monitor_.Init("SYCLQuantileHistMaker"); | ||
} | ||
void Configure(const Args& args) override; | ||
|
||
void Update(xgboost::tree::TrainParam const *param, | ||
linalg::Matrix<GradientPair>* gpair, | ||
DMatrix* dmat, | ||
xgboost::common::Span<HostDeviceVector<bst_node_t>> out_position, | ||
const std::vector<RegTree*>& trees) override; | ||
|
||
bool UpdatePredictionCache(const DMatrix* data, | ||
linalg::MatrixView<float> out_preds) override; | ||
|
||
void LoadConfig(Json const& in) override { | ||
auto const& config = get<Object const>(in); | ||
FromJson(config.at("train_param"), &this->param_); | ||
FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_); | ||
} | ||
|
||
void SaveConfig(Json* p_out) const override { | ||
auto& out = *p_out; | ||
out["train_param"] = ToJson(param_); | ||
out["sycl_hist_train_param"] = ToJson(hist_maker_param_); | ||
} | ||
|
||
char const* Name() const override { | ||
return "grow_quantile_histmaker_sycl"; | ||
} | ||
|
||
protected: | ||
HistMakerTrainParam hist_maker_param_; | ||
// training parameter | ||
xgboost::tree::TrainParam param_; | ||
|
||
xgboost::common::Monitor updater_monitor_; | ||
|
||
::sycl::queue qu_; | ||
DeviceManager device_manager; | ||
ObjInfo const *task_{nullptr}; | ||
}; | ||
|
||
|
||
} // namespace tree | ||
} // namespace sycl | ||
} // namespace xgboost | ||
|
||
#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/** | ||
* Copyright 2020-2024 by XGBoost contributors | ||
*/ | ||
#include <gtest/gtest.h> | ||
|
||
#pragma GCC diagnostic push | ||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare" | ||
#pragma GCC diagnostic ignored "-W#pragma-messages" | ||
#include <xgboost/json.h> | ||
#include <xgboost/task.h> | ||
#include "../../../plugin/sycl/tree/updater_quantile_hist.h" // for QuantileHistMaker | ||
#pragma GCC diagnostic pop | ||
|
||
namespace xgboost::sycl::tree { | ||
TEST(SyclQuantileHistMaker, Basic) { | ||
Context ctx; | ||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); | ||
|
||
ObjInfo task{ObjInfo::kRegression}; | ||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; | ||
|
||
ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl"); | ||
} | ||
|
||
TEST(SyclQuantileHistMaker, JsonIO) { | ||
Context ctx; | ||
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); | ||
|
||
ObjInfo task{ObjInfo::kRegression}; | ||
Json config {Object()}; | ||
{ | ||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; | ||
updater->Configure({{"max_depth", std::to_string(42)}}); | ||
updater->Configure({{"single_precision_histogram", std::to_string(true)}}); | ||
updater->SaveConfig(&config); | ||
} | ||
|
||
{ | ||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; | ||
updater->LoadConfig(config); | ||
|
||
Json new_config {Object()}; | ||
updater->SaveConfig(&new_config); | ||
|
||
ASSERT_EQ(config, new_config); | ||
|
||
auto max_depth = atoi(get<String const>(new_config["train_param"]["max_depth"]).c_str()); | ||
ASSERT_EQ(max_depth, 42); | ||
|
||
auto single_precision_histogram = atoi(get<String const>(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str()); | ||
ASSERT_EQ(single_precision_histogram, 1); | ||
} | ||
|
||
} | ||
} // namespace xgboost::sycl::tree |