Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor funsor.distributions #319

Closed
wants to merge 26 commits into from
Closed

Refactor funsor.distributions #319

wants to merge 26 commits into from

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Mar 3, 2020

This PR adds a replacement for the funsor.distributions module in which PyTorch distributions are wrapped programmatically and to_funsor and to_data are defined generically. The main API change in this version is that distributions now take a name argument rather than a value, making them behave a bit more like Gaussian or Delta, but I'm open to reverting that.

Remaining tasks:

  • Add output shape inference to tensor_to_funsor so that distribution shape inference works
  • Add more tests for conversion
  • Make tests more generic
  • Port any non-trivial eager rules for distribution evaluation to the new module
  • Deprecate the existing distributions code

@eb8680 eb8680 added the WIP label Mar 3, 2020
Comment on lines +59 to +67
def _infer_value_shape(cls, **kwargs):
# rely on the underlying distribution's logic to infer the event_shape
instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs.items()})
out_shape = instance.event_shape
if isinstance(instance.support, torch.distributions.constraints._IntegerInterval):
out_dtype = instance.support.upper_bound + 1
else:
out_dtype = 'real'
return Domain(dtype=out_dtype, shape=out_shape)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is the main innovation in this PR: it turns out we can reuse the distributions' existing event_shape and support inference logic to generically infer the shape of the fresh value input, rather than reimplementing that logic by hand for every distribution we add.

return dist_class


class BernoulliProbs(dist.Bernoulli):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few of these wrappers for distributions with multiple parametrizations, but these should probably be moved upstream to Pyro or PyTorch.

@eb8680
Copy link
Member Author

eb8680 commented Mar 3, 2020

Closing in favor of #320 which is similar but narrower in scope

@eb8680 eb8680 closed this Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant