Skip to content

Commit

Permalink
Merge branch 'master' into mehei/shapvalue
Browse files Browse the repository at this point in the history
  • Loading branch information
heimengqi committed Dec 16, 2020
2 parents cf7ee49 + e55b093 commit 86a4283
Show file tree
Hide file tree
Showing 21 changed files with 597 additions and 437 deletions.
1 change: 0 additions & 1 deletion doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Public Module Reference
.. autosummary::
:toctree: _autosummary

econml.automated_ml
econml.bootstrap
econml.cate_estimator
econml.cate_interpreter
Expand Down
18 changes: 11 additions & 7 deletions econml/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def get_dist(est, arr):

def get_result():
return proxy(can_call, prefix,
lambda arr, est: EmpiricalInferenceResults(d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type,
fname_transformer=fname_transformer))
lambda arr, est: EmpiricalInferenceResults(
d_t=d_t, d_y=d_y,
pred=est, pred_dist=get_dist(est, arr),
inf_type=inf_type,
fname_transformer=fname_transformer,
**self._wrapped._input_names if hasattr(self._wrapped, "_input_names") else None))

# Note that inference results are always methods even if the inference is for a property
# (e.g. coef__inference() is a method but coef_ is a property)
Expand All @@ -242,9 +244,11 @@ def normal_inference(*args, **kwargs):
stderr = getattr(self, prefix + '_std')
if can_call:
stderr = stderr(*args, **kwargs)
return NormalInferenceResults(d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
fname_transformer=fname_transformer)
return NormalInferenceResults(
d_t=d_t, d_y=d_y, pred=pred,
pred_stderr=stderr, inf_type=inf_type,
fname_transformer=fname_transformer,
**self._wrapped._input_names if hasattr(self._wrapped, "_input_names") else None)

# If inference is for a property, create a fresh lambda to avoid passing args through
return normal_inference if can_call else lambda: normal_inference()
Expand Down
75 changes: 53 additions & 22 deletions econml/cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import shap
from slicer import Alias
from .inference import BootstrapInference
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params, inverse_onehot, Summary,
broadcast_unit_treatments, cross_product, _shap_explain_cme)
from .utilities import (tensordot, ndim, reshape, shape, parse_final_model_params,
inverse_onehot, Summary, get_input_columns, broadcast_unit_treatments,
cross_product, _shap_explain_cme)
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete, LinearModelFinalInference,\
LinearModelFinalInferenceDiscrete, NormalInferenceResults, GenericSingleTreatmentModelFinalInference,\
GenericModelFinalInferenceDiscrete
Expand Down Expand Up @@ -46,6 +47,18 @@ def _get_inference(self, inference):
# because inf now stores state from fitting est2
return deepcopy(inference)

def _set_input_names(self, Y, T, X, set_flag=True):
"""Set input column names if inputs have column metadata."""
self._input_names = {
"feature_names": get_input_columns(X),
"output_names": get_input_columns(Y),
"treatment_names": get_input_columns(T)
}
if set_flag:
# This flag is true when names are set in a child class instead
# If names are set in a child class, add an attribute reflecting that
self._input_names_set = True

def _strata(self, Y, T, *args, **kwargs):
"""
Get an array of values representing strata that should be preserved by bootstrapping. For example,
Expand All @@ -64,6 +77,13 @@ def _strata(self, Y, T, *args, **kwargs):
def _prefit(self, Y, T, *args, **kwargs):
self._d_y = np.shape(Y)[1:]
self._d_t = np.shape(T)[1:]
# This works only if X is passed as a kwarg
# We plan to enforce X as kwarg only in new releases
if not hasattr(self, "_input_names_set"):
# This checks if names have been set in a child class
# If names were set in a child class, don't do it again
X = kwargs.get('X')
self._set_input_names(Y, T, X)

@abc.abstractmethod
def fit(self, *args, inference=None, **kwargs):
Expand Down Expand Up @@ -682,7 +702,7 @@ def intercept__inference(self):
"""
pass

