From 19c003aee75af263e8897e1a6afea5323d6d107b Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Mon, 27 Mar 2023 17:25:33 +0200 Subject: [PATCH] Add inference utilities to transform between unconstrained and constrained space Improve and simplify constrain_fn and unconstrain_fn implementation Add missing doctstrings Constrain/unconstrain functions now always consider param sites Fix syntax for lint tests Fix syntax for lint tests Fix syntax for lint tests --- numpyro/infer/util.py | 44 ++++++++++++++++++++++++++++++++ test/infer/test_infer_util.py | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 019c463fc..4a343510b 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -181,6 +181,10 @@ def substitute_fn(site): if site["type"] == "sample": with helpful_support_errors(site): return biject_to(site["fn"].support)(params[site["name"]]) + elif site["type"] == "param": + constraint = site["kwargs"].pop("constraint", constraints.real) + with helpful_support_errors(site): + return biject_to(constraint)(params[site["name"]]) else: return params[site["name"]] @@ -193,6 +197,42 @@ def substitute_fn(site): } +def get_transforms(model, model_args, model_kwargs, params): + """ + (EXPERIMENTAL INTERFACE) Retrieve (inverse) transforms via biject_to() + given a NumPyro model. This function supports 'param' sites. + NB: Parameter values are only used to retrieve the model trace. + + :param model: a callable containing NumPyro primitives. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. + :param dict params: dictionary of values keyed by site names. + :return: `dict` of transformation keyed by site names. + """ + substituted_model = substitute(model, data=params) + transforms, _, _, _ = _get_model_transforms( + substituted_model, model_args, model_kwargs + ) + return transforms + + +def unconstrain_fn(model, model_args, model_kwargs, params): + """ + (EXPERIMENTAL INTERFACE) Given a NumPyro model and a dict of parameters, + this function applies the right transformation to convert parameter values + from constrained space to unconstrained space. + + :param model: a callable containing NumPyro primitives. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. + :param dict params: dictionary of constrained values keyed by site + names. + :return: `dict` of transformation keyed by site names. + """ + transforms = get_transforms(model, model_args, model_kwargs, params) + return transform_fn(transforms, params, invert=True) + + def _unconstrain_reparam(params, site): name = site["name"] if name in params: @@ -449,6 +489,10 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None): for arg in args: if not isinstance(getattr(support, arg), (int, float)): replay_model = True + elif v["type"] == "param": + constraint = v["kwargs"].pop("constraint", constraints.real) + with helpful_support_errors(v, raise_warnings=True): + inv_transforms[k] = biject_to(constraint) elif v["type"] == "deterministic": replay_model = True return inv_transforms, replay_model, has_enumerate_support, model_trace diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 0c1d3945f..4412916b8 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -32,6 +32,7 @@ log_likelihood, potential_energy, transform_fn, + unconstrain_fn, ) import numpyro.optim as optim @@ -220,6 +221,52 @@ def model(): assert_allclose(actual_potential_energy, expected_potential_energy) +def test_constrain_unconstrain(): + x_prior = dist.HalfNormal(2) + y_prior = dist.LogNormal(scale=3.0) # transformed distribution + z_constraint = constraints.positive + + def model(): + numpyro.sample("x", x_prior) + numpyro.sample("y", y_prior) + numpyro.param("z", init_value=2.0, constraint=z_constraint) + + params = {"x": jnp.array(-5.0), "y": jnp.array(7.0), "z": jnp.array(3.0)} + model = handlers.seed(model, random.PRNGKey(0)) + inv_transforms = { + "x": biject_to(x_prior.support), + "y": biject_to(y_prior.support), + "z": biject_to(z_constraint), + } + expected_constrained_samples = partial(transform_fn, inv_transforms)(params) + transforms = { + "x": biject_to(x_prior.support).inv, + "y": biject_to(y_prior.support).inv, + "z": biject_to(z_constraint).inv, + } + expected_unconstrained_samples = partial(transform_fn, transforms)( + expected_constrained_samples + ) + + actual_constrained_samples = constrain_fn(model, (), {}, params) + actual_unconstrained_samples = unconstrain_fn( + model, (), {}, actual_constrained_samples + ) + + assert_allclose(expected_constrained_samples["x"], actual_constrained_samples["x"]) + assert_allclose(expected_constrained_samples["y"], actual_constrained_samples["y"]) + assert_allclose(expected_constrained_samples["z"], actual_constrained_samples["z"]) + assert_allclose( + expected_unconstrained_samples["x"], actual_unconstrained_samples["x"] + ) + assert_allclose( + expected_unconstrained_samples["y"], actual_unconstrained_samples["y"] + ) + assert_allclose( + expected_unconstrained_samples["z"], actual_unconstrained_samples["z"] + ) + + def test_model_with_mask_false(): def model(): x = numpyro.sample("x", dist.Normal())