Skip to content

Commit

Permalink
Refit test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Dec 23, 2020
1 parent 7b98d46 commit f636dab
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 18 deletions.
4 changes: 2 additions & 2 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ def refit(self, inference=None):
X=cached.X,
W=cached.W,
Z=cached.Z,
nuisances=cached.nuisances,
sample_weight=cached.sample_weight,
sample_var=cached.sample_var
)
Expand All @@ -633,11 +632,12 @@ def refit(self, inference=None):
# fit only the final model
self._fit_final(cached.Y,
cached.T,
nuisances=cached.nuisances,
**kwargs)

if inference is not None:
# NOTE: we call inference fit *after* calling the main fit method
inference.fit(self, Y, T, *args, **kwargs)
inference.fit(self, cached.Y, cached.T, **kwargs)
self._inference = inference

return self
Expand Down
2 changes: 1 addition & 1 deletion econml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def predict(self, X):

def __init__(self, model_y, model_t, model_final,
discrete_treatment, categories, n_splits, random_state, monte_carlo_iterations=None):
self._rlearner_model_final = _model_final
self._rlearner_model_final = model_final
self._rlearner_model_y = model_y
self._rlearner_model_t = model_t
super().__init__(_ModelNuisance(clone(model_y, safe=False), clone(model_t, safe=False)),
Expand Down
14 changes: 8 additions & 6 deletions econml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,15 @@ def featurizer(self):
@property
def model_final(self):
# NOTE This is used by the inference methods and is more for internal use to the library
return self._model_final
# We need to use the rlearner's copy to retain the information from fitting
return self.rlearner_model_final._model

@model_final.setter
def model_final(self, model):
model = _FinalWrapper(model,
fit_cate_intercept=super().model_final._fit_cate_intercept,
featurizer=super().model_final._original_featurizer,
use_weight_trick=super().model_final._use_weight_trick)
fit_cate_intercept=self.rlearner_model_final._fit_cate_intercept,
featurizer=self.rlearner_model_final._original_featurizer,
use_weight_trick=self.rlearner_model_final._use_weight_trick)
self._rlearner_model_final = model

@_RLearner.rlearner_model_final.setter
Expand Down Expand Up @@ -496,11 +497,12 @@ def _prepare_model_t(self, model_t):
random_state=self._random_state)
else:
model_t = WeightedLassoCVWrapper(random_state=self._random_state)
return _FirstStageWrapper(model_t, False, self._featurizer, self._linear_first_stages, self._discrete_treatment)
return _FirstStageWrapper(model_t, False, self._featurizer,
self._linear_first_stages, self._discrete_treatment)

def _prepare_final_model(self, model):
self._model_final = model
return _FinalWrapper(self.model_final, self.fit_cate_intercept, self._featurizer, False)
return _FinalWrapper(self._model_final, self._fit_cate_intercept, self._featurizer, False)

def _update_models(self):
self._rlearner_model_y = self._prepare_model_y(self.model_y)
Expand Down
19 changes: 18 additions & 1 deletion econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from econml.dml import (DML, LinearDML, SparseLinearDML, KernelDML, NonParamDML, ForestDML)
import numpy as np
from econml.utilities import shape, hstack, vstack, reshape, cross_product
from econml.inference import BootstrapInference
from econml.inference import BootstrapInference, EmpiricalInferenceResults, NormalInferenceResults
from contextlib import ExitStack
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
import itertools
Expand Down Expand Up @@ -1141,3 +1141,20 @@ def test_montecarlo(self):
v2s = [np.var([est.fit(y, T, W=W).effect() for _ in range(10)]) for _ in range(10)]
# The average variance should be lower when using monte carlo iterations
assert np.mean(v2s) < np.mean(v1s)

def test_refit_inference(self):
"""Test that we can perform inference during refit"""
est = LinearDML(linear_first_stages=False, featurizer=PolynomialFeatures(1, include_bias=False))

X = np.random.choice(np.arange(5), size=(500, 3))
y = np.random.normal(size=(500,))
T = np.random.choice(np.arange(3), size=(500, 2))
W = np.random.normal(size=(500, 2))

est.fit(y, T, X=X, W=W, cache_values=True, inference='statsmodels')

