Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support UnpackTransform.inv via pack_fn #1824

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import jax
from jax import lax, vmap
from jax.flatten_util import ravel_pytree
from jax.nn import log_sigmoid, softplus
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
Expand Down Expand Up @@ -1102,13 +1101,15 @@ class UnpackTransform(Transform):
Transforms a contiguous array to a pytree of subarrays.

:param unpack_fn: callable used to unpack a contiguous array.
:param pack_fn: callable used to pack a pytree into a contiguous array.
"""

domain = constraints.real_vector
codomain = constraints.dependent

def __init__(self, unpack_fn):
def __init__(self, unpack_fn, pack_fn=None):
self.unpack_fn = unpack_fn
self.pack_fn = pack_fn

def __call__(self, x):
batch_shape = x.shape[:-1]
Expand All @@ -1121,9 +1122,15 @@ def __call__(self, x):
return self.unpack_fn(x)

def _inverse(self, y):
if self.pack_fn is None:
raise NotImplementedError(
"pack_fn needs to be provided to perform UnpackTransform.inv."
)
leading_dims = [
v.shape[0] if jnp.ndim(v) > 0 else 0 for v in jax.tree.flatten(y)[0]
]
if not leading_dims:
return jnp.array([])
d0 = leading_dims[0]
not_scalar = d0 > 0 or len(leading_dims) > 1
if not_scalar and all(d == d0 for d in leading_dims[1:]):
Expand All @@ -1132,7 +1139,7 @@ def _inverse(self, y):
" cannot transform a batch of unpacked arrays.",
stacklevel=find_stack_level(),
)
return ravel_pytree(y)[0]
return self.pack_fn(y)

def log_abs_det_jacobian(self, x, y, intermediates=None):
return jnp.zeros(jnp.shape(x)[:-1])
Expand All @@ -1145,10 +1152,14 @@ def inverse_shape(self, shape):

def tree_flatten(self):
# XXX: what if unpack_fn is a parametrized callable pytree?
return (), ((), {"unpack_fn": self.unpack_fn})
return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn})

def __eq__(self, other):
return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn
return (
isinstance(other, UnpackTransform)
and (self.unpack_fn is other.unpack_fn)
and (self.pack_fn is other.pack_fn)
)


def _get_target_shape(shape, forward_shape, inverse_shape):
Expand Down
19 changes: 14 additions & 5 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,20 @@ def _unravel_dict(x_flat, shape_dict):
def _ravel_dict(x):
"""Return the flatten version of `x` and shapes of each item in `x`."""
assert isinstance(x, dict)
shape_dict = {}
shape_dict = {name: jnp.shape(value) for name, value in x.items()}
x_flat = _ravel_dict_with_shape_dict(x, shape_dict)
return x_flat, shape_dict


def _ravel_dict_with_shape_dict(x, shape_dict):
assert set(x.keys()) == set(shape_dict.keys())
x_flat = []
for name, value in x.items():
shape_dict[name] = jnp.shape(value)
for name, shape in shape_dict.items():
value = x[name]
assert shape == jnp.shape(value)
x_flat.append(jnp.reshape(value, -1))
x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,))
return x_flat, shape_dict
return x_flat


class AutoContinuous(AutoGuide):
Expand Down Expand Up @@ -661,7 +668,9 @@ def _setup_prototype(self, *args, **kwargs):
unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
self._unpack_latent = UnpackTransform(unpack_latent)
self._unpack_latent = UnpackTransform(
unpack_latent, _ravel_dict_with_shape_dict
)
self.latent_dim = jnp.size(self._init_latent)
if self.latent_dim == 0:
raise RuntimeError(
Expand Down
3 changes: 2 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2703,7 +2703,8 @@ def test_compose_transform_with_intermediates(ts):
def test_unpack_transform(x_dim, y_dim):
xy = np.random.randn(x_dim + y_dim)
unpack_fn = lambda xy: {"x": xy[:x_dim], "y": xy[x_dim:]} # noqa: E731
transform = transforms.UnpackTransform(unpack_fn)
pack_fn = lambda d: jnp.concatenate([d["x"], d["y"]], axis=-1) # noqa: E731
transform = transforms.UnpackTransform(unpack_fn, pack_fn)
z = transform(xy)
if x_dim == y_dim:
with pytest.warns(UserWarning, match="UnpackTransform.inv"):
Expand Down
Loading