Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference utilities to transform between unconstrained and constrained space #1564

Merged
merged 1 commit into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def substitute_fn(site):
if site["type"] == "sample":
with helpful_support_errors(site):
return biject_to(site["fn"].support)(params[site["name"]])
elif site["type"] == "param":
constraint = site["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(site):
return biject_to(constraint)(params[site["name"]])
else:
return params[site["name"]]

Expand All @@ -193,6 +197,42 @@ def substitute_fn(site):
}


def get_transforms(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Retrieve (inverse) transforms via biject_to()
given a NumPyro model. This function supports 'param' sites.
NB: Parameter values are only used to retrieve the model trace.

:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of values keyed by site names.
:return: `dict` of transformation keyed by site names.
"""
substituted_model = substitute(model, data=params)
transforms, _, _, _ = _get_model_transforms(
substituted_model, model_args, model_kwargs
)
return transforms

Copy link
Member

@fehiepsi fehiepsi Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you run make format to fix lint issue? I guess you need to add a new line here


def unconstrain_fn(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Given a NumPyro model and a dict of parameters,
this function applies the right transformation to convert parameter values
from constrained space to unconstrained space.

:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of constrained values keyed by site
names.
:return: `dict` of transformation keyed by site names.
"""
transforms = get_transforms(model, model_args, model_kwargs, params)
return transform_fn(transforms, params, invert=True)


def _unconstrain_reparam(params, site):
name = site["name"]
if name in params:
Expand Down Expand Up @@ -449,6 +489,10 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
for arg in args:
if not isinstance(getattr(support, arg), (int, float)):
replay_model = True
elif v["type"] == "param":
constraint = v["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(v, raise_warnings=True):
inv_transforms[k] = biject_to(constraint)
elif v["type"] == "deterministic":
replay_model = True
return inv_transforms, replay_model, has_enumerate_support, model_trace
Expand Down
47 changes: 47 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
log_likelihood,
potential_energy,
transform_fn,
unconstrain_fn,
)
import numpyro.optim as optim

Expand Down Expand Up @@ -220,6 +221,52 @@ def model():
assert_allclose(actual_potential_energy, expected_potential_energy)


def test_constrain_unconstrain():
x_prior = dist.HalfNormal(2)
y_prior = dist.LogNormal(scale=3.0) # transformed distribution
z_constraint = constraints.positive

def model():
numpyro.sample("x", x_prior)
numpyro.sample("y", y_prior)
numpyro.param("z", init_value=2.0, constraint=z_constraint)

params = {"x": jnp.array(-5.0), "y": jnp.array(7.0), "z": jnp.array(3.0)}
model = handlers.seed(model, random.PRNGKey(0))
inv_transforms = {
"x": biject_to(x_prior.support),
"y": biject_to(y_prior.support),
"z": biject_to(z_constraint),
}
expected_constrained_samples = partial(transform_fn, inv_transforms)(params)
transforms = {
"x": biject_to(x_prior.support).inv,
"y": biject_to(y_prior.support).inv,
"z": biject_to(z_constraint).inv,
}
expected_unconstrained_samples = partial(transform_fn, transforms)(
expected_constrained_samples
)

actual_constrained_samples = constrain_fn(model, (), {}, params)
actual_unconstrained_samples = unconstrain_fn(
model, (), {}, actual_constrained_samples
)

assert_allclose(expected_constrained_samples["x"], actual_constrained_samples["x"])
assert_allclose(expected_constrained_samples["y"], actual_constrained_samples["y"])
assert_allclose(expected_constrained_samples["z"], actual_constrained_samples["z"])
assert_allclose(
expected_unconstrained_samples["x"], actual_unconstrained_samples["x"]
)
assert_allclose(
expected_unconstrained_samples["y"], actual_unconstrained_samples["y"]
)
assert_allclose(
expected_unconstrained_samples["z"], actual_unconstrained_samples["z"]
)


def test_model_with_mask_false():
def model():
x = numpyro.sample("x", dist.Normal())
Expand Down