Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add shap value features on each estimator (depends on master branch of shap) #336

Merged
merged 23 commits into from
Dec 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ jobs:

- script: 'pip install sklearn-contrib-lightning'
displayName: 'Install lightning'

- script: 'pip install --force-reinstall --no-cache-dir shap'
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
displayName: 'Install public shap'

- script: 'pip install --force-reinstall scikit-learn==0.23.2'
displayName: 'Install public old sklearn'

- script: 'python setup.py build_sphinx -W'
displayName: 'Build documentation'
Expand All @@ -88,9 +94,6 @@ jobs:
- template: azure-pipelines-steps.yml
parameters:
body:
- script: 'pip install shap'
displayName: 'Install shap'

- script: 'python setup.py pytest'
displayName: 'Unit tests'
env:
Expand Down
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'sklearn': ('https://scikit-learn.org/stable', None),
'sklearn': ('https://scikit-learn.org/0.23/', None),
'matplotlib': ('https://matplotlib.org/', None)}

# -- Options for todo extension ----------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions doc/spec/comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ Detailed estimator comparison
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDML` | 1-d/Binary | | Yes | Yes | | Yes | | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ForestDRLearner` | Categorical | | Yes | | | Yes | Yes | Yes |
| :class:`.ForestDRLearner` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.ContinuousTreatmentOrthoForest` | Continuous | | Yes | Yes | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.DiscreteTreatmentOrthoForest` | Categorical | | Yes | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :mod:`~econml.metalearners` | Categorical | | | | | | Yes | Yes |
| :mod:`~econml.metalearners` | Categorical | | | | | Yes | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
| :class:`.DRLearner` | Categorical | | | | | | Yes | Yes |
+---------------------------------------------+--------------+--------------+------------------+-------------+-----------------+------------+--------------+--------------------+
Expand Down
45 changes: 44 additions & 1 deletion econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from warnings import warn
from .inference import BootstrapInference
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params,
inverse_onehot, Summary, get_input_columns)
inverse_onehot, Summary, get_input_columns, broadcast_unit_treatments,
cross_product)
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
GenericModelFinalInferenceDiscrete
from .shap import _shap_explain_cme, _define_names, _shap_explain_joint_linear_model_cate


class BaseCateEstimator(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -456,6 +458,33 @@ def const_marginal_effect_inference(self, X=None):
"""
pass

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
""" Shap value for the final stage models (const_marginal_effect)

Parameters
----------
X: (m, d_x) matrix
Features for each sample. Should be in the same shape of fitted X in final stage.
feature_names: optional None or list of strings of length X.shape[1] (Default=None)
The names of input features.
treatment_names: optional None or list (Default=None)
The name of treatment. In discrete treatment scenario, the name should not include the name of
the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names: optional None or list (Default=None)
The name of the outcome.

Returns
-------
shap_outs: nested dictionary of Explanation object
A nested dictionary by using each output name (e.g. "Y0" when `output_names=None`) and
each treatment name (e.g. "T0" when `treatment_names=None`) as key
and the shap_values explanation object as value.


"""
return _shap_explain_cme(self.const_marginal_effect, X, self._d_t, self._d_y, feature_names, treatment_names,
output_names)


class TreatmentExpansionMixin(BaseCateEstimator):
"""Mixin which automatically handles promotions of scalar treatments to the appropriate shape."""
Expand Down Expand Up @@ -685,6 +714,20 @@ def summary(self, alpha=0.1, value=0, decimals=3, feature_names=None, treatment_
if len(smry.tables) > 0:
return smry

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
(dt, dy, treatment_names, output_names) = _define_names(self._d_t, self._d_y, treatment_names, output_names)
if hasattr(self, "featurizer") and self.featurizer is not None:
X = self.featurizer.transform(X)
X, T = broadcast_unit_treatments(X, dt)
d_x = X.shape[1]
X_new = cross_product(X, T)
feature_names = self.cate_feature_names(feature_names)
return _shap_explain_joint_linear_model_cate(self.model_final, X_new, T, dt, dy, self.fit_cate_intercept,
feature_names=feature_names, treatment_names=treatment_names,
output_names=output_names)

shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__


class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
"""
Expand Down
28 changes: 27 additions & 1 deletion econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@
from sklearn.base import TransformerMixin, clone
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import (ElasticNetCV, LassoCV, LogisticRegressionCV)
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (FunctionTransformer, LabelEncoder,
OneHotEncoder)
from sklearn.utils import check_random_state
import copy

from ._rlearner import _RLearner
from .cate_estimator import (DebiasedLassoCateEstimatorMixin,
ForestModelFinalCateEstimatorMixin,
LinearModelFinalCateEstimatorMixin,
StatsModelsCateEstimatorMixin)
StatsModelsCateEstimatorMixin,
LinearCateEstimator)
from .inference import StatsModelsInference
from .sklearn_extensions.ensemble import SubsampledHonestForest
from .sklearn_extensions.linear_model import (MultiOutputDebiasedLasso,
Expand All @@ -62,6 +65,7 @@
cross_product, deprecated, fit_with_groups,
hstack, inverse_onehot, ndim, reshape,
reshape_treatmentwise_effects, shape, transpose)
from .shap import _shap_explain_model_cate


