Skip to content

Commit

Permalink
correct event_dim use (#1661)
Browse files Browse the repository at this point in the history
* correct event_dim use

* test
  • Loading branch information
deoxyribose authored Oct 11, 2023
1 parent dca5b2b commit 065aa43
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution):
def _promote_batch_shape_independent(d: Independent):
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim]
new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim]
new_self.base_dist = new_base_dist
return new_self

Expand Down
6 changes: 3 additions & 3 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def transition_fn(c, val):


def test_scan_plate_mask():
def model(y=None, T=10):
def model(y=None, T=12):
def transition(carry, y_curr):
x_prev, t = carry
with numpyro.plate("N", 10, dim=-1):
Expand All @@ -237,7 +237,7 @@ def transition(carry, y_curr):
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
model_density, model_trace = log_density(model, (None, 10), {}, {})
model_density, model_trace = log_density(model, (None, 12), {}, {})
assert model_density
assert model_trace["x"]["fn"].batch_shape == (10,)
assert model_trace["x"]["fn"].batch_shape == (12, 10)
assert model_trace["x"]["fn"].event_shape == (3,)

0 comments on commit 065aa43

Please sign in to comment.