def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None, treatment_name=None, output_name=None):
def summary(self, alpha=0.1, value=0, decimals=3, feature_names=None, treatment_names=None, output_names=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect.
Expand All @@ -695,11 +715,11 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None, treatment_name
The mean value of the metric you'd like to test under null hypothesis.
decimals: optinal int (default=3)
Number of decimal places to round each column to.
feat_name: optional list of strings or None (default is None)
feature_names: optional list of strings or None (default is None)
The input of the feature names
treatment_name: optional list of strings or None (default is None)
treatment_names: optional list of strings or None (default is None)
The names of the treatments
output_name: optional list of strings or None (default is None)
output_names: optional list of strings or None (default is None)
The names of the outputs
Returns
Expand All @@ -708,6 +728,11 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None, treatment_name
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
# Get input names
feature_names = self.cate_feature_names() if feature_names is None else feature_names
treatment_names = self._input_names["treatment_names"] if treatment_names is None else treatment_names
output_names = self._input_names["output_names"] if output_names is None else output_names
# Summary
smry = Summary()
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
Expand All @@ -722,9 +747,9 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None, treatment_name
try:
coef_table = self.coef__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals,
feat_name=feat_name,
treatment_name=treatment_name,
output_name=output_name)
feature_names=feature_names,
treatment_names=treatment_names,
output_names=output_names)
coef_array = coef_table.values
coef_headers = [i + '\n' +
j for (i, j) in coef_table.columns] if d_t > 1 else coef_table.columns.tolist()
Expand All @@ -736,9 +761,9 @@ def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None, treatment_name
try:
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
value=value, decimals=decimals,
feat_name=None,
treatment_name=treatment_name,
output_name=output_name)
feature_names=None,
treatment_names=treatment_names,
output_names=output_names)
intercept_array = intercept_table.values
intercept_headers = [i + '\n' + j for (i, j)
in intercept_table.columns] if d_t > 1 else intercept_table.columns.tolist()
Expand Down Expand Up @@ -970,7 +995,8 @@ def intercept__inference(self, T):
"""
pass

def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None, treatment_name=None, output_name=None):
def summary(self, T, *, alpha=0.1, value=0, decimals=3,
feature_names=None, treatment_names=None, output_names=None):
""" The summary of coefficient and intercept in the linear model of the constant marginal treatment
effect associated with treatment T.
Expand All @@ -983,11 +1009,11 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None, treatmen
The mean value of the metric you'd like to test under null hypothesis.
decimals: optinal int (default=3)
Number of decimal places to round each column to.
feat_name: optional list of strings or None (default is None)
feature_names: optional list of strings or None (default is None)
The input of the feature names
treatment_name: optional list of strings or None (default is None)
treatment_names: optional list of strings or None (default is None)
The names of the treatments
output_name: optional list of strings or None (default is None)
output_names: optional list of strings or None (default is None)
The names of the outputs
Returns
Expand All @@ -996,6 +1022,11 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None, treatmen
this holds the summary tables and text, which can be printed or
converted to various output formats.
"""
# Get input names
feature_names = self.cate_feature_names() if feature_names is None else feature_names
treatment_names = self._input_names["treatment_names"] if treatment_names is None else treatment_names
output_names = self._input_names["output_names"] if output_names is None else output_names
# Summary
smry = Summary()
smry.add_extra_txt(["<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:",
"$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$",
Expand All @@ -1008,9 +1039,9 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None, treatmen
"Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"])
try:
coef_table = self.coef__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=feat_name,
treatment_name=treatment_name,
output_name=output_name)
alpha=alpha, value=value, decimals=decimals, feature_names=feature_names,
treatment_names=treatment_names,
output_names=output_names)
coef_array = coef_table.values
coef_headers = coef_table.columns.tolist()
coef_stubs = coef_table.index.tolist()
Expand All @@ -1020,9 +1051,9 @@ def summary(self, T, *, alpha=0.1, value=0, decimals=3, feat_name=None, treatmen
print("Coefficient Results: ", e)
try:
intercept_table = self.intercept__inference(T).summary_frame(
alpha=alpha, value=value, decimals=decimals, feat_name=None,
treatment_name=treatment_name,
output_name=output_name)
alpha=alpha, value=value, decimals=decimals, feature_names=None,
treatment_names=treatment_names,
output_names=output_names)
intercept_array = intercept_table.values
intercept_headers = intercept_table.columns.tolist()
intercept_stubs = intercept_table.index.tolist()
Expand Down
25 changes: 16 additions & 9 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import numpy as np
from sklearn.base import TransformerMixin, clone
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import (ElasticNetCV, LassoCV, LogisticRegressionCV)
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
Expand All @@ -61,8 +62,8 @@
from .sklearn_extensions.model_selection import WeightedStratifiedKFold
from .utilities import (_deprecate_positional, add_intercept,
broadcast_unit_treatments, check_high_dimensional,
check_input_arrays, cross_product, deprecated,
fit_with_groups, hstack, inverse_onehot, ndim, reshape,
cross_product, deprecated, fit_with_groups,
hstack, inverse_onehot, ndim, reshape,
reshape_treatmentwise_effects, shape, transpose)


Expand Down Expand Up @@ -282,27 +283,34 @@ def models_t(self):
"""
return [mdl._model for mdl in super().models_t]

