diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index eca928c5c..e20bfabdb 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -169,13 +169,15 @@ def __init__(self, model, guide, optim, loss, **static_kwargs): self.optim = optax_to_numpyro(optim) - def init(self, rng_key, *args, **kwargs): + def init(self, rng_key, *args, init_params=None, **kwargs): """ Gets the initial SVI state. :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). + :param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from + this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives. :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: the initial :data:`SVIState` @@ -183,12 +185,16 @@ def init(self, rng_key, *args, **kwargs): rng_key, model_seed, guide_seed = random.split(rng_key, 3) model_init = seed(self.model, model_seed) guide_init = seed(self.guide, guide_seed) + if init_params is not None: + guide_init = substitute(guide_init, init_params) guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs) init_guide_params = { name: site["value"] for name, site in guide_trace.items() if site["type"] == "param" } + if init_params is not None: + init_guide_params.update(init_params) model_trace = trace( substitute(replay(model_init, guide_trace), init_guide_params) ).get_trace(*args, **kwargs, **self.static_kwargs) @@ -305,6 +311,7 @@ def run( progress_bar=True, stable_update=False, init_state=None, + init_params=None, **kwargs, ): """ @@ -333,6 +340,8 @@ def run( # continue from the end of the previous svi run rather than beginning again from iteration 0 svi_result = svi.run(random.PRNGKey(1), 2000, data, init_state=svi_result.state) + :param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from + this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives. :param kwargs: keyword arguments to the model / guide :return: a namedtuple with fields `params` and `losses` where `params` holds the optimized values at :class:`numpyro.param` sites, @@ -351,7 +360,7 @@ def body_fn(svi_state, _): return svi_state, loss if init_state is None: - svi_state = self.init(rng_key, *args, **kwargs) + svi_state = self.init(rng_key, *args, init_params=init_params, **kwargs) else: svi_state = init_state if progress_bar: diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 2db18a3ca..929f923b3 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -36,6 +36,10 @@ from numpyro.util import fori_loop +def assert_equal(a, b, prec=0): + return jax.tree_util.tree_map(lambda a, b: assert_allclose(a, b, atol=prec), a, b) + + @pytest.mark.parametrize("alpha", [0.0, 2.0]) def test_renyi_elbo(alpha): def model(x): @@ -224,6 +228,26 @@ def guide(): assert_allclose(svi_result.params["shared"], target_value, atol=0.1) +def test_init_params(): + init_params = {"b": 1.0, "c": 2.0} + + def model(): + numpyro.param("a", 0.0) + # should receive initial value from init_params + numpyro.param("b") + + def guide(): + # should receive initial value from init_params + numpyro.param("c") + + svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0), init_params=init_params) + params = svi.get_params(svi_state) + init_params["a"] = 0.0 + # make sure init params ended up in the SVI state + assert_equal(params, init_params) + + def test_elbo_dynamic_support(): x_prior = dist.TransformedDistribution( dist.Normal(),