-
Notifications
You must be signed in to change notification settings - Fork 236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dispatch _promote_batch_shape_expanded to Independent #1630
Changes from 8 commits
7118101
ed6a716
ccb3108
4735a69
273757c
44b4b21
3060f66
01da9f7
d283108
2ececc8
58232df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import numpyro | ||
from numpyro.contrib.control_flow import cond, scan | ||
import numpyro.distributions as dist | ||
from numpyro.handlers import seed, substitute, trace | ||
from numpyro.handlers import mask, seed, substitute, trace | ||
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO | ||
from numpyro.infer.util import potential_energy | ||
|
||
|
@@ -210,3 +210,31 @@ def transition_fn(c, val): | |
tr = numpyro.handlers.trace(model).get_trace() | ||
assert tr["x"]["value"].shape == (10, 1) | ||
assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3) | ||
|
||
|
||
def test_scan_plate_mask(): | ||
def model(y=None, T=10): | ||
def transition(carry, y_curr): | ||
x_prev, t = carry | ||
with numpyro.plate("N", 10, dim=-1): | ||
with mask(mask=(t < T)): | ||
x_curr = numpyro.sample( | ||
"x", | ||
dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1), | ||
) | ||
y_curr = numpyro.sample( | ||
"y", | ||
dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1), | ||
obs=y_curr, | ||
) | ||
return (x_curr, t + 1), None | ||
|
||
x0 = numpyro.sample( | ||
"x_0", dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1) | ||
) | ||
|
||
x, t = scan(transition, (x0, 0), y, length=T) | ||
return (x, y) | ||
|
||
with numpyro.handlers.seed(rng_seed=0): | ||
x, y = model() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we be a bit more stringent, and test that some There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmh, I don't see any reason why the logic for batch shape promotion of
Independent
andExpandedDistribution
should be the same. PromotingExpandedDistribution
requires some care in maintaining broadcastability between itself and itsbase_dist
when left-most inserting batch dimensions - I don't thinkIndependent
would require this kind of care.Could you instead create
_promote_batch_shape_independent
which simply promotes an independent's base distribution?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can use the code for MaskedDistribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmh, unclear -
MaskedDistribution
enforcesdist.batch_shape == dist.base_dist.batch_shape
while anIndependent
distribution will have a different batch shape that its base.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've committed an
_promote_batch_shape_independent
now, it passes the test, and outputs the correct batch_shape, but beyond that I'm not sure whether the logic is correct.