Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move SHAP explainers out of experimental #3596

Merged
merged 20 commits into from
Apr 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
32487bc
ENH Plot tests and test cleanups
dantegd Mar 10, 2021
b61f6a0
DBG Add conda-forge shap to CI for testing
dantegd Mar 10, 2021
f5e188d
FIX comment print to fix pep8
dantegd Mar 10, 2021
a1b5583
DBG Add pip shap to CI for testing
dantegd Mar 10, 2021
af3dc79
Merge branch '019-enh-shap-plots' of github.com:dantegd/cuml into 019…
dantegd Mar 10, 2021
ed059e5
Merge branch-0.19 into 019-enh-shap-plots
dantegd Mar 25, 2021
6c55212
ENH Multiple enhancements
dantegd Mar 25, 2021
0a377d5
Merge branch 'branch-0.19' of https://github.com/rapidsai/cuml into 0…
dantegd Mar 30, 2021
b29eaed
ENH Move explainers out of experimental
dantegd Mar 30, 2021
6f2dfaa
FIX Remove not needed function
dantegd Mar 30, 2021
d5761fe
ENH Multiple enhancements, corrections to remove experimental and add…
dantegd Apr 1, 2021
688b2a0
FIX The smallest copyright fix so far...
dantegd Apr 1, 2021
68475ba
Update python/cuml/explainer/kernel_shap.pyx
dantegd Apr 1, 2021
8b467d2
Update python/cuml/explainer/kernel_shap.pyx
dantegd Apr 1, 2021
e48ddaf
Update python/cuml/explainer/kernel_shap.pyx
dantegd Apr 1, 2021
0b5a4b0
Update python/cuml/explainer/permutation_shap.pyx
dantegd Apr 1, 2021
8785f19
FIX Add more samples to pytests and fix a bug in permutation shap
dantegd Apr 1, 2021
3c51ae1
FIX gpu ci build script and stray print
dantegd Apr 2, 2021
d834754
Merge branch 'branch-0.19' of https://github.com/rapidsai/cuml into 0…
dantegd Apr 2, 2021
4b10e49
FIX temporarily xfail hellinger pytest
dantegd Apr 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ gpuci_conda_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid
"xgboost=1.3.3dev.rapidsai${MINOR_VERSION}" \
"rapids-build-env=${MINOR_VERSION}.*" \
"rapids-notebook-env=${MINOR_VERSION}.*" \
"rapids-doc-env=${MINOR_VERSION}.*"
"rapids-doc-env=${MINOR_VERSION}.*" \
"shap>=0.37,<=0.39"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will include version 0.39 as well. Did you want to exclude 0.39 or include it @dantegd

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include it, the original had the wrong blank space because I accidentally replaced the <= for < in a bulk replace when I was doing something else, thanks for the notice!


# https://docs.rapids.ai/maintainers/depmgmt/
# gpuci_conda_retry remove --force rapids-build-env rapids-notebook-env
Expand Down
27 changes: 17 additions & 10 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Metrics (clustering and trustworthiness)

.. automodule:: cuml.metrics.cluster.silhouette_score
:members:

.. automodule:: cuml.metrics.cluster.completeness_score
:members:

Expand Down Expand Up @@ -416,6 +416,22 @@ ARIMA
.. autoclass:: cuml.tsa.auto_arima.AutoARIMA
:members:

Model Explainability
====================

SHAP Kernel Explainer
---------------------

.. autoclass:: cuml.explainer.KernelExplainer
:members:

SHAP Permutation Explainer
--------------------------

.. autoclass:: cuml.explainer.PermutationExplainer
:members:


Multi-Node, Multi-GPU Algorithms
================================

Expand Down Expand Up @@ -533,15 +549,6 @@ Preprocessing
add_dummy_feature, binarize, minmax_scale, normalize,
PolynomialFeatures, robust_scale, scale


Model Explanation (SHAP)
------------------------
.. autoclass:: cuml.experimental.explainer.KernelExplainer
:members:

.. autoclass:: cuml.experimental.explainer.PermutationExplainer
:members:

Linear Models
-------------
.. autoclass:: cuml.experimental.linear_model.Lars
Expand Down
3 changes: 3 additions & 0 deletions python/cuml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from cuml.ensemble.randomforestclassifier import RandomForestClassifier
from cuml.ensemble.randomforestregressor import RandomForestRegressor

from cuml.explainer.kernel_shap import KernelExplainer
from cuml.explainer.permutation_shap import PermutationExplainer

from cuml.fil import fil

