Skip to content

Commit

Permalink
Implement contribution prediction with QuantileDMatrix (#10043)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Jiaming Yuan <[email protected]>
  • Loading branch information
ldesreumaux and trivialfis authored Feb 19, 2024
1 parent 057f03c commit edf501d
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 55 deletions.
128 changes: 77 additions & 51 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,67 @@ class CPUPredictor : public Predictor {
}
}

template <typename DataView>
void PredictContributionKernel(DataView batch, const MetaInfo& info,
const gbm::GBTreeModel& model,
const std::vector<bst_float>* tree_weights,
std::vector<std::vector<float>>* mean_values,
std::vector<RegTree::FVec>* feat_vecs,
std::vector<bst_float>* contribs, uint32_t ntree_limit,
bool approximate, int condition,
unsigned condition_feature) const {
const int num_feature = model.learner_model_param->num_feature;
const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
size_t const ncolumns = num_feature + 1;
CHECK_NE(ncolumns, 0);
auto base_margin = info.base_margin_.View(ctx_->Device());
auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0);

// parallel over local batch
common::ParallelFor(batch.Size(), this->ctx_->Threads(), [&](auto i) {
auto row_idx = batch.base_rowid + i;
RegTree::FVec &feats = (*feat_vecs)[omp_get_thread_num()];
if (feats.Size() == 0) {
feats.Init(num_feature);
}
std::vector<bst_float> this_tree_contribs(ncolumns);
// loop over all classes
for (int gid = 0; gid < ngroup; ++gid) {
bst_float* p_contribs = &(*contribs)[(row_idx * ngroup + gid) * ncolumns];
feats.Fill(batch[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values->at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
CalculateContributions(*model.trees[j], feats, tree_mean_values,
&this_tree_contribs[0], condition, condition_feature);
} else {
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
}
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop();
// add base margin to BIAS
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), ngroup);
p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
} else {
p_contribs[ncolumns - 1] += base_score;
}
}
});
}

public:
explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {}

Expand Down Expand Up @@ -861,18 +922,14 @@ class CPUPredictor : public Predictor {
CHECK(!p_fmat->Info().IsColumnSplit())
<< "Predict contribution support for column-wise data split is not yet implemented.";
auto const n_threads = this->ctx_->Threads();
const int num_feature = model.learner_model_param->num_feature;
std::vector<RegTree::FVec> feat_vecs;
InitThreadTemp(n_threads, &feat_vecs);
const MetaInfo& info = p_fmat->Info();
// number of valid trees
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
}
const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
size_t const ncolumns = num_feature + 1;
CHECK_NE(ncolumns, 0);
size_t const ncolumns = model.learner_model_param->num_feature + 1;
// allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = out_contribs->HostVector();
contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group);
Expand All @@ -884,53 +941,22 @@ class CPUPredictor : public Predictor {
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
});
auto base_margin = info.base_margin_.View(ctx_->Device());
auto base_score = model.learner_model_param->BaseScore(ctx_->Device())(0);
// start collecting the contributions
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView();
// parallel over local batch
common::ParallelFor(batch.Size(), n_threads, [&](auto i) {
auto row_idx = batch.base_rowid + i;
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
if (feats.Size() == 0) {
feats.Init(num_feature);
}
std::vector<bst_float> this_tree_contribs(ncolumns);
// loop over all classes
for (int gid = 0; gid < ngroup; ++gid) {
bst_float* p_contribs = &contribs[(row_idx * ngroup + gid) * ncolumns];
feats.Fill(page[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values.at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
CalculateContributions(*model.trees[j], feats, tree_mean_values,
&this_tree_contribs[0], condition, condition_feature);
} else {
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
}
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop();
// add base margin to BIAS
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), ngroup);
p_contribs[ncolumns - 1] += base_margin(row_idx, gid);
} else {
p_contribs[ncolumns - 1] += base_score;
}
}
});
if (!p_fmat->PageExists<SparsePage>()) {
std::vector<Entry> workspace(info.num_col_ * kUnroll * n_threads);
auto ft = p_fmat->Info().feature_types.ConstHostVector();
for (const auto &batch : p_fmat->GetBatches<GHistIndexMatrix>(ctx_, {})) {
PredictContributionKernel(
GHistIndexMatrixView{batch, info.num_col_, ft, workspace, n_threads},
info, model, tree_weights, &mean_values, &feat_vecs, &contribs, ntree_limit,
approximate, condition, condition_feature);
}
} else {
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
PredictContributionKernel(
SparsePageView{&batch}, info, model, tree_weights, &mean_values, &feat_vecs,
&contribs, ntree_limit, approximate, condition, condition_feature);
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,9 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
if (!p_fmat->PageExists<SparsePage>()) {
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
}
CHECK(!p_fmat->Info().IsColumnSplit())
<< "Predict contribution support for column-wise data split is not yet implemented.";
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
Expand Down Expand Up @@ -1102,6 +1105,9 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
if (!p_fmat->PageExists<SparsePage>()) {
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
}
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
out_contribs->SetDevice(ctx_->Device());
if (tree_end == 0 || tree_end > model.trees.size()) {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ TEST(CPUPredictor, GHistIndexTraining) {
auto adapter = data::ArrayAdapter(columnar.c_str());
std::shared_ptr<DMatrix> p_full{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist);
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist, true);
}

TEST(CPUPredictor, CategoricalPrediction) {
Expand Down
25 changes: 24 additions & 1 deletion tests/cpp/predictor/test_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ TEST(Predictor, PredictionCache) {
}

void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist) {
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs) {
size_t constexpr kCols = 16;
size_t constexpr kClasses = 3;
size_t constexpr kIters = 3;
Expand Down Expand Up @@ -161,6 +162,28 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
for (size_t i = 0; i < rows; ++i) {
EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps);
}

if (check_contribs) {
// Contributions
HostDeviceVector<float> from_full_contribs;
learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true);
HostDeviceVector<float> from_hist_contribs;
learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true);
for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i],
from_full_contribs.ConstHostVector()[i], kRtEps);
}

