-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor funsor.distributions #319
Closed
Closed
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
3b13e27
add an inputs argument to to_funsor
eb8680 bc56c2e
implement funsor_to_tensor and tensor_to_funsor
eb8680 fa04123
use to_funsor and to_data in funsor.pyro.convert
eb8680 0951569
nit
eb8680 84d9fc0
tweak
eb8680 a32566f
attempt at generic distribution conversion
eb8680 6be8a34
address comment
eb8680 630abf5
remove domains from dim_to_name
eb8680 bbbdafa
add dim_to_name docstring comment
eb8680 01cde8f
assert batch dim negativity
eb8680 1a3a88f
consider even named dims of size 1 empty in tensor_to_funsor
eb8680 be04eb1
Merge branch 'to-funsor-inputs' into to-funsor-distributions
eb8680 4956775
sketch new distribution wrapper
eb8680 8dc8585
split new version into second file
eb8680 f84a031
lint
eb8680 716a483
Merge branch 'master' into to-funsor-distributions
eb8680 809e44c
most basic beta density test passes
eb8680 ed3543c
basic density tests pass
eb8680 0d84d09
tweak generic to_funsor/to_data implementations
eb8680 06e51d4
standardize test
eb8680 55805b9
check event shape in to_data
eb8680 f1f9e0e
add metaclass to handle default name
eb8680 b5ac620
add a to_funsor test for normal
eb8680 2030536
add makefun to dependencies
eb8680 3c1f249
add more to_funsor sketches
eb8680 c913eeb
shuffle code around
eb8680 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import math | ||
from collections import OrderedDict | ||
|
||
import makefun | ||
import torch | ||
import pyro.distributions as dist | ||
from pyro.distributions.torch_distribution import MaskedDistribution | ||
from pyro.distributions.util import broadcast_shape | ||
|
||
import funsor.ops as ops | ||
from funsor.domains import Domain, bint, reals | ||
from funsor.tensor import Tensor, align_tensors | ||
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, eager, to_data, to_funsor | ||
|
||
|
||
def _dummy_tensor(domain): | ||
return torch.tensor(0.1 if domain.dtype == 'real' else 1).expand(domain.shape) | ||
|
||
|
||
class DistributionMeta2(FunsorMeta): | ||
def __call__(cls, *args, name=None): | ||
if len(args) < len(cls._ast_fields): | ||
args = args + (name if name is not None else 'value',) | ||
return super(DistributionMeta2, cls).__call__(*args) | ||
|
||
|
||
class Distribution2(Funsor, metaclass=DistributionMeta2): | ||
""" | ||
Different design for the Distribution Funsor wrapper, | ||
closer to Gaussian or Delta in which the value is a fresh input. | ||
""" | ||
dist_class = dist.Distribution # defined by derived classes | ||
|
||
def __init__(self, *args, name='value'): | ||
params = OrderedDict(zip(self._ast_fields, args)) | ||
inputs = OrderedDict() | ||
for param_name, value in params.items(): | ||
assert isinstance(param_name, str) | ||
assert isinstance(value, Funsor) | ||
inputs.update(value.inputs) | ||
assert isinstance(name, str) and name not in inputs | ||
inputs[name] = self._infer_value_shape(**params) | ||
output = reals() | ||
fresh = frozenset({name}) | ||
bound = frozenset() | ||
super().__init__(inputs, output, fresh, bound) | ||
self.params = params | ||
self.name = name | ||
|
||
def __getattribute__(self, attr): | ||
if attr in type(self)._ast_fields and attr != 'name': | ||
return self.params[attr] | ||
return super().__getattribute__(attr) | ||
|
||
@classmethod | ||
def _infer_value_shape(cls, **kwargs): | ||
# rely on the underlying distribution's logic to infer the event_shape | ||
instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs.items()}) | ||
out_shape = instance.event_shape | ||
if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): | ||
out_dtype = instance.support.upper_bound + 1 | ||
else: | ||
out_dtype = 'real' | ||
return Domain(dtype=out_dtype, shape=out_shape) | ||
|
||
def eager_subs(self, subs): | ||
name, sub = subs[0] | ||
if isinstance(sub, (Number, Tensor)): | ||
inputs, tensors = align_tensors(*self.params.values()) | ||
data = self.dist_class(*tensors).log_prob(sub.data) | ||
return Tensor(data, inputs) | ||
elif isinstance(sub, (Variable, str)): | ||
return type(self)(*self._ast_values, name=sub.name if isinstance(sub, Variable) else sub) | ||
else: | ||
raise NotImplementedError("not implemented") | ||
|
||
|
||
###################################### | ||
# Converting distributions to funsors | ||
###################################### | ||
|
||
@to_funsor.register(torch.distributions.Distribution) | ||
def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): | ||
import funsor.distributions2 # TODO find a better way to do this lookup | ||
funsor_dist_class = getattr(funsor.distributions2, type(pyro_dist).__name__) | ||
params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) | ||
for param_name in funsor_dist_class._ast_fields if param_name != 'name'] | ||
return funsor_dist_class(*params) | ||
|
||
|
||
@to_funsor.register(torch.distributions.Independent) | ||
def indepdist_to_funsor(pyro_dist, output=None, dim_to_name=None): | ||
result = to_funsor(pyro_dist.base_dist, dim_to_name=dim_to_name) | ||
for i in range(pyro_dist.reinterpreted_batch_ndims): | ||
name = ... # XXX what is this? read off from result? | ||
result = funsor.terms.Independent(result, "value", name, "value") | ||
return result | ||
|
||
|
||
@to_funsor.register(MaskedDistribution) | ||
def maskeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): | ||
mask = to_funsor(pyro_dist._mask.float(), output=output, dim_to_name=dim_to_name) | ||
funsor_base_dist = to_funsor(pyro_dist.base_dist, output=output, dim_to_name=dim_to_name) | ||
return mask * funsor_base_dist | ||
|
||
|
||
@to_funsor.register(torch.distributions.TransformedDistribution) | ||
def transformeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): | ||
raise NotImplementedError("TODO") | ||
|
||
|
||
########################################################### | ||
# Converting distribution funsors to PyTorch distributions | ||
########################################################### | ||
|
||
@to_data.register(Distribution2) | ||
def distribution_to_data(funsor_dist, name_to_dim=None): | ||
pyro_dist_class = funsor_dist.dist_class | ||
params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) | ||
for param_name in funsor_dist._ast_fields if param_name != 'name'] | ||
pyro_dist = pyro_dist_class(*params) | ||
funsor_event_shape = funsor_dist.inputs[funsor_dist.name].shape | ||
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) | ||
if pyro_dist.event_shape != funsor_event_shape: | ||
raise ValueError("Event shapes don't match, something went wrong") | ||
return pyro_dist | ||
|
||
|
||
@to_data.register(Independent) | ||
def indep_to_data(funsor_dist, name_to_dim=None): | ||
raise NotImplementedError("TODO") | ||
|
||
|
||
################################################################################ | ||
# Distribution Wrappers | ||
################################################################################ | ||
|
||
def make_dist(pyro_dist_class, param_names=()): | ||
|
||
if not param_names: | ||
param_names = tuple(pyro_dist_class.arg_constraints.keys()) | ||
assert all(name in pyro_dist_class.arg_constraints for name in param_names) | ||
|
||
@makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") | ||
def dist_init(self, *args, **kwargs): | ||
return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name=kwargs['name']) | ||
|
||
dist_class = DistributionMeta2(pyro_dist_class.__name__, (Distribution2,), { | ||
'dist_class': pyro_dist_class, | ||
'__init__': dist_init, | ||
}) | ||
|
||
return dist_class | ||
|
||
|
||
class BernoulliProbs(dist.Bernoulli): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a few of these wrappers for distributions with multiple parametrizations, but these should probably be moved upstream to Pyro or PyTorch. |
||
def __init__(self, probs, validate_args=None): | ||
return super().__init__(probs=probs, validate_args=validate_args) | ||
|
||
|
||
class BernoulliLogits(dist.Bernoulli): | ||
def __init__(self, logits, validate_args=None): | ||
return super().__init__(logits=logits, validate_args=validate_args) | ||
|
||
|
||
class CategoricalProbs(dist.Categorical): | ||
def __init__(self, probs, validate_args=None): | ||
return super().__init__(probs=probs, validate_args=validate_args) | ||
|
||
|
||
class CategoricalLogits(dist.Categorical): | ||
def __init__(self, logits, validate_args=None): | ||
return super().__init__(logits=logits, validate_args=validate_args) | ||
|
||
|
||
_wrapped_pyro_dists = [ | ||
(dist.Beta, ()), | ||
(BernoulliProbs, ('probs',)), | ||
(BernoulliLogits, ('logits',)), | ||
(CategoricalProbs, ('probs',)), | ||
(CategoricalLogits, ('logits',)), | ||
(dist.Poisson, ()), | ||
(dist.Gamma, ()), | ||
(dist.VonMises, ()), | ||
(dist.Dirichlet, ()), | ||
(dist.Normal, ()), | ||
(dist.MultivariateNormal, ('loc', 'scale_tril')), | ||
] | ||
|
||
for pyro_dist_class, param_names in _wrapped_pyro_dists: | ||
locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class, param_names) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
author_email='[email protected]', | ||
python_requires=">=3.6", | ||
install_requires=[ | ||
'makefun', | ||
'multipledispatch', | ||
'numpy>=1.7', | ||
'opt_einsum>=2.3.2', | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is the main innovation in this PR: it turns out we can reuse the distributions' existing
event_shape
andsupport
inference logic to generically infer the shape of the freshvalue
input, rather than reimplementing that logic by hand for every distribution we add.