From 065aa4352dbc9ebeb1132b15b51d8b3149848640 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Wed, 11 Oct 2023 15:53:43 +0200 Subject: [PATCH] correct event_dim use (#1661) * correct event_dim use * test --- numpyro/distributions/batch_util.py | 2 +- test/contrib/test_control_flow.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 48688ac10..0b54fb641 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -577,7 +577,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution): def _promote_batch_shape_independent(d: Independent): new_self = copy.copy(d) new_base_dist = promote_batch_shape(d.base_dist) - new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim] + new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim] new_self.base_dist = new_base_dist return new_self diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 6ccc84605..67273d02c 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -213,7 +213,7 @@ def transition_fn(c, val): def test_scan_plate_mask(): - def model(y=None, T=10): + def model(y=None, T=12): def transition(carry, y_curr): x_prev, t = carry with numpyro.plate("N", 10, dim=-1): @@ -237,7 +237,7 @@ def transition(carry, y_curr): return (x, y) with numpyro.handlers.seed(rng_seed=0): - model_density, model_trace = log_density(model, (None, 10), {}, {}) + model_density, model_trace = log_density(model, (None, 12), {}, {}) assert model_density - assert model_trace["x"]["fn"].batch_shape == (10,) + assert model_trace["x"]["fn"].batch_shape == (12, 10) assert model_trace["x"]["fn"].event_shape == (3,)