Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Categorical split support for SHAP #7053

Merged
merged 1 commit into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
*/
void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
unsigned unique_depth, PathElement* parent_unique_path,
bst_float parent_zero_fraction, bst_float parent_one_fraction,
int parent_feature_index, int condition,
Expand Down
6 changes: 4 additions & 2 deletions src/common/bitfield.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ struct BitFieldContainer {
BitFieldContainer() = default;
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
BitFieldContainer &operator=(BitFieldContainer const &that) = default;
BitFieldContainer &operator=(BitFieldContainer &&that) = default;

common::Span<value_type> Bits() { return bits_; }
common::Span<value_type const> Bits() const { return bits_; }
XGBOOST_DEVICE common::Span<value_type> Bits() { return bits_; }
XGBOOST_DEVICE common::Span<value_type const> Bits() const { return bits_; }

/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
Expand Down
6 changes: 6 additions & 0 deletions src/common/categorical.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t
return !s_cats.Check(cat);
}

struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};

using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
} // namespace common
Expand Down
8 changes: 2 additions & 6 deletions src/common/quantile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "device_helpers.cuh"
#include "quantile.h"
#include "timer.h"
#include "categorical.h"

namespace xgboost {
namespace common {
Expand All @@ -17,11 +18,6 @@ using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry;

namespace detail {
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};
struct SketchUnique {
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
Expand Down Expand Up @@ -122,7 +118,7 @@ class SketchContainer {
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
detail::IsCatOp{});
common::IsCatOp{});

timer_.Init(__func__);
}
Expand Down
187 changes: 159 additions & 28 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -415,18 +415,81 @@ class DeviceModel {
}
};

struct ShapSplitCondition {
ShapSplitCondition() = default;
XGBOOST_DEVICE
ShapSplitCondition(float feature_lower_bound, float feature_upper_bound,
bool is_missing_branch, common::CatBitField cats)
: feature_lower_bound(feature_lower_bound),
feature_upper_bound(feature_upper_bound),
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
assert(feature_lower_bound <= feature_upper_bound);
}

/*! Feature values >= lower and < upper flow down this path. */
float feature_lower_bound;
float feature_upper_bound;
/*! Feature value set to true flow down this path. */
common::CatBitField categories;
/*! Do missing values flow down this path? */
bool is_missing_branch;

// Does this instance flow down this path?
XGBOOST_DEVICE bool EvaluateSplit(float x) const {
// is nan
if (isnan(x)) {
return is_missing_branch;
}
if (categories.Size() != 0) {
auto cat = static_cast<uint32_t>(x);
return categories.Check(cat);
} else {
return x >= feature_lower_bound && x < feature_upper_bound;
}
}

// the &= op in bitfiled is per cuda thread, this one loops over the entire
// bitfield.
XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l,
common::CatBitField r) {
if (l.Data() == r.Data()) {
return l;
}
if (l.Size() > r.Size()) {
thrust::swap(l, r);
}
for (size_t i = 0; i < r.Bits().size(); ++i) {
l.Bits()[i] &= r.Bits()[i];
}
return l;
}

// Combine two split conditions on the same feature
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
// Combine duplicate features
if (categories.Size() != 0 || other.categories.Size() != 0) {
categories = Intersect(categories, other.categories);
} else {
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
}
is_missing_branch = is_missing_branch && other.is_missing_branch;
}
};

struct PathInfo {
int64_t leaf_position; // -1 not a leaf
size_t length;
size_t tree_idx;
};

// Transform model into path element form for GPUTreeShap
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
const gbm::GBTreeModel& model, size_t tree_limit,
int gpu_id) {
DeviceModel device_model;
device_model.Init(model, 0, tree_limit, gpu_id);
void ExtractPaths(
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> *paths,
DeviceModel *model, dh::device_vector<uint32_t> *path_categories,
int gpu_id) {
auto& device_model = *model;

dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
auto d_nodes = device_model.nodes.ConstDeviceSpan();
Expand Down Expand Up @@ -462,14 +525,45 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,

paths->resize(path_segments.back());

auto d_paths = paths->data().get();
auto d_paths = dh::ToSpan(*paths);
auto d_info = info.data().get();
auto d_stats = device_model.stats.ConstDeviceSpan();
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
auto d_path_segments = path_segments.data().get();

auto d_split_types = device_model.split_types.ConstDeviceSpan();
auto d_cat_segments = device_model.categories_tree_segments.ConstDeviceSpan();
auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan();

size_t max_cat = 0;
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
common::IsCatOp{})) {
dh::PinnedMemory pinned;
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1);
auto max_elem_it = dh::MakeTransformIterator<size_t>(
dh::tbegin(d_cat_node_segments),
[] __device__(RegTree::Segment seg) { return seg.size; });
size_t max_cat_it =
thrust::max_element(thrust::device, max_elem_it,
max_elem_it + d_cat_node_segments.size()) -
max_elem_it;
dh::safe_cuda(cudaMemcpy(h_max_cat.data(),
d_cat_node_segments.data() + max_cat_it,
h_max_cat.size_bytes(), cudaMemcpyDeviceToHost));
max_cat = h_max_cat[0].size;
CHECK_GE(max_cat, 1);
path_categories->resize(max_cat * paths->size());
}

