From 44393bb3baa03569fe86d00c640e969cb730376c Mon Sep 17 00:00:00 2001 From: "Benjamin T. Vincent" Date: Mon, 6 May 2024 14:42:59 +0100 Subject: [PATCH] make doctests pass for pymc_model.py #323 --- causalpy/pymc_models.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/causalpy/pymc_models.py b/causalpy/pymc_models.py index 1d7a545d..d4fc95ee 100644 --- a/causalpy/pymc_models.py +++ b/causalpy/pymc_models.py @@ -72,13 +72,17 @@ class ModelBuilder(pm.Model): ... } ... ) >>> model.fit(X, y) - Inference... + + + Inference data... >>> X_new = rng.normal(loc=0, scale=1, size=(20,2)) >>> model.predict(X_new) - Inference... - >>> model.score(X, y) # doctest: +NUMBER - r2 0.3 - r2_std 0.0 + + Inference data... + >>> model.score(X, y) + + r2 0.390344 + r2_std 0.081135 dtype: float64 """ @@ -112,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None: # Ensure random_seed is used in sample_prior_predictive() and # sample_posterior_predictive() if provided in sample_kwargs. - if "random_seed" in self.sample_kwargs: - random_seed = self.sample_kwargs["random_seed"] - else: - random_seed = None + random_seed = self.sample_kwargs.get("random_seed", None) self.build_model(X, y, coords) with self: @@ -137,10 +138,17 @@ def predict(self, X): """ + # Ensure random_seed is used in sample_prior_predictive() and + # sample_posterior_predictive() if provided in sample_kwargs. + random_seed = self.sample_kwargs.get("random_seed", None) + self._data_setter(X) with self: # sample with new input data post_pred = pm.sample_posterior_predictive( - self.idata, var_names=["y_hat", "mu"], progressbar=False + self.idata, + var_names=["y_hat", "mu"], + progressbar=False, + random_seed=random_seed, ) return post_pred @@ -193,7 +201,9 @@ class WeightedSumFitter(ModelBuilder): >>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1)) >>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False}) >>> wsf.fit(X,y) - Inference ... + + + Inference data... """ # noqa: W605 def build_model(self, X, y, coords): @@ -249,7 +259,9 @@ class LinearRegression(ModelBuilder): ... 'obs_indx': np.arange(rd.shape[0]) ... }, ... ) - Inference... + + + Inference data... """ # noqa: W605 def build_model(self, X, y, coords): @@ -301,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder): ... "eta": 2, ... "lkj_sd": 2, ... }) + + Inference data... """