Skip to content

Commit

Permalink
Re-enable Arviz tests in pymc3.tests.test_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 26, 2021
1 parent e01a473 commit 1a604c3
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

import aesara
import aesara.tensor as aet
import arviz as az
import numpy as np
import numpy.testing as npt
import pytest

from aesara import shared
from arviz import InferenceData
from arviz import from_dict as az_from_dict
from scipy import stats

import pymc3 as pm
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_return_inferencedata(self, monkeypatch):

# inferencedata with tuning
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
assert isinstance(result, az.InferenceData)
assert isinstance(result, InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) > 0
Expand All @@ -215,7 +216,7 @@ def test_return_inferencedata(self, monkeypatch):
random_seed=-1
)
assert "prior" in result
assert isinstance(result, az.InferenceData)
assert isinstance(result, InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) == 0
Expand Down Expand Up @@ -458,20 +459,26 @@ def test_normal_scalar(self):
ppc = pm.sample_posterior_predictive(trace, size=5, var_names=["a"])
assert ppc["a"].shape == (nchains * ndraws, 5)

@pytest.mark.xfail(reason="Arviz not refactored for v4")
def test_normal_scalar_idata(self):
nchains = 2
ndraws = 500
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
trace = pm.sample(
draws=ndraws, chains=nchains, return_inferencedata=True, discard_tuned_samples=False
draws=ndraws,
chains=nchains,
return_inferencedata=False,
discard_tuned_samples=False,
)

assert not isinstance(trace, InferenceData)

with model:
# test keep_size parameter and idata input
idata = pm.to_inference_data(trace)
assert isinstance(idata, InferenceData)

ppc = pm.sample_posterior_predictive(idata, keep_size=True)
assert ppc["a"].shape == (nchains, ndraws)

Expand Down Expand Up @@ -505,16 +512,19 @@ def test_normal_vector(self, caplog):
assert "a" in ppc
assert ppc["a"].shape == (10, 4, 2)

@pytest.mark.xfail(reason="Arviz not refactored for v4")
def test_normal_vector_idata(self, caplog):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2]))
trace = pm.sample(return_inferencedata=False)

assert not isinstance(trace, InferenceData)

with model:
# test keep_size parameter with inference data as input...
idata = pm.to_inference_data(trace)
assert isinstance(idata, InferenceData)

ppc = pm.sample_posterior_predictive(idata, keep_size=True)
assert ppc["a"].shape == (trace.nchains, len(trace), 2)

Expand Down Expand Up @@ -703,7 +713,7 @@ def test_potentials_warning(self):
p = pm.Potential("p", a + 1)
obs = pm.Normal("obs", a, 1, observed=5)

trace = az.from_dict({"a": np.random.rand(10)})
trace = az_from_dict({"a": np.random.rand(10)})
with m:
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_posterior_predictive(trace, samples=5)
Expand Down Expand Up @@ -768,7 +778,7 @@ def test_potentials_warning(self):
p = pm.Potential("p", a + 1)
obs = pm.Normal("obs", a, 1, observed=5)

trace = az.from_dict({"a": np.random.rand(10)})
trace = az_from_dict({"a": np.random.rand(10)})
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])

Expand Down Expand Up @@ -1031,17 +1041,17 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
with pmodel:
pp = pm.sample_posterior_predictive([trace[15]], var_names=["d"])

@pytest.mark.xfail(reason="Arviz not refactored for v4")
def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture

with pmodel:
prior = pm.sample_prior_predictive(samples=20)

idat = pm.to_inference_data(trace, prior=prior)

with pmodel:
pp = pm.sample_posterior_predictive(idat.prior, var_names=["d"])

@pytest.mark.xfail(reason="Arviz not refactored for v4")
def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
idat = pm.to_inference_data(trace)
Expand Down

0 comments on commit 1a604c3

Please sign in to comment.