Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add init_params argument to svi.init() and svi.run() #1561

Merged
merged 3 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,32 @@ def __init__(self, model, guide, optim, loss, **static_kwargs):

self.optim = optax_to_numpyro(optim)

def init(self, rng_key, *args, **kwargs):
def init(self, rng_key, *args, init_params=None, **kwargs):
"""
Gets the initial SVI state.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from
this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives.
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: the initial :data:`SVIState`
"""
rng_key, model_seed, guide_seed = random.split(rng_key, 3)
model_init = seed(self.model, model_seed)
guide_init = seed(self.guide, guide_seed)
if init_params is not None:
guide_init = substitute(guide_init, init_params)
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
init_guide_params = {
name: site["value"]
for name, site in guide_trace.items()
if site["type"] == "param"
}
if init_params is not None:
init_guide_params.update(init_params)
model_trace = trace(
substitute(replay(model_init, guide_trace), init_guide_params)
).get_trace(*args, **kwargs, **self.static_kwargs)
Expand Down Expand Up @@ -305,6 +311,7 @@ def run(
progress_bar=True,
stable_update=False,
init_state=None,
init_params=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -333,6 +340,8 @@ def run(
# continue from the end of the previous svi run rather than beginning again from iteration 0
svi_result = svi.run(random.PRNGKey(1), 2000, data, init_state=svi_result.state)
:param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from
this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives.
:param kwargs: keyword arguments to the model / guide
:return: a namedtuple with fields `params` and `losses` where `params`
holds the optimized values at :class:`numpyro.param` sites,
Expand All @@ -351,7 +360,7 @@ def body_fn(svi_state, _):
return svi_state, loss

if init_state is None:
svi_state = self.init(rng_key, *args, **kwargs)
svi_state = self.init(rng_key, *args, init_params=init_params, **kwargs)
else:
svi_state = init_state
if progress_bar:
Expand Down
24 changes: 24 additions & 0 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from numpyro.util import fori_loop


def assert_equal(a, b, prec=0):
return jax.tree_util.tree_map(lambda a, b: assert_allclose(a, b, atol=prec), a, b)


@pytest.mark.parametrize("alpha", [0.0, 2.0])
def test_renyi_elbo(alpha):
def model(x):
Expand Down Expand Up @@ -224,6 +228,26 @@ def guide():
assert_allclose(svi_result.params["shared"], target_value, atol=0.1)


def test_init_params():
init_params = {"b": 1.0, "c": 2.0}

def model():
numpyro.param("a", 0.0)
# should receive initial value from init_params
numpyro.param("b")

def guide():
# should receive initial value from init_params
numpyro.param("c")

svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0), init_params=init_params)
params = svi.get_params(svi_state)
init_params["a"] = 0.0
# make sure init params ended up in the SVI state
assert_equal(params, init_params)


def test_elbo_dynamic_support():
x_prior = dist.TransformedDistribution(
dist.Normal(),
Expand Down