diff --git a/python/cuml/explainer/kernel_shap.pyx b/python/cuml/explainer/kernel_shap.pyx index 6a2281e76e..eb312775e8 100644 --- a/python/cuml/explainer/kernel_shap.pyx +++ b/python/cuml/explainer/kernel_shap.pyx @@ -645,8 +645,13 @@ def _weighted_linear_regression(X, # from nonzero_inds and some additional arrays # nonzero_inds tells us which cols of X to use y = y - X[:, nonzero_inds[-1]] * (fx - expected_value) - Xw = cp.transpose( - cp.transpose(X[:, nonzero_inds[:-1]]) - X[:, nonzero_inds[-1]]) + if len(nonzero_inds) == 1: + # when only one index is nonzero, use that column + Xw = X[:, nonzero_inds] + else: + Xw = cp.transpose( + cp.transpose( + X[:, nonzero_inds[:-1]]) - X[:, nonzero_inds[-1]]) Xw = Xw * cp.sqrt(weights[:, cp.newaxis]) y = y * cp.sqrt(weights) diff --git a/python/cuml/test/explainer/test_explainer_kernel_shap.py b/python/cuml/test/explainer/test_explainer_kernel_shap.py index 219801f384..ec6e6cce18 100644 --- a/python/cuml/test/explainer/test_explainer_kernel_shap.py +++ b/python/cuml/test/explainer/test_explainer_kernel_shap.py @@ -21,9 +21,11 @@ import pytest import sklearn.neighbors +from cuml import Lasso from cuml import KernelExplainer from cuml.common.import_utils import has_scipy from cuml.common.import_utils import has_shap +from cuml.datasets import make_regression from cuml.test.conftest import create_synthetic_dataset from cuml.test.utils import ClassEnumerator from cuml.test.utils import get_shap_values @@ -322,6 +324,18 @@ def test_l1_regularization(exact_shap_regression_dataset, l1_type): assert isinstance(nz, cp.ndarray) +def test_typeerror_input(): + X, y = make_regression(n_samples=100, n_features=10, random_state=10) + clf = Lasso() + clf.fit(X, y) + exp = KernelExplainer(model=clf.predict, data=X, nsamples=10) + try: + _ = exp.shap_values(X) + assert True + except TypeError: + assert False + + ############################################################################### # Precomputed results # # and testing variables #