assert isinstance(est.effect_inference(X), NormalInferenceResults)

est.refit(inference=BootstrapInference(2))

assert isinstance(est.effect_inference(X), EmpiricalInferenceResults)
14 changes: 7 additions & 7 deletions econml/tests/test_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def score(self, Y, T, W=None, nuisances=None):
np.testing.assert_array_almost_equal(est.effect(T0=0, T1=10), np.ones(1) * 10, decimal=2)
np.testing.assert_almost_equal(est.score(y, X[:, 0], W=X[:, 1:]), sigma**2, decimal=3)
np.testing.assert_almost_equal(est.score_, sigma**2, decimal=3)
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)
# Nuisance model has no score method, so nuisance_scores_ should be none
assert est.nuisance_scores_ is None

Expand All @@ -191,7 +191,7 @@ def score(self, Y, T, W=None, nuisances=None):
np.testing.assert_array_almost_equal(est.effect(T0=0, T1=10), np.ones(1) * 10, decimal=2)
np.testing.assert_almost_equal(est.score(y, X[:, 0], None, X[:, 1:]), sigma**2, decimal=3)
np.testing.assert_almost_equal(est.score_, sigma**2, decimal=3)
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)

# Test custom splitter
np.random.seed(123)
Expand All @@ -208,7 +208,7 @@ def score(self, Y, T, W=None, nuisances=None):
np.testing.assert_array_almost_equal(est.effect(T0=0, T1=10), np.ones(1) * 10, decimal=2)
np.testing.assert_almost_equal(est.score(y, X[:, 0], W=X[:, 1:]), sigma**2, decimal=3)
np.testing.assert_almost_equal(est.score_, sigma**2, decimal=3)
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)

# Test incomplete set of test folds
np.random.seed(123)
Expand All @@ -225,7 +225,7 @@ def score(self, Y, T, W=None, nuisances=None):
np.testing.assert_array_almost_equal(est.effect(T0=0, T1=10), np.ones(1) * 10, decimal=2)
np.testing.assert_almost_equal(est.score(y, X[:, 0], W=X[:, 1:]), sigma**2, decimal=3)
np.testing.assert_almost_equal(est.score_, sigma**2, decimal=3)
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)

def test_ol_no_score_final(self):
class ModelNuisance:
Expand Down Expand Up @@ -266,7 +266,7 @@ def predict(self, X=None):
np.testing.assert_array_almost_equal(est.effect(), np.ones(1), decimal=3)
np.testing.assert_array_almost_equal(est.effect(T0=0, T1=10), np.ones(1) * 10, decimal=2)
assert est.score_ is None
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)

def test_ol_nuisance_scores(self):
class ModelNuisance:
Expand Down Expand Up @@ -309,7 +309,7 @@ def predict(self, X=None):
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
np.testing.assert_array_almost_equal(est.effect(), np.ones(1), decimal=3)
np.testing.assert_array_almost_equal(est.effect(T0=0, T1=10), np.ones(1) * 10, decimal=2)
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)
nuisance_scores_y = est.nuisance_scores_[0]
nuisance_scores_t = est.nuisance_scores_[1]
assert len(nuisance_scores_y) == len(nuisance_scores_t) == 2 # as many scores as splits
Expand Down Expand Up @@ -363,4 +363,4 @@ def score(self, Y, T, W=None, nuisances=None):
np.testing.assert_almost_equal(est.const_marginal_effect(), 1, decimal=3)
np.testing.assert_array_almost_equal(est.effect(), np.ones(1), decimal=3)
np.testing.assert_almost_equal(est.score(y, T, W=X), sigma**2, decimal=3)
np.testing.assert_almost_equal(est.model_final.model.coef_[0], 1, decimal=3)
np.testing.assert_almost_equal(est.ortho_learner_model_final.model.coef_[0], 1, decimal=3)
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ setup_requires =
install_requires =
numpy
scipy != 1.4.0
scikit-learn > 0.21.0
scikit-learn > 0.21.0, < 0.24
keras < 2.4
sparse
tensorflow > 1.10, < 2.3
Expand Down

0 comments on commit f636dab

Please sign in to comment.