def cate_feature_names(self, input_feature_names=None):
def cate_feature_names(self, feature_names=None):
"""
Get the output feature names.
Parameters
----------
input_feature_names: list of strings of length X.shape[1] or None
The names of the input features
feature_names: list of strings of length X.shape[1] or None
The names of the input features. If None and X is a dataframe, it defaults to the column names
from the dataframe.
Returns
-------
out_feature_names: list of strings or None
The names of the output features :math:`\\phi(X)`, i.e. the features with respect to which the
final constant marginal CATE model is linear. It is the names of the features that are associated
with each entry of the :meth:`coef_` parameter. Not available when the featurizer is not None and
does not have a method: `get_feature_names(input_feature_names)`. Otherwise None is returned.
does not have a method: `get_feature_names(feature_names)`. Otherwise None is returned.
"""
if self._d_x is None:
# Handles the corner case when X=None but featurizer might be not None
return None
if feature_names is None:
feature_names = self._input_names["feature_names"]
if self.original_featurizer is None:
return input_feature_names
return feature_names
elif hasattr(self.original_featurizer, 'get_feature_names'):
return self.original_featurizer.get_feature_names(input_feature_names)
# This fails if X=None and featurizer is not None, but that case is handled above
return self.original_featurizer.get_feature_names(feature_names)
else:
raise AttributeError("Featurizer does not have a method: get_feature_names!")

Expand Down Expand Up @@ -760,7 +768,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
if sample_var is not None and inference is not None:
warn("This estimator does not yet support sample variances and inference does not take "
"sample variances into account. This feature will be supported in a future release.")
Y, T, X, W, sample_weight, sample_var = check_input_arrays(Y, T, X, W, sample_weight, sample_var)
check_high_dimensional(X, T, threshold=5, featurizer=self.featurizer,
discrete_treatment=self._discrete_treatment,
msg="The number of features in the final model (< 5) is too small for a sparse model. "
Expand Down
23 changes: 14 additions & 9 deletions econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@
from .sklearn_extensions.linear_model import (
DebiasedLasso, StatsModelsLinearRegression, WeightedLassoCVWrapper)
from .utilities import (_deprecate_positional, check_high_dimensional,
check_input_arrays, filter_none_kwargs,
fit_with_groups, inverse_onehot, _shap_explain_cme)
filter_none_kwargs, fit_with_groups, inverse_onehot, _shap_explain_cme)


class _ModelNuisance:
Expand Down Expand Up @@ -538,27 +537,34 @@ def featurizer(self):
"""
return super().model_final._featurizer

def cate_feature_names(self, input_feature_names=None):
def cate_feature_names(self, feature_names=None):
"""
Get the output feature names.
Parameters
----------
input_feature_names: list of strings of length X.shape[1] or None
The names of the input features
feature_names: list of strings of length X.shape[1] or None
The names of the input features. If None and X is a dataframe, it defaults to the column names
from the dataframe.
Returns
-------
out_feature_names: list of strings or None
The names of the output features :math:`\\phi(X)`, i.e. the features with respect to which the
final CATE model for each treatment is linear. It is the names of the features that are associated
with each entry of the :meth:`coef_` parameter. Available only when the featurizer is not None and has
a method: `get_feature_names(input_feature_names)`. Otherwise None is returned.
a method: `get_feature_names(feature_names)`. Otherwise None is returned.
"""
if self._d_x is None:
# Handles the corner case when X=None but featurizer might be not None
return None
if feature_names is None:
feature_names = self._input_names["feature_names"]
if self.featurizer is None:
return input_feature_names
return feature_names
elif hasattr(self.featurizer, 'get_feature_names'):
return self.featurizer.get_feature_names(input_feature_names)
# This fails if X=None and featurizer is not None, but that case is handled above
return self.featurizer.get_feature_names(feature_names)
else:
raise AttributeError("Featurizer does not have a method: get_feature_names!")

Expand Down Expand Up @@ -1006,7 +1012,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
if sample_weight is not None and inference is not None:
warn("This estimator does not yet support sample variances and inference does not take "
"sample variances into account. This feature will be supported in a future release.")
Y, T, X, W, sample_weight, sample_var = check_input_arrays(Y, T, X, W, sample_weight, sample_var)
check_high_dimensional(X, T, threshold=5, featurizer=self.featurizer,
discrete_treatment=self._discrete_treatment,
msg="The number of features in the final model (< 5) is too small for a sparse model. "
Expand Down
Loading

0 comments on commit 86a4283

Please sign in to comment.