auto d_model_categories = device_model.categories.DeviceSpan();
common::Span<uint32_t> d_path_categories = dh::ToSpan(*path_categories);

dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
auto path_info = d_info[idx];
size_t tree_offset = d_tree_segments[path_info.tree_idx];
TreeView tree{0, path_info.tree_idx, d_nodes,
d_tree_segments, d_split_types, d_cat_segments,
d_cat_node_segments, d_model_categories};
int group = d_tree_group[path_info.tree_idx];
size_t child_idx = path_info.leaf_position;
auto child = d_nodes[child_idx];
Expand All @@ -481,20 +575,38 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
double child_cover = d_stats[child_idx].sum_hess;
double parent_cover = d_stats[parent_idx].sum_hess;
double zero_fraction = child_cover / parent_cover;
auto parent = d_nodes[parent_idx];
auto parent = tree.d_tree[child.Parent()];

bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
d_paths[output_position--] = {
idx, parent.SplitIndex(), group, lower_bound,
upper_bound, is_missing_path, zero_fraction, v};

float lower_bound = -inf;
float upper_bound = inf;
common::CatBitField bits;
if (common::IsCat(tree.cats.split_type, child.Parent())) {
auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat);
size_t size = tree.cats.node_ptr[child.Parent()].size;
auto node_cats = tree.cats.categories.subspan(tree.cats.node_ptr[child.Parent()].beg, size);
SPAN_CHECK(path_cats.size() >= node_cats.size());
for (size_t i = 0; i < node_cats.size(); ++i) {
path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i];
}
bits = common::CatBitField{path_cats};
} else {
lower_bound = is_left_path ? -inf : parent.SplitCond();
upper_bound = is_left_path ? parent.SplitCond() : inf;
}
d_paths[output_position--] =
gpu_treeshap::PathElement<ShapSplitCondition>{
idx, parent.SplitIndex(),
group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits},
zero_fraction, v};
child_idx = parent_idx;
child = parent;
}
// Root node has feature -1
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v};
});
}

Expand Down Expand Up @@ -696,11 +808,16 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float> const*,
std::vector<bst_float> const* tree_weights,
bool approximate, int,
unsigned) const override {
std::string not_implemented{"contribution is not implemented in GPU "
"predictor, use `cpu_predictor` instead."};
if (approximate) {
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
LOG(FATAL) << "Approximated " << not_implemented;
}
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
Expand All @@ -718,16 +835,21 @@ class GPUPredictor : public xgboost::Predictor {
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
Expand All @@ -746,11 +868,15 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end,
std::vector<bst_float> const*,
std::vector<bst_float> const* tree_weights,
bool approximate) const override {
std::string not_implemented{"contribution is not implemented in GPU "
"predictor, use `cpu_predictor` instead."};
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
LOG(FATAL) << "Approximated " << not_implemented;
}
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
Expand All @@ -769,16 +895,21 @@ class GPUPredictor : public xgboost::Predictor {
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShapInteractions(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
Expand Down
19 changes: 8 additions & 11 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,

// recursive computation of SHAP values for a decision tree
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
bst_node_t node_index, unsigned unique_depth,
PathElement *parent_unique_path,
bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index,
Expand Down Expand Up @@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();
}
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
node.RightChild() : node.LeftChild());
auto const &cats = this->GetCategoriesMatrix();
bst_node_t hot_index = predictor::GetNextNode<true, true>(
node, node_index, feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);

const auto cold_index =
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
const bst_float w = this->Stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) {
}
}


TEST(CpuPredictor, IterationRange) {
TestIterationRange("cpu_predictor");
}

TEST(CpuPredictor, ExternalMemory) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/predictor/test_gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) {
}
}

TEST(GPUPredictor, IterationRange) {
TestIterationRange("gpu_predictor");
}


TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}
Expand Down
Loading