-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor funsor.distributions #319
Conversation
def _infer_value_shape(cls, **kwargs): | ||
# rely on the underlying distribution's logic to infer the event_shape | ||
instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs.items()}) | ||
out_shape = instance.event_shape | ||
if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): | ||
out_dtype = instance.support.upper_bound + 1 | ||
else: | ||
out_dtype = 'real' | ||
return Domain(dtype=out_dtype, shape=out_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is the main innovation in this PR: it turns out we can reuse the distributions' existing event_shape
and support
inference logic to generically infer the shape of the fresh value
input, rather than reimplementing that logic by hand for every distribution we add.
return dist_class | ||
|
||
|
||
class BernoulliProbs(dist.Bernoulli): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a few of these wrappers for distributions with multiple parametrizations, but these should probably be moved upstream to Pyro or PyTorch.
Closing in favor of #320 which is similar but narrower in scope |
This PR adds a replacement for the
funsor.distributions
module in which PyTorch distributions are wrapped programmatically andto_funsor
andto_data
are defined generically. The main API change in this version is that distributions now take aname
argument rather than avalue
, making them behave a bit more likeGaussian
orDelta
, but I'm open to reverting that.Remaining tasks:
tensor_to_funsor
so that distribution shape inference workseager
rules for distribution evaluation to the new moduledistributions
code