Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dispatch _promote_batch_shape_expanded to Independent #1630

Merged
merged 11 commits into from
Aug 28, 2023
4 changes: 4 additions & 0 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from numpyro.distributions.distribution import (
Distribution,
ExpandedDistribution,
Independent,
MaskedDistribution,
Unit,
)
Expand Down Expand Up @@ -563,6 +564,9 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution):
return new_self


promote_batch_shape.register(Independent, _promote_batch_shape_expanded)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmh, I don't see any reason why the logic for batch shape promotion of Independent and ExpandedDistribution should be the same. Promoting ExpandedDistribution requires some care in maintaining broadcastability between itself and its base_dist when left-most inserting batch dimensions - I don't think Independent would require this kind of care.

Could you instead create _promote_batch_shape_independent which simply promotes an independent's base distribution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use the code for MaskedDistribution?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmh, unclear - MaskedDistribution enforces dist.batch_shape == dist.base_dist.batch_shape while an Independent distribution will have a different batch shape that its base.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've committed an _promote_batch_shape_independent now, it passes the test, and outputs the correct batch_shape, but beyond that I'm not sure whether the logic is correct.



@promote_batch_shape.register
def _promote_batch_shape_masked(d: MaskedDistribution):
new_self = copy.copy(d)
Expand Down
30 changes: 29 additions & 1 deletion test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpyro
from numpyro.contrib.control_flow import cond, scan
import numpyro.distributions as dist
from numpyro.handlers import seed, substitute, trace
from numpyro.handlers import mask, seed, substitute, trace
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.util import potential_energy

Expand Down Expand Up @@ -210,3 +210,31 @@ def transition_fn(c, val):
tr = numpyro.handlers.trace(model).get_trace()
assert tr["x"]["value"].shape == (10, 1)
assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3)


def test_scan_plate_mask():
def model(y=None, T=10):
def transition(carry, y_curr):
x_prev, t = carry
with numpyro.plate("N", 10, dim=-1):
with mask(mask=(t < T)):
x_curr = numpyro.sample(
"x",
dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1),
)
y_curr = numpyro.sample(
"y",
dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1),
obs=y_curr,
)
return (x_curr, t + 1), None

x0 = numpyro.sample(
"x_0", dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1)
)

x, t = scan(transition, (x0, 0), y, length=T)
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
x, y = model()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we be a bit more stringent, and test that some numpyro functionalities like computing log probs work well on the model, as done in other scan tests? That way we will spot more potential bugs in our implementation ahead of merging :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a log_density call to the test.

Loading