from cuml.internals.global_settings import (
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/common/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def has_sklearn():
return False


def has_shap(min_version=None):
def has_shap(min_version="0.37"):
try:
import shap # noqa
if min_version is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,5 +14,5 @@
# limitations under the License.
#

from cuml.experimental.explainer.kernel_shap import KernelExplainer
from cuml.experimental.explainer.permutation_shap import PermutationExplainer
from cuml.explainer.kernel_shap import KernelExplainer
from cuml.explainer.permutation_shap import PermutationExplainer
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ from cuml.common.input_utils import input_to_cupy_array
from cuml.common.input_utils import input_to_host_array
from cuml.common.logger import debug
from cuml.common.logger import warn
from cuml.experimental.explainer.common import get_dtype_from_model_func
from cuml.experimental.explainer.common import get_handle_from_cuml_model_func
from cuml.experimental.explainer.common import get_link_fn_from_str_or_fn
from cuml.experimental.explainer.common import get_tag_from_model_func
from cuml.experimental.explainer.common import model_func_call
from cuml.experimental.explainer.common import output_list_shap_values
from cuml.explainer.common import get_dtype_from_model_func
from cuml.explainer.common import get_handle_from_cuml_model_func
from cuml.explainer.common import get_link_fn_from_str_or_fn
from cuml.explainer.common import get_tag_from_model_func
from cuml.explainer.common import model_func_call
from cuml.explainer.common import output_list_shap_values

from cuml.raft.common.handle cimport handle_t
from libcpp cimport bool
Expand Down Expand Up @@ -269,7 +269,7 @@ class SHAPBase():
shap_values.append(cp.zeros(X.shape, dtype=self.dtype))

# Allocate synthetic dataset array once for multiple explanations
if getattr(self, "synth_data", None) is None and synth_data_shape \
if getattr(self, "_synth_data", None) is None and synth_data_shape \
is not None:
self._synth_data = cp.zeros(
shape=synth_data_shape,
Expand Down Expand Up @@ -297,16 +297,14 @@ class SHAPBase():
output_type=self.output_type
)

debug(self._get_timers_str())

return shap_values

def __call__(self,
X,
main_effects=False,
**kwargs):

if not has_shap("0.37"):
if not has_shap(min_version="0.37"):
raise ImportError("SHAP >= 0.37 was not found, please install it "
" or use the explainer.shap_values function "
"instead. ")
Expand Down Expand Up @@ -411,9 +409,3 @@ class SHAPBase():
def _reset_timers(self):
self.total_time = 0
self.model_call_time = 0

def _get_timers_str(self):
res_str = "Time spent by category:\n"
res_str += "Total time: {}".format(self.total_time)
res_str += "Time spent in model calls {}:".format(self.model_call_time)
return res_str
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ import cupy as cp
import numpy as np
import time

from cuml.common.import_utils import has_shap
from cuml.common.import_utils import has_sklearn
from cuml.common.input_utils import input_to_cupy_array
from cuml.experimental.explainer.base import SHAPBase
from cuml.experimental.explainer.common import get_cai_ptr
from cuml.experimental.explainer.common import model_func_call
from cuml.experimental.explainer.common import output_list_shap_values
from cuml.explainer.base import SHAPBase
from cuml.explainer.common import get_cai_ptr
from cuml.explainer.common import model_func_call
from cuml.explainer.common import output_list_shap_values
from cuml.linear_model import Lasso
from cuml.linear_model import LinearRegression
from cuml.raft.common.handle import Handle
Expand Down Expand Up @@ -71,28 +70,32 @@ cdef extern from "cuml/explainer/kernel_shap.hpp" namespace "ML":

class KernelExplainer(SHAPBase):
"""
GPU accelerated of SHAP's kernel explainer (experimental).

Based on the SHAP package:
GPU accelerated of SHAP's kernel explainer.

cuML's SHAP based explainers accelerate the algorithmic part of SHAP.
They are optimized to be used with fast GPU based models, like those in
cuML. By creating the datasets and internal calculations,
alongside minimizing data copies and transfers, they can accelerate
explanations significantly. But they can also be used with
CPU based models, where speedups can still be achieved, but those can be
capped by factors like data transfers and the speed of the models.

KernelExplainer is based on the Python SHAP
package's KernelExplainer class:
https://github.com/slundberg/shap/blob/master/shap/explainers/_kernel.py

Main differences of the GPU version:

- Data generation and Kernel SHAP calculations are significantly faster,
but this has a tradeoff of having more model evaluations if both the
observation explained and the background data have many 0-valued
columns.
- Support for SHAP's new Explanation and API will be available in the
next version.
- There is a small initialization cost (similar to training time of
regular Scikit/cuML models) of a few seconds, which was a tradeoff for
faster explanations after that.
- Only tabular data is supported for now, via passing the background
dataset explicitly. Since the new API of SHAP is still evolving, the
main supported API right now is the old one
(i.e. ``explainer.shap_values()``)
- Sparse data support is planned for the near future.
- Further optimizations are in progress.
Current characteristics of the GPU version:

* Unlike the SHAP package, ``nsamples`` is a parameter at the
initialization of the explainer and there is a small initialization
time.
* Only tabular data is supported for now, via passing the background
dataset explicitly.
* Sparse data support is planned for the near future.
* Further optimizations are in progress. For example, if the background
dataset has constant value columns and the observation has the same
value in some entries, the number of evaluations of the function can
be reduced (this will come in the next version).

Parameters
----------
Expand Down Expand Up @@ -125,7 +128,7 @@ class KernelExplainer(SHAPBase):
random_state: int, RandomState instance or None (default = None)
Seed for the random number generator for dataset creation. Note: due to
the design of the sampling algorithm the concurrency can affect
results so currently 100% deterministic execution is not guaranteed.
results, so currently 100% deterministic execution is not guaranteed.
gpu_model : bool or None (default = None)
If None Explainer will try to infer whether `model` can take GPU data
(as CuPy arrays), otherwise it will use NumPy arrays to call `model`.
Expand Down Expand Up @@ -155,7 +158,7 @@ class KernelExplainer(SHAPBase):
>>> from cuml import make_regression
>>> from cuml import train_test_split
>>>
>>> from cuml.experimental.explainer import KernelExplainer as cuKE
>>> from cuml.explainer import KernelExplainer
>>>
>>> X, y = make_regression(
... n_samples=102,
Expand All @@ -171,7 +174,7 @@ class KernelExplainer(SHAPBase):
>>>
>>> model = SVR().fit(X_train, y_train)
>>>
>>> cu_explainer = cuKE(
>>> cu_explainer = KernelExplainer(
... model=model.predict,
... data=X_train,
... gpu_model=True)
Expand All @@ -190,7 +193,7 @@ class KernelExplainer(SHAPBase):
*,
model,
data,
nsamples=2**11,
nsamples='auto',
link='identity',
verbose=False,
random_state=None,
Expand All @@ -199,7 +202,7 @@ class KernelExplainer(SHAPBase):
dtype=None,
output_type=None):

super(KernelExplainer, self).__init__(
super().__init__(
model=model,
background=data,
order='C',
Expand All @@ -212,7 +215,12 @@ class KernelExplainer(SHAPBase):
output_type=output_type
)

self.nsamples = nsamples
# default value matching SHAP package
if nsamples == 'auto':
self.nsamples = 2 * self.ncols + 2**11
else:
self.nsamples = nsamples

# Maximum number of samples that user can set
max_samples = 2 ** 32

Expand Down Expand Up @@ -249,8 +257,6 @@ class KernelExplainer(SHAPBase):
self._weights = cp.ones(self.nsamples, dtype=self.dtype)
self._weights[:self.nsamples_exact] = cp.array(weight)

self._reset_timers()

def shap_values(self,
X,
l1_reg='auto',
Expand All @@ -275,8 +281,7 @@ class KernelExplainer(SHAPBase):

Returns
-------
values : array or list

shap_values : array or list
"""
return self._explain(X,
synth_data_shape=(self.nrows * self.nsamples,
Expand Down Expand Up @@ -366,8 +371,6 @@ class KernelExplainer(SHAPBase):
<int> maxsample,
<uint64_t> self.random_state)

# kept while in experimental namespace. It is not needed for cuml
# models, but for other GPU models it is
self.handle.sync()

model_timer = time.time()
Expand All @@ -379,8 +382,6 @@ class KernelExplainer(SHAPBase):
self.model_call_time = \
self.model_call_time + (time.time() - model_timer)

l1_reg_time = 0

for i in range(self.model_dimensions):
if self.model_dimensions == 1:
y_hat = y - self._expected_value
Expand Down Expand Up @@ -415,7 +416,7 @@ class KernelExplainer(SHAPBase):
self.l1_reg_time = \
self.l1_reg_time + (time.time() - reg_timer)
# in case all indexes become zero
if nonzero_inds.shape == (0, ):
if len(nonzero_inds) == 0:
return None

reg_timer = time.time()
Expand Down Expand Up @@ -449,14 +450,6 @@ class KernelExplainer(SHAPBase):
self.l1_reg_time = 0
self.linear_model_time = 0

def _get_timers_str(self):
res_str = super()._get_timers_str()
res_str += "Time spent in L1 regularization: {}".format(
self.l1_reg_time)
res_str += "Time spent in linear model calculation: {}".format(
self.linear_model_time)
return res_str


def _get_number_of_exact_random_samples(ncols, nsamples):
"""
Expand Down
Loading