Skip to content

Commit

Permalink
revert formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tvwenger authored and aloctavodia committed Aug 1, 2024
1 parent 8eaa9be commit 8ffb951
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,13 @@ def initialize_population(self) -> dict[str, np.ndarray]:
)

model = self.model

prior_expression = make_initial_point_expression(
free_rvs=model.free_RVs,
rvs_to_transforms=model.rvs_to_transforms,
initval_strategies={},
initval_strategies={
**model.rvs_to_initial_values,
},
default_strategy="prior",
return_transformed=True,
)
Expand Down
25 changes: 25 additions & 0 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pymc as pm

from pymc.backends.base import MultiTrace
from pymc.distributions.transforms import Ordered
from pymc.pytensorf import floatX
from pymc.smc.kernels import IMH, systematic_resampling
from tests.helpers import assert_random_state_equal
Expand Down Expand Up @@ -269,6 +270,30 @@ def test_deprecated_abc_args(self):
):
pm.sample_smc(draws=10, chains=1, save_log_pseudolikelihood=True)

def test_ordered(self):
"""
Test that initial population respects custom initval, especially when applied
to the Ordered transformation. Regression test for #7438.
"""
with pm.Model() as m:
pm.Normal(
"a",
mu=0.0,
sigma=1.0,
size=(2,),
transform=Ordered(),
initval=[-1.0, 1.0],
)

smc = IMH(model=m)
out = smc.initialize_population()

# initial point should not include NaNs
assert not np.any(np.isnan(out["a_ordered__"]))

# initial point should match for all particles
assert np.all(out["a_ordered__"][0] == out["a_ordered__"])


class TestMHKernel:
def test_normal_model(self):
Expand Down

0 comments on commit 8ffb951

Please sign in to comment.