From 8ffb95181119f612521d14896e85a11ba7564ca6 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Mon, 29 Jul 2024 12:15:00 -0500 Subject: [PATCH] revert formatting --- pymc/smc/kernels.py | 5 ++++- tests/smc/test_smc.py | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 61aa403c890..fd1674b90f4 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -193,10 +193,13 @@ def initialize_population(self) -> dict[str, np.ndarray]: ) model = self.model + prior_expression = make_initial_point_expression( free_rvs=model.free_RVs, rvs_to_transforms=model.rvs_to_transforms, - initval_strategies={}, + initval_strategies={ + **model.rvs_to_initial_values, + }, default_strategy="prior", return_transformed=True, ) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 33c8718eae8..84a53695581 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -25,6 +25,7 @@ import pymc as pm from pymc.backends.base import MultiTrace +from pymc.distributions.transforms import Ordered from pymc.pytensorf import floatX from pymc.smc.kernels import IMH, systematic_resampling from tests.helpers import assert_random_state_equal @@ -269,6 +270,30 @@ def test_deprecated_abc_args(self): ): pm.sample_smc(draws=10, chains=1, save_log_pseudolikelihood=True) + def test_ordered(self): + """ + Test that initial population respects custom initval, especially when applied + to the Ordered transformation. Regression test for #7438. + """ + with pm.Model() as m: + pm.Normal( + "a", + mu=0.0, + sigma=1.0, + size=(2,), + transform=Ordered(), + initval=[-1.0, 1.0], + ) + + smc = IMH(model=m) + out = smc.initialize_population() + + # initial point should not include NaNs + assert not np.any(np.isnan(out["a_ordered__"])) + + # initial point should match for all particles + assert np.all(out["a_ordered__"][0] == out["a_ordered__"]) + class TestMHKernel: def test_normal_model(self):