Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce custom initval in SMC #7439

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ class SMC_KERNEL(ABC):
initialize_population
Choose initial population of SMC particles. Should return a dictionary
with {var.name : numpy array of size (draws, var.size)}. Defaults
to sampling from the prior distribution. This method is only called
if `start` is not specified.
to sampling from the prior distribution, except for parameters which have custom
`initval`, in which case that value is used for all SMC particles.
This method is only called if `start` is not specified.

_initialize_kernel : default
Creates initial population of particles in the variable
Expand Down Expand Up @@ -144,7 +145,8 @@ def __init__(
independent chains. Defaults to 2000.
start : dict, or array of dict, default None
Starting point in parameter space. It should be a list of dict with length `chains`.
When None (default) the starting point is sampled from the prior distribution.
When None (default) the starting point is sampled from the prior distribution, except
for parameters with a custom `initval`, in which case that value is used.
model : Model (optional if in ``with`` context).
random_seed : int, array_like of int, RandomState or Generator, optional
Value used to initialize the random number generator.
Expand Down Expand Up @@ -193,10 +195,13 @@ def initialize_population(self) -> dict[str, np.ndarray]:
)

model = self.model

prior_expression = make_initial_point_expression(
free_rvs=model.free_RVs,
rvs_to_transforms=model.rvs_to_transforms,
initval_strategies={},
initval_strategies={
**model.rvs_to_initial_values,
},
default_strategy="prior",
return_transformed=True,
)
Expand Down
25 changes: 25 additions & 0 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pymc as pm

from pymc.backends.base import MultiTrace
from pymc.distributions.transforms import Ordered
from pymc.pytensorf import floatX
from pymc.smc.kernels import IMH, systematic_resampling
from tests.helpers import assert_random_state_equal
Expand Down Expand Up @@ -269,6 +270,30 @@ def test_deprecated_abc_args(self):
):
pm.sample_smc(draws=10, chains=1, save_log_pseudolikelihood=True)

def test_ordered(self):
"""
Test that initial population respects custom initval, especially when applied
to the Ordered transformation. Regression test for #7438.
"""
with pm.Model() as m:
pm.Normal(
"a",
mu=0.0,
sigma=1.0,
size=(2,),
transform=Ordered(),
initval=[-1.0, 1.0],
)

smc = IMH(model=m)
out = smc.initialize_population()

# initial point should not include NaNs
assert not np.any(np.isnan(out["a_ordered__"]))

# initial point should match for all particles
assert np.all(out["a_ordered__"][0] == out["a_ordered__"])


class TestMHKernel:
def test_normal_model(self):
Expand Down
Loading