-
Notifications
You must be signed in to change notification settings - Fork 236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dispatch _promote_batch_shape_expanded to Independent #1630
Conversation
Building documentation fails with
presumably because functools.singledispatch does not support Union types def register_for_union_type(target_fn):
def decorator(fn):
arg_type = list(get_type_hints(fn).values())[0]
assert arg_type.__origin__ is Union
for type_ in arg_type.__args__:
fn = target_fn.register(type_)(fn)
return fn
return decorator and do @register_for_union_type(promote_batch_shape)
def _promote_batch_shape_expanded(d: Union[ExpandedDistribution, Independent]): |
I think you can do
|
numpyro/distributions/batch_util.py
Outdated
@@ -523,7 +524,7 @@ def _default_promote_batch_shape(d: Distribution): | |||
return new_self | |||
|
|||
|
|||
@promote_batch_shape.register | |||
@promote_batch_shape.register(Independent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will disable behavior for ExpandedDistribution. You might need to have a separate register for Inpedendent (reuse _promote_batch_shape_expanded as in my comment). Anyway, could you add description for why this is needed. The default behavior does not work for Independent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an example that reproduces the error I got:
import numpyro
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.handlers import mask
from numpyro.contrib.control_flow import scan
def mask_inside_plate_inside_scan(y=None, T=10):
def transition(carry, y_curr):
x_prev, t = carry
with numpyro.plate("N", 10, dim=-1):
with mask(mask=(t < T)):
x_curr = numpyro.sample('x', dist.Normal(jnp.zeros((10,8)), jnp.ones((10,8))).to_event(1))
y_curr = numpyro.sample('y', dist.Normal(x_curr, jnp.ones((10,8))).to_event(1), obs=y_curr)
return (x_curr, t+1), None
x0 = numpyro.sample('x_0', dist.Normal(jnp.zeros((10,8)), jnp.ones((10,8))).to_event(1))
print
x, t = scan(transition, (x0, 0), y, length=T)
return (x, y)
with numpyro.handlers.seed(rng_seed=0):
x, y = mask_inside_plate_inside_scan()
Traceback (most recent call last):
File "/home/frans/Desktop/test_batch_util_independent.py", line 25, in <module>
x, y = mask_inside_plate_inside_scan()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/frans/Desktop/test_batch_util_independent.py", line 21, in mask_inside_plate_inside_scan
x, t = scan(transition, (x0, 0), y, length=T)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/frans/protein_dmm/numpyro/numpyro/contrib/control_flow/scan.py", line 441, in scan
msg = apply_stack(initial_msg)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/frans/protein_dmm/numpyro/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/frans/protein_dmm/numpyro/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/frans/protein_dmm/numpyro/numpyro/contrib/control_flow/scan.py", line 313, in scan_wrapper
site["fn"] = promote_batch_shape(site["fn"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/frans/protein_dmm/numpyro/numpyro/distributions/batch_util.py", line 569, in _promote_batch_shape_masked
new_base_dist = promote_batch_shape(d.base_dist)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/frans/protein_dmm/numpyro/numpyro/distributions/batch_util.py", line 515, in _default_promote_batch_shape
attr_name = list(d.arg_constraints.keys())[0]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range
The problem is that x is a MaskedDistribution
, whose base_dist is an Independent
(whose base_dist is Normal
). Since Independent
isn't being dispatched to, it gets dispatched to _default_promote_batch_shape
, which tries to access arg_constraints
, which aren't there, but under the base_dist.
The error seems to only appear when inside a scan.
I misunderstood your comment before, but I've applied it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example is great!! Could you add the above test case to test_control_flow? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @deoxyribose. Comments below :-)
numpyro/distributions/batch_util.py
Outdated
@@ -563,6 +564,9 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): | |||
return new_self | |||
|
|||
|
|||
promote_batch_shape.register(Independent, _promote_batch_shape_expanded) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmh, I don't see any reason why the logic for batch shape promotion of Independent
and ExpandedDistribution
should be the same. Promoting ExpandedDistribution
requires some care in maintaining broadcastability between itself and its base_dist
when left-most inserting batch dimensions - I don't think Independent
would require this kind of care.
Could you instead create _promote_batch_shape_independent
which simply promotes an independent's base distribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can use the code for MaskedDistribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmh, unclear - MaskedDistribution
enforces dist.batch_shape == dist.base_dist.batch_shape
while an Independent
distribution will have a different batch shape that its base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've committed an _promote_batch_shape_independent
now, it passes the test, and outputs the correct batch_shape, but beyond that I'm not sure whether the logic is correct.
test/contrib/test_control_flow.py
Outdated
return (x, y) | ||
|
||
with numpyro.handlers.seed(rng_seed=0): | ||
x, y = model() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we be a bit more stringent, and test that some numpyro
functionalities like computing log probs work well on the model, as done in other scan
tests? That way we will spot more potential bugs in our implementation ahead of merging :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a log_density
call to the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @deoxyribose!
Independent distributions aren't dispatched to in batch_utils - the fix reuses the batch promoting for ExpandedDistribution and runs without errors, but hasn't been tested thoroughly, perhaps @pierreglaser can review?