Skip to content

Commit

Permalink
Add support for monte carlo nuisance estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Dec 16, 2020
1 parent fc0ba65 commit 3021200
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 12 deletions.
23 changes: 21 additions & 2 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def _subinds_check_none(self, var, inds):
return var[inds] if var is not None else None

def _strata(self, Y, T, X=None, W=None, Z=None,
sample_weight=None, sample_var=None, groups=None, cache_values=False):
sample_weight=None, sample_var=None, groups=None,
cache_values=False, monte_carlo_iterations=None):
if self._discrete_instrument:
Z = LabelEncoder().fit_transform(np.ravel(Z))

Expand Down Expand Up @@ -517,7 +518,7 @@ def _prefit(self, Y, T, *args, **kwargs):
"we will disallow passing X, W, and Z by position.", ['X', 'W', 'Z'])
@BaseCateEstimator._wrap_fit
def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None):
cache_values=False, monte_carlo_iterations=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand All @@ -543,6 +544,8 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache the inputs and computed nuisances, which will allow refitting a different final model
monte_carlo_iterations: int, optional
The number of times to rerun the first stage models to reduce the variance of the nuisances.
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`).
Expand All @@ -555,6 +558,22 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
Y, T, X, W, Z, sample_weight, sample_var, groups = check_input_arrays(
Y, T, X, W, Z, sample_weight, sample_var, groups)
self._check_input_dims(Y, T, X, W, Z, sample_weight, sample_var, groups)

all_nuisances = []
fitted_inds = None

for _ in range(monte_carlo_iterations or 1):
nuisances, new_inds = self._fit_nuisances(Y, T, X, W, Z, sample_weight=sample_weight, groups=groups)
all_nuisances.append(nuisances)
if fitted_inds is None:
fitted_inds = new_inds
elif not np.array_equal(fitted_inds, new_inds):
raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated")

if monte_carlo_iterations is not None:
# TODO: support different ways to aggregate, like median?
nuisances = np.mean(np.array(all_nuisances), axis=0)

nuisances, fitted_inds = self._fit_nuisances(Y, T, X, W, Z, sample_weight=sample_weight, groups=groups)
Y, T, X, W, Z, sample_weight, sample_var = (self._subinds_check_none(arr, fitted_inds)
for arr in (Y, T, X, W, Z, sample_weight, sample_var))
Expand Down
7 changes: 5 additions & 2 deletions econml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(self, model_y, model_t, model_final,
@_deprecate_positional("X, and should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None):
cache_values=False, monte_carlo_iterations=None, inference=None):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.
Expand All @@ -310,6 +310,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache inputs and first stage results, which will allow refitting a different final model
monte_carlo_iterations: int, optional
The number of times to rerun the first stage models to reduce the variance of the nuisances.
inference: string,:class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of:class:`.BootstrapInference`).
Expand All @@ -321,7 +323,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
# Replacing fit from _OrthoLearner, to enforce Z=None and improve the docstring
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)
cache_values=cache_values, monte_carlo_iterations=monte_carlo_iterations,
inference=inference)

def score(self, Y, T, X=None, W=None):
"""
Expand Down
27 changes: 19 additions & 8 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _update_models(self):
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
cache_values=False, monte_carlo_iterations=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand All @@ -536,6 +536,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache inputs and first stage results, which will allow refitting a different final model
monte_carlo_iterations: int, optional
The number of times to rerun the first stage models to reduce the variance of the nuisances.
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`) and 'auto'
Expand All @@ -546,7 +548,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
self
"""
return super().fit(Y, T, X=X, W=W, sample_weight=sample_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)
cache_values=cache_values, monte_carlo_iterations=monte_carlo_iterations,
inference=inference)

@property
def linear_first_stages(self):
Expand Down Expand Up @@ -701,7 +704,7 @@ def __init__(self,
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
cache_values=False, monte_carlo_iterations=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand All @@ -725,6 +728,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache inputs and first stage results, which will allow refitting a different final model
monte_carlo_iterations: int, optional
The number of times to rerun the first stage models to reduce the variance of the nuisances.
inference: string, :class:`.Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`) and 'statsmodels'
Expand All @@ -736,7 +741,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
"""
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)
cache_values=cache_values, monte_carlo_iterations=monte_carlo_iterations,
inference=inference)

@DML.model_final.setter
def model_final(self, model):
Expand Down Expand Up @@ -861,7 +867,7 @@ def __init__(self,
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
cache_values=False, monte_carlo_iterations=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand All @@ -886,6 +892,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache inputs and first stage results, which will allow refitting a different final model
monte_carlo_iterations: int, optional
The number of times to rerun the first stage models to reduce the variance of the nuisances.
inference: string, `Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`) and 'debiasedlasso'
Expand All @@ -906,7 +914,7 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
"We recommend using the LinearDML estimator for this low-dimensional setting.")
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=None, groups=groups,
cache_values=cache_values, inference=inference)
cache_values=cache_values, monte_carlo_iterations=None, inference=inference)

@DML.model_final.setter
def model_final(self, model):
Expand Down Expand Up @@ -1417,7 +1425,7 @@ def __init__(self,
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference='auto'):
cache_values=False, monte_carlo_iterations=None, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
Expand All @@ -1442,6 +1450,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache inputs and first stage results, which will allow refitting a different final model
monte_carlo_iterations: int, optional
The number of times to rerun the first stage models to reduce the variance of the nuisances.
inference: string, `Inference` instance, or None
Method for performing inference. This estimator supports 'bootstrap'
(or an instance of :class:`.BootstrapInference`) and 'blb'
Expand All @@ -1453,7 +1463,8 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
"""
return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, sample_var=None, groups=groups,
cache_values=cache_values, inference=inference)
cache_values=cache_values, monte_carlo_iteration=monte_carlo_iterations,
inference=inference)

@DML.model_final.setter
def model_final(self, model):
Expand Down
8 changes: 8 additions & 0 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,3 +1112,11 @@ def test_refit(self):
ldml.refit()
with pytest.raises(Exception):
dml.refit()

def test_montecarlo(self):
"""Test that we can perform nuisance averaging."""
y = np.random.normal(size=30) + [0, 1] * 15
T = np.random.normal(size=(30,)) + y
W = np.random.normal(size=(30, 3))
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression())
est.fit(y, T, W=W, monte_carlo_iterations=2).effect()

0 comments on commit 3021200

Please sign in to comment.