Skip to content

Commit

Permalink
Add tests related to named models
Browse files Browse the repository at this point in the history
A test was added to asserts that strace.name is a string (this was not the case before pymc-devs#4365).

Non-empty model names are actually not supported (again, see pymc-devs#4365) so attempting to SMC-sample a named model will now raise a NotImplementedError.
  • Loading branch information
basnijholt authored and michaelosthege committed Jan 15, 2021
1 parent 3944c07 commit 624d337
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pymc3/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def sample_smc(
_log.info("Initializing SMC sampler...")

model = modelcontext(model)
if model.name:
raise NotImplementedError(
"The SMC implementation currently does not support named models. "
"See https://github.com/pymc-devs/pymc3/pull/4365."
)
if cores is None:
cores = _cpu_count()

Expand Down
2 changes: 1 addition & 1 deletion pymc3/smc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def posterior_to_trace(self):
varnames = [v.name for v in self.variables]

with self.model:
strace = NDArray(self.model.name)
strace = NDArray(name=self.model.name)
strace.setup(lenght_pos, self.chain)
for i in range(lenght_pos):
value = []
Expand Down
21 changes: 21 additions & 0 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pytest
import theano.tensor as tt

import pymc3 as pm
Expand Down Expand Up @@ -189,3 +190,23 @@ def test_repr_latex(self):
assert expected == self.s._repr_latex_()
assert self.s._repr_latex_() == self.s.__latex__()
assert self.SMABC_test.model._repr_latex_() == self.SMABC_test.model.__latex__()

def test_name_is_string_type(self):
with self.SMABC_potential:
assert not self.SMABC_potential.name
trace = pm.sample_smc(draws=10, kernel="ABC")
assert isinstance(trace._straces[0].name, str)

def test_named_models_are_unsupported(self):
def normal_sim(a, b):
return np.random.normal(a, b, 1000)

with pm.Model(name="NamedModel"):
a = pm.Normal("a", mu=0, sigma=1)
b = pm.HalfNormal("b", sigma=1)
c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf))
s = pm.Simulator(
"s", normal_sim, params=(a, b), sum_stat="sort", epsilon=1, observed=self.data
)
with pytest.raises(NotImplementedError, match="named models"):
pm.sample_smc(draws=10, kernel="ABC")

0 comments on commit 624d337

Please sign in to comment.