From 5de306a871523821e027df93dc33ce0f03a80ad0 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 25 Jun 2021 17:29:32 +0800 Subject: [PATCH] Implement categorical data support for SHAP. * Add CPU implementation. * Update GPUTreeSHAP. * Add GPU implementation by defining custom split condition. --- gputreeshap | 2 +- include/xgboost/tree_model.h | 2 +- src/common/bitfield.h | 6 +- src/common/categorical.h | 6 + src/common/quantile.cuh | 8 +- src/predictor/gpu_predictor.cu | 187 ++++++++++++++++++---- src/tree/tree_model.cc | 19 +-- tests/cpp/predictor/test_cpu_predictor.cc | 5 + tests/cpp/predictor/test_gpu_predictor.cu | 5 + tests/cpp/predictor/test_predictor.cc | 74 +++++++++ tests/cpp/predictor/test_predictor.h | 2 + tests/python-gpu/test_gpu_prediction.py | 21 ++- 12 files changed, 287 insertions(+), 50 deletions(-) diff --git a/gputreeshap b/gputreeshap index 3310a30bb123..5bba198a7c2b 160000 --- a/gputreeshap +++ b/gputreeshap @@ -1 +1 @@ -Subproject commit 3310a30bb123a49ab12c58e03edc2479512d2f64 +Subproject commit 5bba198a7c2b3298dc766740965a4dffa7d8ffa4 diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 7d0d34a840a5..670bef938e9c 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -567,7 +567,7 @@ class RegTree : public Model { * \param condition_feature the index of the feature to fix * \param condition_fraction what fraction of the current weight matches our conditioning feature */ - void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index, + void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index, unsigned unique_depth, PathElement* parent_unique_path, bst_float parent_zero_fraction, bst_float parent_one_fraction, int parent_feature_index, int condition, diff --git a/src/common/bitfield.h b/src/common/bitfield.h index c727360b397c..2aa3f38545b3 100644 --- a/src/common/bitfield.h +++ b/src/common/bitfield.h @@ -87,9 +87,11 @@ struct BitFieldContainer { BitFieldContainer() = default; XGBOOST_DEVICE explicit BitFieldContainer(common::Span bits) : bits_{bits} {} XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {} + BitFieldContainer &operator=(BitFieldContainer const &that) = default; + BitFieldContainer &operator=(BitFieldContainer &&that) = default; - common::Span Bits() { return bits_; } - common::Span Bits() const { return bits_; } + XGBOOST_DEVICE common::Span Bits() { return bits_; } + XGBOOST_DEVICE common::Span Bits() const { return bits_; } /*\brief Compute the size of needed memory allocation. The returned value is in terms * of number of elements with `BitFieldContainer::value_type'. diff --git a/src/common/categorical.h b/src/common/categorical.h index 02899a9018c2..371ae1bd6d8d 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -42,6 +42,12 @@ inline XGBOOST_DEVICE bool Decision(common::Span cats, bst_cat_t return !s_cats.Check(cat); } +struct IsCatOp { + XGBOOST_DEVICE bool operator()(FeatureType ft) { + return ft == FeatureType::kCategorical; + } +}; + using CatBitField = LBitField32; using KCatBitField = CLBitField32; } // namespace common diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index eab7290bf733..be8ea1834caf 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -8,6 +8,7 @@ #include "device_helpers.cuh" #include "quantile.h" #include "timer.h" +#include "categorical.h" namespace xgboost { namespace common { @@ -17,11 +18,6 @@ using WQSketch = WQuantileSketch; using SketchEntry = WQSketch::Entry; namespace detail { -struct IsCatOp { - XGBOOST_DEVICE bool operator()(FeatureType ft) { - return ft == FeatureType::kCategorical; - } -}; struct SketchUnique { XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const { return a.value - b.value == 0; @@ -122,7 +118,7 @@ class SketchContainer { has_categorical_ = !d_feature_types.empty() && thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), - detail::IsCatOp{}); + common::IsCatOp{}); timer_.Init(__func__); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index c09de1cda522..5b91fc1bf508 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -415,6 +415,68 @@ class DeviceModel { } }; +struct ShapSplitCondition { + ShapSplitCondition() = default; + XGBOOST_DEVICE + ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, + bool is_missing_branch, common::CatBitField cats) + : feature_lower_bound(feature_lower_bound), + feature_upper_bound(feature_upper_bound), + is_missing_branch(is_missing_branch), categories{std::move(cats)} { + assert(feature_lower_bound <= feature_upper_bound); + } + + /*! Feature values >= lower and < upper flow down this path. */ + float feature_lower_bound; + float feature_upper_bound; + /*! Feature value set to true flow down this path. */ + common::CatBitField categories; + /*! Do missing values flow down this path? */ + bool is_missing_branch; + + // Does this instance flow down this path? + XGBOOST_DEVICE bool EvaluateSplit(float x) const { + // is nan + if (isnan(x)) { + return is_missing_branch; + } + if (categories.Size() != 0) { + auto cat = static_cast(x); + return categories.Check(cat); + } else { + return x >= feature_lower_bound && x < feature_upper_bound; + } + } + + // the &= op in bitfiled is per cuda thread, this one loops over the entire + // bitfield. + XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l, + common::CatBitField r) { + if (l.Data() == r.Data()) { + return l; + } + if (l.Size() > r.Size()) { + thrust::swap(l, r); + } + for (size_t i = 0; i < r.Bits().size(); ++i) { + l.Bits()[i] &= r.Bits()[i]; + } + return l; + } + + // Combine two split conditions on the same feature + XGBOOST_DEVICE void Merge(ShapSplitCondition other) { + // Combine duplicate features + if (categories.Size() != 0 || other.categories.Size() != 0) { + categories = Intersect(categories, other.categories); + } else { + feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); + feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + } + is_missing_branch = is_missing_branch && other.is_missing_branch; + } +}; + struct PathInfo { int64_t leaf_position; // -1 not a leaf size_t length; @@ -422,11 +484,12 @@ struct PathInfo { }; // Transform model into path element form for GPUTreeShap -void ExtractPaths(dh::device_vector* paths, - const gbm::GBTreeModel& model, size_t tree_limit, - int gpu_id) { - DeviceModel device_model; - device_model.Init(model, 0, tree_limit, gpu_id); +void ExtractPaths( + dh::device_vector> *paths, + DeviceModel *model, dh::device_vector *path_categories, + int gpu_id) { + auto& device_model = *model; + dh::caching_device_vector info(device_model.nodes.Size()); dh::XGBCachingDeviceAllocator alloc; auto d_nodes = device_model.nodes.ConstDeviceSpan(); @@ -462,14 +525,45 @@ void ExtractPaths(dh::device_vector* paths, paths->resize(path_segments.back()); - auto d_paths = paths->data().get(); + auto d_paths = dh::ToSpan(*paths); auto d_info = info.data().get(); auto d_stats = device_model.stats.ConstDeviceSpan(); auto d_tree_group = device_model.tree_group.ConstDeviceSpan(); auto d_path_segments = path_segments.data().get(); + + auto d_split_types = device_model.split_types.ConstDeviceSpan(); + auto d_cat_segments = device_model.categories_tree_segments.ConstDeviceSpan(); + auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan(); + + size_t max_cat = 0; + if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types), + common::IsCatOp{})) { + dh::PinnedMemory pinned; + auto h_max_cat = pinned.GetSpan(1); + auto max_elem_it = dh::MakeTransformIterator( + dh::tbegin(d_cat_node_segments), + [] __device__(RegTree::Segment seg) { return seg.size; }); + size_t max_cat_it = + thrust::max_element(thrust::device, max_elem_it, + max_elem_it + d_cat_node_segments.size()) - + max_elem_it; + dh::safe_cuda(cudaMemcpy(h_max_cat.data(), + d_cat_node_segments.data() + max_cat_it, + h_max_cat.size_bytes(), cudaMemcpyDeviceToHost)); + max_cat = h_max_cat[0].size; + CHECK_GE(max_cat, 1); + path_categories->resize(max_cat * paths->size()); + } + + auto d_model_categories = device_model.categories.DeviceSpan(); + common::Span d_path_categories = dh::ToSpan(*path_categories); + dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) { auto path_info = d_info[idx]; size_t tree_offset = d_tree_segments[path_info.tree_idx]; + TreeView tree{0, path_info.tree_idx, d_nodes, + d_tree_segments, d_split_types, d_cat_segments, + d_cat_node_segments, d_model_categories}; int group = d_tree_group[path_info.tree_idx]; size_t child_idx = path_info.leaf_position; auto child = d_nodes[child_idx]; @@ -481,20 +575,38 @@ void ExtractPaths(dh::device_vector* paths, double child_cover = d_stats[child_idx].sum_hess; double parent_cover = d_stats[parent_idx].sum_hess; double zero_fraction = child_cover / parent_cover; - auto parent = d_nodes[parent_idx]; + auto parent = tree.d_tree[child.Parent()]; + bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx; bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) || (parent.DefaultLeft() && is_left_path); - float lower_bound = is_left_path ? -inf : parent.SplitCond(); - float upper_bound = is_left_path ? parent.SplitCond() : inf; - d_paths[output_position--] = { - idx, parent.SplitIndex(), group, lower_bound, - upper_bound, is_missing_path, zero_fraction, v}; + + float lower_bound = -inf; + float upper_bound = inf; + common::CatBitField bits; + if (common::IsCat(tree.cats.split_type, child.Parent())) { + auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat); + size_t size = tree.cats.node_ptr[child.Parent()].size; + auto node_cats = tree.cats.categories.subspan(tree.cats.node_ptr[child.Parent()].beg, size); + SPAN_CHECK(path_cats.size() >= node_cats.size()); + for (size_t i = 0; i < node_cats.size(); ++i) { + path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i]; + } + bits = common::CatBitField{path_cats}; + } else { + lower_bound = is_left_path ? -inf : parent.SplitCond(); + upper_bound = is_left_path ? parent.SplitCond() : inf; + } + d_paths[output_position--] = + gpu_treeshap::PathElement{ + idx, parent.SplitIndex(), + group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits}, + zero_fraction, v}; child_idx = parent_idx; child = parent; } // Root node has feature -1 - d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v}; + d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v}; }); } @@ -696,11 +808,16 @@ class GPUPredictor : public xgboost::Predictor { void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, unsigned tree_end, - std::vector const*, + std::vector const* tree_weights, bool approximate, int, unsigned) const override { + std::string not_implemented{"contribution is not implemented in GPU " + "predictor, use `cpu_predictor` instead."}; if (approximate) { - LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor."; + LOG(FATAL) << "Approximated " << not_implemented; + } + if (tree_weights != nullptr) { + LOG(FATAL) << "Dart booster feature " << not_implemented; } dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); out_contribs->SetDevice(generic_param_->gpu_id); @@ -718,16 +835,21 @@ class GPUPredictor : public xgboost::Predictor { out_contribs->Fill(0.0f); auto phis = out_contribs->DeviceSpan(); - dh::device_vector device_paths; - ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id); + dh::device_vector> + device_paths; + DeviceModel d_model; + d_model.Init(model, 0, tree_end, generic_param_->gpu_id); + dh::device_vector categories; + ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { batch.data.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id); SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature); - gpu_treeshap::GPUTreeShap( - X, device_paths.begin(), device_paths.end(), ngroup, - phis.data() + batch.base_rowid * contributions_columns, phis.size()); + auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; + gpu_treeshap::GPUTreeShap>( + X, device_paths.begin(), device_paths.end(), ngroup, begin, + dh::tend(phis)); } // Add the base margin term to last column p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); @@ -746,11 +868,15 @@ class GPUPredictor : public xgboost::Predictor { HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, unsigned tree_end, - std::vector const*, + std::vector const* tree_weights, bool approximate) const override { + std::string not_implemented{"contribution is not implemented in GPU " + "predictor, use `cpu_predictor` instead."}; if (approximate) { - LOG(FATAL) << "[Internal error]: " << __func__ - << " approximate is not implemented in GPU Predictor."; + LOG(FATAL) << "Approximated " << not_implemented; + } + if (tree_weights != nullptr) { + LOG(FATAL) << "Dart booster feature " << not_implemented; } dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); out_contribs->SetDevice(generic_param_->gpu_id); @@ -769,16 +895,21 @@ class GPUPredictor : public xgboost::Predictor { out_contribs->Fill(0.0f); auto phis = out_contribs->DeviceSpan(); - dh::device_vector device_paths; - ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id); + dh::device_vector> + device_paths; + DeviceModel d_model; + d_model.Init(model, 0, tree_end, generic_param_->gpu_id); + dh::device_vector categories; + ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { batch.data.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id); SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature); - gpu_treeshap::GPUTreeShapInteractions( - X, device_paths.begin(), device_paths.end(), ngroup, - phis.data() + batch.base_rowid * contributions_columns, phis.size()); + auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns; + gpu_treeshap::GPUTreeShapInteractions>( + X, device_paths.begin(), device_paths.end(), ngroup, begin, + dh::tend(phis)); } // Add the base margin term to last column p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index fdd39f9bd995..b30f0d65330c 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth, // recursive computation of SHAP values for a decision tree void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, - unsigned node_index, unsigned unique_depth, + bst_node_t node_index, unsigned unique_depth, PathElement *parent_unique_path, bst_float parent_zero_fraction, bst_float parent_one_fraction, int parent_feature_index, @@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, // internal node } else { // find which branch is "hot" (meaning x would follow it) - unsigned hot_index = 0; - if (feat.IsMissing(split_index)) { - hot_index = node.DefaultChild(); - } else if (feat.GetFvalue(split_index) < node.SplitCond()) { - hot_index = node.LeftChild(); - } else { - hot_index = node.RightChild(); - } - const unsigned cold_index = (static_cast(hot_index) == node.LeftChild() ? - node.RightChild() : node.LeftChild()); + auto const &cats = this->GetCategoriesMatrix(); + bst_node_t hot_index = predictor::GetNextNode( + node, node_index, feat.GetFvalue(split_index), + feat.IsMissing(split_index), cats); + + const auto cold_index = + (hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild()); const bst_float w = this->Stat(node_index).sum_hess; const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w; const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w; diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 13808e5c9a54..92dca55034bb 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) { } } + +TEST(CpuPredictor, IterationRange) { + TestIterationRange("cpu_predictor"); +} + TEST(CpuPredictor, ExternalMemory) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index d4a5dce63e29..722d24299fa2 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) { } } +TEST(GPUPredictor, IterationRange) { + TestIterationRange("gpu_predictor"); +} + + TEST(GPUPredictor, CategoricalPrediction) { TestCategoricalPrediction("gpu_predictor"); } diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index da5f23090421..855ede0d7a38 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -281,4 +281,78 @@ void TestCategoricalPredictLeaf(StringView name) { predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1); } + + +void TestIterationRange(std::string name) { + size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3; + auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses); + std::unique_ptr learner{Learner::Create({dmat})}; + + learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)}, + {"predictor", name}}); + + size_t kIters = 10; + for (size_t i = 0; i < kIters; ++i) { + learner->UpdateOneIter(i, dmat); + } + + bool bound = false; + std::unique_ptr sliced {learner->Slice(0, 3, 1, &bound)}; + ASSERT_FALSE(bound); + + HostDeviceVector out_predt_sliced; + HostDeviceVector out_predt_ranged; + + // margin + { + sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, + false, false); + + learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false, + false, false); + + auto const &h_sliced = out_predt_sliced.HostVector(); + auto const &h_range = out_predt_ranged.HostVector(); + ASSERT_EQ(h_sliced.size(), h_range.size()); + ASSERT_EQ(h_sliced, h_range); + } + + // SHAP + { + sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, + true, false, false); + + learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true, + false, false); + + auto const &h_sliced = out_predt_sliced.HostVector(); + auto const &h_range = out_predt_ranged.HostVector(); + ASSERT_EQ(h_sliced.size(), h_range.size()); + ASSERT_EQ(h_sliced, h_range); + } + + // SHAP interaction + { + sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, + false, false, true); + learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false, + false, true); + auto const &h_sliced = out_predt_sliced.HostVector(); + auto const &h_range = out_predt_ranged.HostVector(); + ASSERT_EQ(h_sliced.size(), h_range.size()); + ASSERT_EQ(h_sliced, h_range); + } + + // Leaf + { + sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, + false, false, false); + learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false, + false, false); + auto const &h_sliced = out_predt_sliced.HostVector(); + auto const &h_range = out_predt_ranged.HostVector(); + ASSERT_EQ(h_sliced.size(), h_range.size()); + ASSERT_EQ(h_sliced, h_range); + } +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index d5eccf6a0668..0cb49d4ad773 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -68,6 +68,8 @@ void TestPredictionWithLesserFeatures(std::string preditor_name); void TestCategoricalPrediction(std::string name); void TestCategoricalPredictLeaf(StringView name); + +void TestIterationRange(std::string name); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_ diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index ce73034a4a5d..12106836bad6 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -5,7 +5,7 @@ import xgboost as xgb from xgboost.compat import PANDAS_INSTALLED -from hypothesis import given, strategies, assume, settings, note +from hypothesis import given, strategies, assume, settings if PANDAS_INSTALLED: from hypothesis.extra.pandas import column, data_frames, range_indexes @@ -275,6 +275,25 @@ def test_shap_interactions(self, num_rounds, dataset, param): margin, 1e-3, 1e-3) + def test_shap_categorical(self): + X, y = tm.make_categorical(100, 20, 7, False) + Xy = xgb.DMatrix(X, y, enable_categorical=True) + booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10) + + booster.set_param({"predictor": "gpu_predictor"}) + shap = booster.predict(Xy, pred_contribs=True) + margin = booster.predict(Xy, output_margin=True) + np.testing.assert_allclose( + np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3 + ) + + booster.set_param({"predictor": "cpu_predictor"}) + shap = booster.predict(Xy, pred_contribs=True) + margin = booster.predict(Xy, output_margin=True) + np.testing.assert_allclose( + np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3 + ) + def test_predict_leaf_basic(self): gpu_leaf = run_predict_leaf('gpu_predictor') cpu_leaf = run_predict_leaf('cpu_predictor')