From 9d0756e069fb65d3773fb0f4d7afae813ee5ea88 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 11 Oct 2023 14:08:36 +0200 Subject: [PATCH 1/2] correct event_dim use --- numpyro/distributions/batch_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 48688ac10..0b54fb641 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -577,7 +577,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution): def _promote_batch_shape_independent(d: Independent): new_self = copy.copy(d) new_base_dist = promote_batch_shape(d.base_dist) - new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim] + new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim] new_self.base_dist = new_base_dist return new_self From 96d12392ff0ad998c80637e939e4463a6635071c Mon Sep 17 00:00:00 2001 From: = Date: Wed, 11 Oct 2023 15:01:58 +0200 Subject: [PATCH 2/2] test --- test/contrib/test_control_flow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 6ccc84605..67273d02c 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -213,7 +213,7 @@ def transition_fn(c, val): def test_scan_plate_mask(): - def model(y=None, T=10): + def model(y=None, T=12): def transition(carry, y_curr): x_prev, t = carry with numpyro.plate("N", 10, dim=-1): @@ -237,7 +237,7 @@ def transition(carry, y_curr): return (x, y) with numpyro.handlers.seed(rng_seed=0): - model_density, model_trace = log_density(model, (None, 10), {}, {}) + model_density, model_trace = log_density(model, (None, 12), {}, {}) assert model_density - assert model_trace["x"]["fn"].batch_shape == (10,) + assert model_trace["x"]["fn"].batch_shape == (12, 10) assert model_trace["x"]["fn"].event_shape == (3,)