// Contributions (approximate method)
HostDeviceVector<float> from_full_approx_contribs;
learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true);
HostDeviceVector<float> from_hist_approx_contribs;
learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true);
for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) {
EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i],
from_full_approx_contribs.ConstHostVector()[i], kRtEps);
}
}
}

void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/predictor/test_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ void TestBasic(DMatrix* dmat, Context const * ctx);

// p_full and p_hist should come from the same data set.
void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins,
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist);
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
bool check_contribs = false);

void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
bst_feature_t cols);
Expand Down
28 changes: 27 additions & 1 deletion tests/python/test_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re

import numpy as np
import scipy
import scipy.special

import xgboost as xgb
Expand Down Expand Up @@ -256,3 +255,30 @@ def interaction_value(trees, x, i, j):
brute_force[-1, -1] += base_score
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_interactions=True)
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4

def test_shap_values(self) -> None:
from sklearn.datasets import make_classification, make_regression

def assert_same(X: np.ndarray, y: np.ndarray) -> None:
Xy = xgb.DMatrix(X, y)
booster = xgb.train({}, Xy, num_boost_round=4)
shap_dm = booster.predict(Xy, pred_contribs=True)
Xy = xgb.QuantileDMatrix(X, y)
shap_qdm = booster.predict(Xy, pred_contribs=True)
np.testing.assert_allclose(shap_dm, shap_qdm)

margin = booster.predict(Xy, output_margin=True)
np.testing.assert_allclose(
np.sum(shap_qdm, axis=len(shap_qdm.shape) - 1), margin, 1e-3, 1e-3
)

shap_dm = booster.predict(Xy, pred_interactions=True)
Xy = xgb.QuantileDMatrix(X, y)
shap_qdm = booster.predict(Xy, pred_interactions=True)
np.testing.assert_allclose(shap_dm, shap_qdm)

X, y = make_regression()
assert_same(X, y)

X, y = make_classification()
assert_same(X, y)

0 comments on commit edf501d

Please sign in to comment.