Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Aug 5, 2020
1 parent b7c5982 commit 9835a06
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
17 changes: 11 additions & 6 deletions econml/tests/test_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,22 @@ class TestDML(unittest.TestCase):

def test_cate_api(self):
"""Test that we correctly implement the CATE API."""
n = 20
n_c = 20 # number of rows for continuous models
n_d = 30 # number of rows for discrete models

def make_random(is_discrete, d):
def make_random(n, is_discrete, d):
if d is None:
return None
sz = (n, d) if d >= 0 else (n,)
if is_discrete:
while True:
arr = np.random.choice(['a', 'b', 'c'], size=sz)
# ensure that we've got at least two of every element
# ensure that we've got at least 6 of every element
# 2 outer splits, 3 inner splits when model_t is 'auto' and treatment is discrete
# NOTE: this number may need to change if the default number of folds in
# WeightedStratifiedKFold changes
_, counts = np.unique(arr, return_counts=True)
if len(counts) == 3 and counts.min() > 1:
if len(counts) == 3 and counts.min() > 5:
return arr
else:
return np.random.normal(size=sz)
Expand All @@ -55,7 +59,8 @@ def make_random(is_discrete, d):
for d_y in [3, 1, -1]:
for d_x in [2, None]:
for d_w in [2, None]:
W, X, Y, T = [make_random(is_discrete, d)
n = n_d if is_discrete else n_c
W, X, Y, T = [make_random(n, is_discrete, d)
for is_discrete, d in [(False, d_w),
(False, d_x),
(False, d_y),
Expand Down Expand Up @@ -699,7 +704,7 @@ def test_can_custom_splitter(self):
def test_can_use_featurizer(self):
"Test that we can use a featurizer, and that fit is only called during training"
dml = LinearDMLCateEstimator(LinearRegression(), LinearRegression(),
fit_cate_intercept=False, featurizer=OneHotEncoder(n_values='auto', sparse=False))
fit_cate_intercept=False, featurizer=OneHotEncoder(sparse=False))

T = np.tile([1, 2, 3], 6)
Y = np.array([1, 2, 3, 1, 2, 3])
Expand Down
2 changes: 1 addition & 1 deletion econml/tests/test_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def test_sparse(self):
y_lower, y_upper = sparse_dml.effect_interval(x_test, T0=0, T1=1)
in_CI = ((y_lower < true_eff) & (true_eff < y_upper))
# Check that a majority of true effects lie in the 5-95% CI
self.assertTrue(in_CI.mean() > 0.8)
self.assertGreater(in_CI.mean(), 0.8)

def _test_te(self, learner_instance, tol, te_type="const"):
if te_type not in ["const", "heterogeneous"]:
Expand Down
20 changes: 17 additions & 3 deletions econml/tests/test_orf.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ def test_effect_shape(self):

def test_nuisance_model_has_weights(self):
"""Test whether the correct exception is being raised if model_final doesn't have weights."""

# Create a wrapper around Lasso that doesn't support weights
# since Lasso does natively support them starting in sklearn 0.23
class NoWeightModel:
def __init__(self):
self.model = Lasso()

def fit(self, X, y):
self.model.fit(X, y)
return self

def predict(self, X):
return self.model.predict(X)

# Generate data with continuous treatments
T = np.dot(TestOrthoForest.W[:, TestOrthoForest.support], TestOrthoForest.coefs_T) + \
TestOrthoForest.eta_sample(TestOrthoForest.n)
Expand All @@ -192,14 +206,14 @@ def test_nuisance_model_has_weights(self):
T * TE + TestOrthoForest.epsilon_sample(TestOrthoForest.n)
# Instantiate model with most of the default parameters
est = ContinuousTreatmentOrthoForest(n_jobs=4, n_trees=10,
model_T=Lasso(),
model_Y=Lasso())
model_T=NoWeightModel(),
model_Y=NoWeightModel())
est.fit(Y=Y, T=T, X=TestOrthoForest.X, W=TestOrthoForest.W)
weights_error_msg = (
"Estimators of type {} do not accept weights. "
"Consider using the class WeightedModelWrapper from econml.utilities to build a weighted model."
)
self.assertRaisesRegexp(TypeError, weights_error_msg.format("Lasso"),
self.assertRaisesRegexp(TypeError, weights_error_msg.format("NoWeightModel"),
est.effect, X=TestOrthoForest.X)

def _test_te(self, learner_instance, expected_te, tol, treatment_type='continuous'):
Expand Down

0 comments on commit 9835a06

Please sign in to comment.