Skip to content

Commit

Permalink
revert formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tvwenger committed Jul 29, 2024
1 parent 8eaa9be commit defe39d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
5 changes: 4 additions & 1 deletion pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,13 @@ def initialize_population(self) -> dict[str, np.ndarray]:
)

model = self.model

prior_expression = make_initial_point_expression(
free_rvs=model.free_RVs,
rvs_to_transforms=model.rvs_to_transforms,
initval_strategies={},
initval_strategies={
**model.rvs_to_initial_values,
},
default_strategy="prior",
return_transformed=True,
)
Expand Down
38 changes: 35 additions & 3 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pymc as pm

from pymc.backends.base import MultiTrace
from pymc.distributions.transforms import Ordered
from pymc.pytensorf import floatX
from pymc.smc.kernels import IMH, systematic_resampling
from tests.helpers import assert_random_state_equal
Expand Down Expand Up @@ -79,7 +80,9 @@ def test_sample(self):
initial_rng_state = np.random.get_state()
with self.SMC_test:
mtrace = pm.sample_smc(
draws=self.samples, return_inferencedata=False, progressbar=not _IS_WINDOWS
draws=self.samples,
return_inferencedata=False,
progressbar=not _IS_WINDOWS,
)

# Verify sampling was done with a non-global random generator
Expand Down Expand Up @@ -148,7 +151,10 @@ def test_marginal_likelihood(self):
a = pm.Beta("a", alpha, beta)
y = pm.Bernoulli("y", a, observed=data)
trace = pm.sample_smc(
2000, chains=2, return_inferencedata=False, progressbar=not _IS_WINDOWS
2000,
chains=2,
return_inferencedata=False,
progressbar=not _IS_WINDOWS,
)
# log_marginal_likelihood is found in the last value of each chain
lml = np.mean([chain[-1] for chain in trace.report.log_marginal_likelihood])
Expand Down Expand Up @@ -211,7 +217,9 @@ def test_return_datatype(self, chains):
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
idata = pm.sample_smc(
chains=chains, draws=draws, progressbar=not (chains > 1 and _IS_WINDOWS)
chains=chains,
draws=draws,
progressbar=not (chains > 1 and _IS_WINDOWS),
)
mt = pm.sample_smc(
chains=chains,
Expand Down Expand Up @@ -269,6 +277,30 @@ def test_deprecated_abc_args(self):
):
pm.sample_smc(draws=10, chains=1, save_log_pseudolikelihood=True)

def test_ordered(self):
"""
Test that initial population respects custom initval, especially when applied
to the Ordered transformation. Regression test for #7438.
"""
with pm.Model() as m:
pm.Normal(
"a",
mu=0.0,
sigma=1.0,
size=(2,),
transform=Ordered(),
initval=[-1.0, 1.0],
)

smc = IMH(model=m)
out = smc.initialize_population()

# initial point should not include NaNs
assert not np.any(np.isnan(out["a_ordered__"]))

# initial point should match for all particles
assert np.all(out["a_ordered__"][0] == out["a_ordered__"])


class TestMHKernel:
def test_normal_model(self):
Expand Down

0 comments on commit defe39d

Please sign in to comment.