class _FirstStageWrapper:
Expand Down Expand Up @@ -938,6 +942,18 @@ def __init__(self,
n_splits=n_splits,
random_state=random_state)

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
if self.featurizer is not None:
F = self.featurizer.transform(X)
else:
F = X
feature_names = self.cate_feature_names(feature_names)

return _shap_explain_model_cate(self.const_marginal_effect, self.model_cate, F, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__


class ForestDML(ForestModelFinalCateEstimatorMixin, NonParamDML):
""" Instance of NonParamDML with a
Expand Down Expand Up @@ -1169,6 +1185,16 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
sample_weight=sample_weight, sample_var=None, groups=groups,
inference=inference)

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
# SubsampleHonestForest can't be recognized by SHAP, but the tree entries are consistent with a tree in
# a RandomForestRegressor, modify the class name in order to be identified as tree models.
model = copy.deepcopy(self.model_cate)
model.__class__ = RandomForestRegressor
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
return _shap_explain_model_cate(self.const_marginal_effect, model, X, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__


@deprecated("The DMLCateEstimator class has been renamed to DML; "
"an upcoming release will remove support for the old name")
Expand Down
37 changes: 35 additions & 2 deletions econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,25 @@
"""

from warnings import warn
from copy import deepcopy

import numpy as np
from sklearn.base import clone
from sklearn.linear_model import (LassoCV, LinearRegression,
LogisticRegressionCV)
from sklearn.ensemble import RandomForestRegressor

from ._ortho_learner import _OrthoLearner
from .cate_estimator import (DebiasedLassoCateEstimatorDiscreteMixin,
ForestModelFinalCateEstimatorDiscreteMixin,
StatsModelsCateEstimatorDiscreteMixin)
StatsModelsCateEstimatorDiscreteMixin, LinearCateEstimator)
from .inference import GenericModelFinalInferenceDiscrete
from .sklearn_extensions.ensemble import SubsampledHonestForest
from .sklearn_extensions.linear_model import (
DebiasedLasso, StatsModelsLinearRegression, WeightedLassoCVWrapper)
from .utilities import (_deprecate_positional, check_high_dimensional,
filter_none_kwargs, fit_with_groups, inverse_onehot)
from .shap import _shap_explain_multitask_model_cate, _shap_explain_model_cate


class _ModelNuisance:
Expand Down Expand Up @@ -564,6 +567,23 @@ def cate_feature_names(self, feature_names=None):
else:
raise AttributeError("Featurizer does not have a method: get_feature_names!")

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
if self.featurizer is not None:
F = self.featurizer.transform(X)
else:
F = X
feature_names = self.cate_feature_names(feature_names)

if self._multitask_model_final:
heimengqi marked this conversation as resolved.
Show resolved Hide resolved
return _shap_explain_multitask_model_cate(self.const_marginal_effect, self.multitask_model_cate, F,
self._d_t, self._d_y, feature_names,
treatment_names, output_names)
else:
return _shap_explain_model_cate(self.const_marginal_effect, super().model_final.models_cate,
F, self._d_t, self._d_y, feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__


class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner):
"""
Expand Down Expand Up @@ -1147,7 +1167,7 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner):
"""

def __init__(self,
model_regression, model_propensity,
model_regression="auto", model_propensity="auto",
min_propensity=1e-6,
categories='auto',
n_crossfit_splits=2,
Expand Down Expand Up @@ -1235,3 +1255,16 @@ def model_final(self):
@property
def fitted_models_final(self):
return super().model_final.models_cate

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
models = []
for fitted_model in self.fitted_models_final:
# SubsampleHonestForest can't be recognized by SHAP, but the tree entries are consistent with a tree in
# a RandomForestRegressor, modify the class name in order to be identified as tree models.
model = deepcopy(fitted_model)
model.__class__ = RandomForestRegressor
models.append(model)
return _shap_explain_model_cate(self.const_marginal_effect, models, X, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__
7 changes: 7 additions & 0 deletions econml/metalearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer
from .utilities import (check_inputs, check_models, broadcast_unit_treatments, reshape_treatmentwise_effects,
inverse_onehot, transpose, _EncoderWrapper, _deprecate_positional)
from .shap import _shap_explain_model_cate


class TLearner(TreatmentExpansionMixin, LinearCateEstimator):
Expand Down Expand Up @@ -458,3 +459,9 @@ def _fit_weighted_pipeline(self, model_instance, X, y, sample_weight):
else:
last_step_name = model_instance.steps[-1][0]
model_instance.fit(X, y, **{"{0}__sample_weight".format(last_step_name): sample_weight})

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None):
vsyrgkanis marked this conversation as resolved.
Show resolved Hide resolved
return _shap_explain_model_cate(self.const_marginal_effect, self.final_models, X, self._d_t, self._d_y,
feature_names=feature_names,
treatment_names=treatment_names, output_names=output_names)
shap_values.__doc__ = LinearCateEstimator.shap_values.__doc__
Loading