diff --git a/funsor/distributions2.py b/funsor/distributions2.py new file mode 100644 index 000000000..b8a67c228 --- /dev/null +++ b/funsor/distributions2.py @@ -0,0 +1,194 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import OrderedDict + +import makefun +import torch +import pyro.distributions as dist +from pyro.distributions.torch_distribution import MaskedDistribution +from pyro.distributions.util import broadcast_shape + +import funsor.ops as ops +from funsor.domains import Domain, bint, reals +from funsor.tensor import Tensor, align_tensors +from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, eager, to_data, to_funsor + + +def _dummy_tensor(domain): + return torch.tensor(0.1 if domain.dtype == 'real' else 1).expand(domain.shape) + + +class DistributionMeta2(FunsorMeta): + def __call__(cls, *args, name=None): + if len(args) < len(cls._ast_fields): + args = args + (name if name is not None else 'value',) + return super(DistributionMeta2, cls).__call__(*args) + + +class Distribution2(Funsor, metaclass=DistributionMeta2): + """ + Different design for the Distribution Funsor wrapper, + closer to Gaussian or Delta in which the value is a fresh input. + """ + dist_class = dist.Distribution # defined by derived classes + + def __init__(self, *args, name='value'): + params = OrderedDict(zip(self._ast_fields, args)) + inputs = OrderedDict() + for param_name, value in params.items(): + assert isinstance(param_name, str) + assert isinstance(value, Funsor) + inputs.update(value.inputs) + assert isinstance(name, str) and name not in inputs + inputs[name] = self._infer_value_shape(**params) + output = reals() + fresh = frozenset({name}) + bound = frozenset() + super().__init__(inputs, output, fresh, bound) + self.params = params + self.name = name + + def __getattribute__(self, attr): + if attr in type(self)._ast_fields and attr != 'name': + return self.params[attr] + return super().__getattribute__(attr) + + @classmethod + def _infer_value_shape(cls, **kwargs): + # rely on the underlying distribution's logic to infer the event_shape + instance = cls.dist_class(**{k: _dummy_tensor(v.output) for k, v in kwargs.items()}) + out_shape = instance.event_shape + if isinstance(instance.support, torch.distributions.constraints._IntegerInterval): + out_dtype = instance.support.upper_bound + 1 + else: + out_dtype = 'real' + return Domain(dtype=out_dtype, shape=out_shape) + + def eager_subs(self, subs): + name, sub = subs[0] + if isinstance(sub, (Number, Tensor)): + inputs, tensors = align_tensors(*self.params.values()) + data = self.dist_class(*tensors).log_prob(sub.data) + return Tensor(data, inputs) + elif isinstance(sub, (Variable, str)): + return type(self)(*self._ast_values, name=sub.name if isinstance(sub, Variable) else sub) + else: + raise NotImplementedError("not implemented") + + +###################################### +# Converting distributions to funsors +###################################### + +@to_funsor.register(torch.distributions.Distribution) +def torchdistribution_to_funsor(pyro_dist, output=None, dim_to_name=None): + import funsor.distributions2 # TODO find a better way to do this lookup + funsor_dist_class = getattr(funsor.distributions2, type(pyro_dist).__name__) + params = [to_funsor(getattr(pyro_dist, param_name), dim_to_name=dim_to_name) + for param_name in funsor_dist_class._ast_fields if param_name != 'name'] + return funsor_dist_class(*params) + + +@to_funsor.register(torch.distributions.Independent) +def indepdist_to_funsor(pyro_dist, output=None, dim_to_name=None): + result = to_funsor(pyro_dist.base_dist, dim_to_name=dim_to_name) + for i in range(pyro_dist.reinterpreted_batch_ndims): + name = ... # XXX what is this? read off from result? + result = funsor.terms.Independent(result, "value", name, "value") + return result + + +@to_funsor.register(MaskedDistribution) +def maskeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): + mask = to_funsor(pyro_dist._mask.float(), output=output, dim_to_name=dim_to_name) + funsor_base_dist = to_funsor(pyro_dist.base_dist, output=output, dim_to_name=dim_to_name) + return mask * funsor_base_dist + + +@to_funsor.register(torch.distributions.TransformedDistribution) +def transformeddist_to_funsor(pyro_dist, output=None, dim_to_name=None): + raise NotImplementedError("TODO") + + +########################################################### +# Converting distribution funsors to PyTorch distributions +########################################################### + +@to_data.register(Distribution2) +def distribution_to_data(funsor_dist, name_to_dim=None): + pyro_dist_class = funsor_dist.dist_class + params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) + for param_name in funsor_dist._ast_fields if param_name != 'name'] + pyro_dist = pyro_dist_class(*params) + funsor_event_shape = funsor_dist.inputs[funsor_dist.name].shape + pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) + if pyro_dist.event_shape != funsor_event_shape: + raise ValueError("Event shapes don't match, something went wrong") + return pyro_dist + + +@to_data.register(Independent) +def indep_to_data(funsor_dist, name_to_dim=None): + raise NotImplementedError("TODO") + + +################################################################################ +# Distribution Wrappers +################################################################################ + +def make_dist(pyro_dist_class, param_names=()): + + if not param_names: + param_names = tuple(pyro_dist_class.arg_constraints.keys()) + assert all(name in pyro_dist_class.arg_constraints for name in param_names) + + @makefun.with_signature(f"__init__(self, {', '.join(param_names)}, name='value')") + def dist_init(self, *args, **kwargs): + return Distribution2.__init__(self, *map(to_funsor, list(kwargs.values())[:-1]), name=kwargs['name']) + + dist_class = DistributionMeta2(pyro_dist_class.__name__, (Distribution2,), { + 'dist_class': pyro_dist_class, + '__init__': dist_init, + }) + + return dist_class + + +class BernoulliProbs(dist.Bernoulli): + def __init__(self, probs, validate_args=None): + return super().__init__(probs=probs, validate_args=validate_args) + + +class BernoulliLogits(dist.Bernoulli): + def __init__(self, logits, validate_args=None): + return super().__init__(logits=logits, validate_args=validate_args) + + +class CategoricalProbs(dist.Categorical): + def __init__(self, probs, validate_args=None): + return super().__init__(probs=probs, validate_args=validate_args) + + +class CategoricalLogits(dist.Categorical): + def __init__(self, logits, validate_args=None): + return super().__init__(logits=logits, validate_args=validate_args) + + +_wrapped_pyro_dists = [ + (dist.Beta, ()), + (BernoulliProbs, ('probs',)), + (BernoulliLogits, ('logits',)), + (CategoricalProbs, ('probs',)), + (CategoricalLogits, ('logits',)), + (dist.Poisson, ()), + (dist.Gamma, ()), + (dist.VonMises, ()), + (dist.Dirichlet, ()), + (dist.Normal, ()), + (dist.MultivariateNormal, ('loc', 'scale_tril')), +] + +for pyro_dist_class, param_names in _wrapped_pyro_dists: + locals()[pyro_dist_class.__name__.split(".")[-1]] = make_dist(pyro_dist_class, param_names) diff --git a/setup.py b/setup.py index 23a39f65a..78b25e1c9 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ author_email='fritzo@uber.com', python_requires=">=3.6", install_requires=[ + 'makefun', 'multipledispatch', 'numpy>=1.7', 'opt_einsum>=2.3.2', diff --git a/test/test_distributions2.py b/test/test_distributions2.py new file mode 100644 index 000000000..c748331c1 --- /dev/null +++ b/test/test_distributions2.py @@ -0,0 +1,271 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import OrderedDict + +import pyro +import pytest +import torch + +import funsor +import funsor.distributions2 as dist +from funsor.cnf import Contraction, GaussianMixture +from funsor.delta import Delta +from funsor.domains import bint, reals +from funsor.interpreter import interpretation, reinterpret +from funsor.pyro.convert import dist_to_funsor +from funsor.tensor import Einsum, Tensor +from funsor.terms import Independent, Variable, lazy +from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor +from funsor.util import get_backend + +funsor.set_backend("torch") + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_beta_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def beta(concentration1, concentration0, value): + return torch.distributions.Beta(concentration1, concentration0).log_prob(value) + + check_funsor(beta, {'concentration1': reals(), 'concentration0': reals(), 'value': reals()}, reals()) + + concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs) + concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.rand(batch_shape), inputs) + expected = beta(concentration1, concentration0, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Beta(concentration1, concentration0, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_bernoulli_probs_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals()) + def bernoulli(probs, value): + return torch.distributions.Bernoulli(probs=probs).log_prob(value) + + check_funsor(bernoulli, {'probs': reals(), 'value': reals()}, reals()) + + probs = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.rand(batch_shape).round(), inputs) + expected = bernoulli(probs, value) + check_funsor(expected, inputs, reals()) + + actual = dist.BernoulliProbs(probs, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_bernoulli_logits_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals()) + def bernoulli(logits, value): + return torch.distributions.Bernoulli(logits=logits).log_prob(value) + + check_funsor(bernoulli, {'logits': reals(), 'value': reals()}, reals()) + + logits = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.rand(batch_shape).round(), inputs) + expected = bernoulli(logits, value) + check_funsor(expected, inputs, reals()) + + actual = dist.BernoulliLogits(logits, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('size', [4]) +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_categorical_probs_density(size, batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(size), bint(size), reals()) + def categorical_probs(probs, value): + return torch.distributions.Categorical(probs=probs).log_prob(value) + + check_funsor(categorical_probs, {'probs': reals(size), 'value': bint(size)}, reals()) + + probs_data = torch.randn(batch_shape + (size,)).exp() + probs_data /= probs_data.sum(-1, keepdim=True) + probs = Tensor(probs_data, inputs) + value = random_tensor(inputs, bint(size)) + expected = categorical_probs(probs, value) + check_funsor(expected, inputs, reals()) + + actual = dist.CategoricalProbs(probs, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('size', [4]) +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_categorical_logits_density(size, batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(size), bint(size), reals()) + def categorical_logits(logits, value): + return torch.distributions.Categorical(logits=logits).log_prob(value) + + check_funsor(categorical_logits, {'logits': reals(size), 'value': bint(size)}, reals()) + + logits_data = torch.randn(batch_shape + (size,)) + logits_data /= logits_data.sum(-1, keepdim=True) + logits = Tensor(logits_data, inputs) + value = random_tensor(inputs, bint(size)) + expected = categorical_logits(logits, value) + check_funsor(expected, inputs, reals()) + + actual = dist.CategoricalLogits(logits, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str) +def test_dirichlet_density(batch_shape, event_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(*event_shape), reals(*event_shape), reals()) + def dirichlet(concentration, value): + return torch.distributions.Dirichlet(concentration).log_prob(value) + + check_funsor(dirichlet, {'concentration': reals(*event_shape), 'value': reals(*event_shape)}, reals()) + + concentration = Tensor(torch.randn(batch_shape + event_shape).exp(), inputs) + value_data = torch.rand(batch_shape + event_shape) + value_data = value_data / value_data.sum(-1, keepdim=True) + value = Tensor(value_data, inputs) + expected = dirichlet(concentration, value) + check_funsor(expected, inputs, reals()) + actual = dist.Dirichlet(concentration, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_normal_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def normal(loc, scale, value): + return torch.distributions.Normal(loc, scale).log_prob(value) + + check_funsor(normal, {'loc': reals(), 'scale': reals(), 'value': reals()}, reals()) + + loc = Tensor(torch.randn(batch_shape), inputs) + scale = Tensor(torch.randn(batch_shape).exp(), inputs) + value = Tensor(torch.randn(batch_shape), inputs) + expected = normal(loc, scale, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Normal(loc, scale, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_poisson_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals()) + def poisson(rate, value): + return torch.distributions.Poisson(rate).log_prob(value) + + check_funsor(poisson, {'rate': reals(), 'value': reals()}, reals()) + + rate = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.randn(batch_shape).exp().round(), inputs) + expected = poisson(rate, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Poisson(rate, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_gamma_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def gamma(concentration, rate, value): + return torch.distributions.Gamma(concentration, rate).log_prob(value) + + check_funsor(gamma, {'concentration': reals(), 'rate': reals(), 'value': reals()}, reals()) + + concentration = Tensor(torch.rand(batch_shape), inputs) + rate = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.randn(batch_shape).exp(), inputs) + expected = gamma(concentration, rate, value) + check_funsor(expected, inputs, reals()) + + actual = dist.Gamma(concentration, rate, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +def test_von_mises_density(batch_shape): + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) + + @funsor.function(reals(), reals(), reals(), reals()) + def von_mises(loc, concentration, value): + return pyro.distributions.VonMises(loc, concentration).log_prob(value) + + check_funsor(von_mises, {'concentration': reals(), 'loc': reals(), 'value': reals()}, reals()) + + concentration = Tensor(torch.rand(batch_shape), inputs) + loc = Tensor(torch.rand(batch_shape), inputs) + value = Tensor(torch.randn(batch_shape).abs(), inputs) + expected = von_mises(loc, concentration, value) + check_funsor(expected, inputs, reals()) + + actual = dist.VonMises(loc, concentration, name='value')(value=value) + check_funsor(actual, inputs, reals()) + assert_close(actual, expected) + + +@pytest.mark.parametrize("event_shape", [ + (), # (5,), (4, 3), +], ids=str) +@pytest.mark.parametrize("batch_shape", [ + (), (2,), (2, 3), +], ids=str) +def test_normal_funsor_normal(batch_shape, event_shape): + loc = torch.randn(batch_shape + event_shape) + scale = torch.randn(batch_shape + event_shape).exp() + d = pyro.distributions.Normal(loc, scale).to_event(len(event_shape)) + value = d.sample() + name_to_dim = OrderedDict( + (f'{v}', v) for v in range(-len(batch_shape), 0) if batch_shape[v] > 1) + dim_to_name = OrderedDict((v, k) for k, v in name_to_dim.items()) + f = funsor.to_funsor(d, reals(), dim_to_name=dim_to_name) + d2 = funsor.to_data(f, name_to_dim=name_to_dim) + assert type(d) == type(d2) + assert d.batch_shape == d2.batch_shape + assert d.event_shape == d2.event_shape + expected_log_prob = d.log_prob(value) + actual_log_prob = d2.log_prob(value) + assert_close(actual_log_prob, expected_log_prob) + expected_funsor_log_prob = funsor.to_funsor(actual_log_prob, reals(), dim_to_name) + actual_funsor_log_prob = f(value=funsor.to_funsor(value, reals(*event_shape), dim_to_name)) + assert_close(actual_funsor_log_prob, expected_funsor_log_prob)