Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dispatch _promote_batch_shape_expanded to Independent #1630

Merged
merged 11 commits into from
Aug 28, 2023

Conversation

deoxyribose
Copy link
Contributor

@deoxyribose deoxyribose commented Aug 21, 2023

Independent distributions aren't dispatched to in batch_utils - the fix reuses the batch promoting for ExpandedDistribution and runs without errors, but hasn't been tested thoroughly, perhaps @pierreglaser can review?

@deoxyribose
Copy link
Contributor Author

deoxyribose commented Aug 21, 2023

Building documentation fails with

Warning, treated as error:
autodoc: failed to import function 'control_flow.scan' from module 'numpyro.contrib'; the following exception was raised:
Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.8.17/x64/lib/python3.8/site-packages/sphinx/ext/autodoc/importer.py", line 62, in import_module
    return importlib.import_module(modname)
  File "/opt/hostedtoolcache/Python/3.8.17/x64/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 843, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/runner/work/numpyro/numpyro/numpyro/contrib/control_flow/__init__.py", line 5, in <module>
    from numpyro.contrib.control_flow.scan import scan
  File "/home/runner/work/numpyro/numpyro/numpyro/contrib/control_flow/scan.py", line 12, in <module>
    from numpyro.distributions.batch_util import promote_batch_shape
  File "/home/runner/work/numpyro/numpyro/numpyro/distributions/batch_util.py", line 528, in <module>
    def _promote_batch_shape_expanded(d: Union[ExpandedDistribution, Independent]):
  File "/opt/hostedtoolcache/Python/3.8.17/x64/lib/python3.8/functools.py", line 860, in register
    raise TypeError(
TypeError: Invalid annotation for 'd'. typing.Union[numpyro.distributions.distribution.ExpandedDistribution, numpyro.distributions.distribution.Independent] is not a class.

presumably because functools.singledispatch does not support Union types
A workaround would be to add

    def register_for_union_type(target_fn):
        def decorator(fn):
            arg_type = list(get_type_hints(fn).values())[0]
            assert arg_type.__origin__ is Union
            for type_ in arg_type.__args__:
                fn = target_fn.register(type_)(fn)
            return fn

        return decorator

and do

@register_for_union_type(promote_batch_shape)
def _promote_batch_shape_expanded(d: Union[ExpandedDistribution, Independent]):

@fehiepsi
Copy link
Member

I think you can do

promote_batch_shape.register(dist.Independent)(_promote_batch_shape_expanded)

@@ -523,7 +524,7 @@ def _default_promote_batch_shape(d: Distribution):
return new_self


@promote_batch_shape.register
@promote_batch_shape.register(Independent)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will disable behavior for ExpandedDistribution. You might need to have a separate register for Inpedendent (reuse _promote_batch_shape_expanded as in my comment). Anyway, could you add description for why this is needed. The default behavior does not work for Independent?

Copy link
Contributor Author

@deoxyribose deoxyribose Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example that reproduces the error I got:

import numpyro
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.handlers import mask
from numpyro.contrib.control_flow import scan

def mask_inside_plate_inside_scan(y=None, T=10):
    def transition(carry, y_curr):
        x_prev, t = carry
        with numpyro.plate("N", 10, dim=-1):
            with mask(mask=(t < T)):
                x_curr = numpyro.sample('x', dist.Normal(jnp.zeros((10,8)), jnp.ones((10,8))).to_event(1))
                y_curr = numpyro.sample('y', dist.Normal(x_curr, jnp.ones((10,8))).to_event(1), obs=y_curr)
                return (x_curr, t+1), None
    x0 = numpyro.sample('x_0', dist.Normal(jnp.zeros((10,8)), jnp.ones((10,8))).to_event(1))
    print
    x, t = scan(transition, (x0, 0), y, length=T)
    return (x, y)

with numpyro.handlers.seed(rng_seed=0):
    x, y = mask_inside_plate_inside_scan()
Traceback (most recent call last):
  File "/home/frans/Desktop/test_batch_util_independent.py", line 25, in <module>
    x, y = mask_inside_plate_inside_scan()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frans/Desktop/test_batch_util_independent.py", line 21, in mask_inside_plate_inside_scan
    x, t = scan(transition, (x0, 0), y, length=T)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frans/protein_dmm/numpyro/numpyro/contrib/control_flow/scan.py", line 441, in scan
    msg = apply_stack(initial_msg)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frans/protein_dmm/numpyro/numpyro/primitives.py", line 53, in apply_stack
    default_process_message(msg)
  File "/home/frans/protein_dmm/numpyro/numpyro/primitives.py", line 28, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frans/protein_dmm/numpyro/numpyro/contrib/control_flow/scan.py", line 313, in scan_wrapper
    site["fn"] = promote_batch_shape(site["fn"])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frans/protein_dmm/numpyro/numpyro/distributions/batch_util.py", line 569, in _promote_batch_shape_masked
    new_base_dist = promote_batch_shape(d.base_dist)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frans/protein_dmm/numpyro/numpyro/distributions/batch_util.py", line 515, in _default_promote_batch_shape
    attr_name = list(d.arg_constraints.keys())[0]
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range

The problem is that x is a MaskedDistribution, whose base_dist is an Independent (whose base_dist is Normal). Since Independent isn't being dispatched to, it gets dispatched to _default_promote_batch_shape, which tries to access arg_constraints, which aren't there, but under the base_dist.
The error seems to only appear when inside a scan.

I misunderstood your comment before, but I've applied it now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example is great!! Could you add the above test case to test_control_flow? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it now.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@pierreglaser pierreglaser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @deoxyribose. Comments below :-)

@@ -563,6 +564,9 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution):
return new_self


promote_batch_shape.register(Independent, _promote_batch_shape_expanded)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmh, I don't see any reason why the logic for batch shape promotion of Independent and ExpandedDistribution should be the same. Promoting ExpandedDistribution requires some care in maintaining broadcastability between itself and its base_dist when left-most inserting batch dimensions - I don't think Independent would require this kind of care.

Could you instead create _promote_batch_shape_independent which simply promotes an independent's base distribution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use the code for MaskedDistribution?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmh, unclear - MaskedDistribution enforces dist.batch_shape == dist.base_dist.batch_shape while an Independent distribution will have a different batch shape that its base.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've committed an _promote_batch_shape_independent now, it passes the test, and outputs the correct batch_shape, but beyond that I'm not sure whether the logic is correct.

return (x, y)

with numpyro.handlers.seed(rng_seed=0):
x, y = model()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we be a bit more stringent, and test that some numpyro functionalities like computing log probs work well on the model, as done in other scan tests? That way we will spot more potential bugs in our implementation ahead of merging :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a log_density call to the test.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @deoxyribose!

@fehiepsi fehiepsi merged commit ca96eca into pyro-ppl:master Aug 28, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants