From b2cee89a952cc290af7e00b003ef5556f9e8a350 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 11 Apr 2024 21:13:32 -0400 Subject: [PATCH] Avoid unnecessary reshape for trivial expand (#1776) * avoid unnecessary reshape for trivial expand * do not perform unnecessary broadcasting for Delta * fix failing tests related to Delta change * revert change at Delta in favor of a separate PR --- numpyro/distributions/distribution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ade4e9910..0f7b8136a 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -583,7 +583,7 @@ def _broadcast_shape(existing_shape, new_shape): ) return ( tuple(reversed(reversed_shape)), - OrderedDict(expanded_sizes), + OrderedDict(reversed(expanded_sizes)), OrderedDict(interstitial_sizes), ) @@ -601,6 +601,8 @@ def _sample(self, sample_fn, key, sample_shape=()): batch_shape = expanded_sizes + interstitial_sizes # shape = sample_shape + expanded_sizes + interstitial_sizes + base_dist.shape() samples, intermediates = sample_fn(key, sample_shape=sample_shape + batch_shape) + if not interstitial_sizes: + return samples, intermediates interstitial_dims = tuple(self._interstitial_sizes.keys()) event_dim = len(self.event_shape)