Skip to content

Commit

Permalink
Avoid unnecessary reshape for trivial expand (#1776)
Browse files Browse the repository at this point in the history
* avoid unnecessary reshape for trivial expand

* do not perform unnecessary broadcasting for Delta

* fix failing tests related to Delta change

* revert change at Delta in favor of a separate PR
  • Loading branch information
fehiepsi authored Apr 12, 2024
1 parent d7159b8 commit b2cee89
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _broadcast_shape(existing_shape, new_shape):
)
return (
tuple(reversed(reversed_shape)),
OrderedDict(expanded_sizes),
OrderedDict(reversed(expanded_sizes)),
OrderedDict(interstitial_sizes),
)

Expand All @@ -601,6 +601,8 @@ def _sample(self, sample_fn, key, sample_shape=()):
batch_shape = expanded_sizes + interstitial_sizes
# shape = sample_shape + expanded_sizes + interstitial_sizes + base_dist.shape()
samples, intermediates = sample_fn(key, sample_shape=sample_shape + batch_shape)
if not interstitial_sizes:
return samples, intermediates

interstitial_dims = tuple(self._interstitial_sizes.keys())
event_dim = len(self.event_shape)
Expand Down

0 comments on commit b2cee89

Please sign in to comment.