Skip to content

Commit

Permalink
make doctests pass for pymc_model.py #323
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed May 6, 2024
1 parent 15b5756 commit 44393bb
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,17 @@ class ModelBuilder(pm.Model):
... }
... )
>>> model.fit(X, y)
Inference...
<BLANKLINE>
<BLANKLINE>
Inference data...
>>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
>>> model.predict(X_new)
Inference...
>>> model.score(X, y) # doctest: +NUMBER
r2 0.3
r2_std 0.0
<BLANKLINE>
Inference data...
>>> model.score(X, y)
<BLANKLINE>
r2 0.390344
r2_std 0.081135
dtype: float64
"""

Expand Down Expand Up @@ -112,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:

# Ensure random_seed is used in sample_prior_predictive() and
# sample_posterior_predictive() if provided in sample_kwargs.
if "random_seed" in self.sample_kwargs:
random_seed = self.sample_kwargs["random_seed"]
else:
random_seed = None
random_seed = self.sample_kwargs.get("random_seed", None)

self.build_model(X, y, coords)
with self:
Expand All @@ -137,10 +138,17 @@ def predict(self, X):
"""

# Ensure random_seed is used in sample_prior_predictive() and
# sample_posterior_predictive() if provided in sample_kwargs.
random_seed = self.sample_kwargs.get("random_seed", None)

self._data_setter(X)
with self: # sample with new input data
post_pred = pm.sample_posterior_predictive(
self.idata, var_names=["y_hat", "mu"], progressbar=False
self.idata,
var_names=["y_hat", "mu"],
progressbar=False,
random_seed=random_seed,
)
return post_pred

Expand Down Expand Up @@ -193,7 +201,9 @@ class WeightedSumFitter(ModelBuilder):
>>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
>>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
>>> wsf.fit(X,y)
Inference ...
<BLANKLINE>
<BLANKLINE>
Inference data...
""" # noqa: W605

def build_model(self, X, y, coords):
Expand Down Expand Up @@ -249,7 +259,9 @@ class LinearRegression(ModelBuilder):
... 'obs_indx': np.arange(rd.shape[0])
... },
... )
Inference...
<BLANKLINE>
<BLANKLINE>
Inference data...
""" # noqa: W605

def build_model(self, X, y, coords):
Expand Down Expand Up @@ -301,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder):
... "eta": 2,
... "lkj_sd": 2,
... })
<BLANKLINE>
<BLANKLINE>
Inference data...
"""

Expand Down

0 comments on commit 44393bb

Please sign in to comment.