diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 927c2b096..40c31fa51 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -9,7 +9,6 @@ import jax from jax import lax, vmap -from jax.flatten_util import ravel_pytree from jax.nn import log_sigmoid, softplus import jax.numpy as jnp from jax.scipy.linalg import solve_triangular @@ -1102,13 +1101,15 @@ class UnpackTransform(Transform): Transforms a contiguous array to a pytree of subarrays. :param unpack_fn: callable used to unpack a contiguous array. + :param pack_fn: callable used to pack a pytree into a contiguous array. """ domain = constraints.real_vector codomain = constraints.dependent - def __init__(self, unpack_fn): + def __init__(self, unpack_fn, pack_fn=None): self.unpack_fn = unpack_fn + self.pack_fn = pack_fn def __call__(self, x): batch_shape = x.shape[:-1] @@ -1121,9 +1122,15 @@ def __call__(self, x): return self.unpack_fn(x) def _inverse(self, y): + if self.pack_fn is None: + raise NotImplementedError( + "pack_fn needs to be provided to perform UnpackTransform.inv." + ) leading_dims = [ v.shape[0] if jnp.ndim(v) > 0 else 0 for v in jax.tree.flatten(y)[0] ] + if not leading_dims: + return jnp.array([]) d0 = leading_dims[0] not_scalar = d0 > 0 or len(leading_dims) > 1 if not_scalar and all(d == d0 for d in leading_dims[1:]): @@ -1132,7 +1139,7 @@ def _inverse(self, y): " cannot transform a batch of unpacked arrays.", stacklevel=find_stack_level(), ) - return ravel_pytree(y)[0] + return self.pack_fn(y) def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.zeros(jnp.shape(x)[:-1]) @@ -1145,10 +1152,14 @@ def inverse_shape(self, shape): def tree_flatten(self): # XXX: what if unpack_fn is a parametrized callable pytree? - return (), ((), {"unpack_fn": self.unpack_fn}) + return (), ((), {"unpack_fn": self.unpack_fn, "pack_fn": self.pack_fn}) def __eq__(self, other): - return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn + return ( + isinstance(other, UnpackTransform) + and (self.unpack_fn is other.unpack_fn) + and (self.pack_fn is other.pack_fn) + ) def _get_target_shape(shape, forward_shape, inverse_shape): diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 624865a75..72b2df3bf 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -624,13 +624,20 @@ def _unravel_dict(x_flat, shape_dict): def _ravel_dict(x): """Return the flatten version of `x` and shapes of each item in `x`.""" assert isinstance(x, dict) - shape_dict = {} + shape_dict = {name: jnp.shape(value) for name, value in x.items()} + x_flat = _ravel_dict_with_shape_dict(x, shape_dict) + return x_flat, shape_dict + + +def _ravel_dict_with_shape_dict(x, shape_dict): + assert set(x.keys()) == set(shape_dict.keys()) x_flat = [] - for name, value in x.items(): - shape_dict[name] = jnp.shape(value) + for name, shape in shape_dict.items(): + value = x[name] + assert shape == jnp.shape(value) x_flat.append(jnp.reshape(value, -1)) x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,)) - return x_flat, shape_dict + return x_flat class AutoContinuous(AutoGuide): @@ -661,7 +668,9 @@ def _setup_prototype(self, *args, **kwargs): unpack_latent = partial(_unravel_dict, shape_dict=shape_dict) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples - self._unpack_latent = UnpackTransform(unpack_latent) + self._unpack_latent = UnpackTransform( + unpack_latent, _ravel_dict_with_shape_dict + ) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError( diff --git a/test/test_distributions.py b/test/test_distributions.py index 7453e3f51..e10fd7248 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2703,7 +2703,8 @@ def test_compose_transform_with_intermediates(ts): def test_unpack_transform(x_dim, y_dim): xy = np.random.randn(x_dim + y_dim) unpack_fn = lambda xy: {"x": xy[:x_dim], "y": xy[x_dim:]} # noqa: E731 - transform = transforms.UnpackTransform(unpack_fn) + pack_fn = lambda d: jnp.concatenate([d["x"], d["y"]], axis=-1) # noqa: E731 + transform = transforms.UnpackTransform(unpack_fn, pack_fn) z = transform(xy) if x_dim == y_dim: with pytest.warns(UserWarning, match="UnpackTransform.inv"):