diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index cc13dbf40..48688ac10 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -51,6 +51,7 @@ from numpyro.distributions.distribution import ( Distribution, ExpandedDistribution, + Independent, MaskedDistribution, Unit, ) @@ -572,6 +573,15 @@ def _promote_batch_shape_masked(d: MaskedDistribution): return new_self +@promote_batch_shape.register +def _promote_batch_shape_independent(d: Independent): + new_self = copy.copy(d) + new_base_dist = promote_batch_shape(d.base_dist) + new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim] + new_self.base_dist = new_base_dist + return new_self + + @promote_batch_shape.register def _promote_batch_shape_unit(d: Unit): return d diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index c4b2e641d..6ccc84605 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -10,9 +10,9 @@ import numpyro from numpyro.contrib.control_flow import cond, scan import numpyro.distributions as dist -from numpyro.handlers import seed, substitute, trace +from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO -from numpyro.infer.util import potential_energy +from numpyro.infer.util import log_density, potential_energy def test_scan(): @@ -210,3 +210,34 @@ def transition_fn(c, val): tr = numpyro.handlers.trace(model).get_trace() assert tr["x"]["value"].shape == (10, 1) assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3) + + +def test_scan_plate_mask(): + def model(y=None, T=10): + def transition(carry, y_curr): + x_prev, t = carry + with numpyro.plate("N", 10, dim=-1): + with mask(mask=(t < T)): + x_curr = numpyro.sample( + "x", + dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1), + ) + y_curr = numpyro.sample( + "y", + dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1), + obs=y_curr, + ) + return (x_curr, t + 1), None + + x0 = numpyro.sample( + "x_0", dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1) + ) + + x, t = scan(transition, (x0, 0), y, length=T) + return (x, y) + + with numpyro.handlers.seed(rng_seed=0): + model_density, model_trace = log_density(model, (None, 10), {}, {}) + assert model_density + assert model_trace["x"]["fn"].batch_shape == (10,) + assert model_trace["x"]["fn"].event_shape == (3,)