Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic unscaled_sample method to Distribution #323

Merged
merged 26 commits into from
Mar 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
102e605
Add a generic unscaled_sample method to Distribution
eb8680 Mar 24, 2020
9a3be45
include dice factor in result of sample
eb8680 Mar 25, 2020
1c7f9d7
nit
eb8680 Mar 25, 2020
6a39e00
split shape and value tests for sample
eb8680 Mar 25, 2020
3650b1d
add tests for more distributions
eb8680 Mar 26, 2020
f74cbcd
fix bernoulli
eb8680 Mar 26, 2020
2cdec4a
add beta and binomial sample tests
eb8680 Mar 26, 2020
571edd5
atol and rtol params
eb8680 Mar 26, 2020
c3730f3
bernoulliprobs test
eb8680 Mar 26, 2020
e394934
normal sample test
eb8680 Mar 26, 2020
af9cb4c
poisson sample test
eb8680 Mar 26, 2020
aaf933c
add xfails for missing eager sampling patterns
eb8680 Mar 26, 2020
b006907
lint
eb8680 Mar 26, 2020
7d9ba16
test reparameterized gradients
eb8680 Mar 27, 2020
98889b6
lint
eb8680 Mar 27, 2020
78cbc03
Add nonreparametrized distributions for testing
eb8680 Mar 27, 2020
2fe6d5c
forgot beta
eb8680 Mar 27, 2020
ba4b50b
change default gradient test statistic to mean
eb8680 Mar 27, 2020
517312d
address comments
eb8680 Mar 27, 2020
6a4a8c9
tweak tests and fix parameter handling for ambiguous distributions
eb8680 Mar 27, 2020
e1581be
increase num_samples for some nonreparametrized tests
eb8680 Mar 27, 2020
a85b1e9
switch beta test to variance, weird behavior
eb8680 Mar 27, 2020
d235cb3
decrease default sample test tolerance
eb8680 Mar 27, 2020
495ed9e
increase num_samples
eb8680 Mar 27, 2020
b7cf999
tweak bernoulli test tolerance
eb8680 Mar 27, 2020
1622645
address nit
eb8680 Mar 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import makefun
import pyro.distributions as dist
import pyro.distributions.testing.fakes as fakes
from pyro.distributions.torch_distribution import MaskedDistribution
import torch
import torch.distributions.constraints as constraints
Expand Down Expand Up @@ -85,11 +86,11 @@ def __init__(self, *args):
inputs = OrderedDict(inputs)
output = reals()
super(Distribution, self).__init__(inputs, output)
self.params = params
self.params = OrderedDict(params)

def __repr__(self):
return '{}({})'.format(type(self).__name__,
', '.join('{}={}'.format(*kv) for kv in self.params))
', '.join('{}={}'.format(*kv) for kv in self.params.items()))

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp and isinstance(self.value, Variable) and self.value.name in reduced_vars:
Expand All @@ -104,6 +105,30 @@ def eager_log_prob(cls, *params):
data = cls.dist_class(**params).log_prob(value)
return Tensor(data, inputs)

def unscaled_sample(self, sampled_vars, sample_inputs):
params = OrderedDict(self.params)
value = params.pop("value")
assert all(isinstance(v, (Number, Tensor)) for v in params.values())
assert isinstance(value, Variable) and value.name in sampled_vars
inputs_, tensors = align_tensors(*params.values())
inputs = OrderedDict(sample_inputs.items())
inputs.update(inputs_)
sample_shape = tuple(v.size for v in sample_inputs.values())

raw_dist = self.dist_class(**dict(zip(self._ast_fields[:-1], tensors)))
if getattr(raw_dist, "has_rsample", False):
raw_sample = raw_dist.rsample(sample_shape)
else:
raw_sample = raw_dist.sample(sample_shape)

result = funsor.delta.Delta(value.name, Tensor(raw_sample, inputs, value.output.dtype))
if not getattr(raw_dist, "has_rsample", False):
# scaling of dice_factor by num samples should already be handled by Funsor.sample
raw_log_prob = raw_dist.log_prob(raw_sample)
dice_factor = Tensor(raw_log_prob - raw_log_prob.detach(), inputs)
result += dice_factor
return result

def __getattribute__(self, attr):
if attr in type(self)._ast_fields and attr != 'name':
return self.params[attr]
Expand Down Expand Up @@ -194,6 +219,10 @@ def __init__(self, logits, validate_args=None):
(dist.Normal, ()),
(dist.MultivariateNormal, ('loc', 'scale_tril')),
(dist.Delta, ()),
(fakes.NonreparameterizedBeta, ()),
(fakes.NonreparameterizedGamma, ()),
(fakes.NonreparameterizedNormal, ()),
(fakes.NonreparameterizedDirichlet, ()),
]

for pyro_dist_class, param_names in _wrapped_pyro_dists:
Expand Down Expand Up @@ -284,7 +313,7 @@ def distribution_to_data(funsor_dist, name_to_dim=None):
pyro_dist_class = funsor_dist.dist_class
params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim)
for param_name in funsor_dist._ast_fields if param_name != 'value']
pyro_dist = pyro_dist_class(*params)
pyro_dist = pyro_dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params)))
funsor_event_shape = funsor_dist.value.output.shape
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0))
if pyro_dist.event_shape != funsor_event_shape:
Expand Down
200 changes: 197 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

import funsor
import funsor.distributions as dist
import funsor.ops as ops
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.domains import bint, reals
from funsor.interpreter import interpretation, reinterpret
from funsor.integrate import Integrate
from funsor.pyro.convert import dist_to_funsor
from funsor.tensor import Einsum, Tensor
from funsor.terms import Independent, Variable, lazy
from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor
from funsor.tensor import Einsum, Tensor, align_tensors
from funsor.terms import Independent, Variable, eager, lazy
from funsor.testing import assert_close, check_funsor, random_mvn, random_tensor, xfail_param
from funsor.util import get_backend

pytestmark = pytest.mark.skipif(get_backend() != "torch",
Expand Down Expand Up @@ -654,3 +656,195 @@ def von_mises(loc, concentration, value):
actual = dist.VonMises(loc, concentration, d)(value=value)
check_funsor(actual, inputs, reals())
assert_close(actual, expected)


def _check_sample(funsor_dist, sample_inputs, inputs, atol=1e-2, rtol=None,
num_samples=100000, statistic="mean", skip_grad=False):
"""utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs))))
sample_inputs = OrderedDict((k, bint(samples_per_dim)) for k in sample_inputs)

for tensor in list(funsor_dist.params.values())[:-1]:
tensor.data.requires_grad_()

sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs)
expected_inputs = OrderedDict(
tuple(sample_inputs.items()) + tuple(inputs.items()) + (('value', funsor_dist.inputs['value']),)
)
check_funsor(sample_value, expected_inputs, reals())

if sample_inputs:

actual_mean = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))

inputs, tensors = align_tensors(*list(funsor_dist.params.values())[:-1])
raw_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], tensors)))
expected_mean = Tensor(raw_dist.mean, inputs)

check_funsor(actual_mean, expected_mean.inputs, expected_mean.output)
assert_close(actual_mean, expected_mean, atol=atol, rtol=rtol)

if sample_inputs and not skip_grad:
if statistic == "mean":
actual_stat, expected_stat = actual_mean, expected_mean
elif statistic == "variance":
actual_stat = Integrate(
sample_value,
(Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2,
frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = Tensor(raw_dist.variance, inputs)
elif statistic == "entropy":
actual_stat = -Integrate(
sample_value, funsor_dist, frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = Tensor(raw_dist.entropy(), inputs)
else:
raise ValueError("invalid test statistic")

grad_targets = [v.data for v in list(funsor_dist.params.values())[:-1]]
actual_grads = torch.autograd.grad(actual_stat.reduce(ops.add).sum().data, grad_targets, allow_unused=True)
expected_grads = torch.autograd.grad(
expected_stat.reduce(ops.add).sum().data, grad_targets, allow_unused=True)

assert_close(actual_stat, expected_stat, atol=atol, rtol=rtol)

for actual_grad, expected_grad in zip(actual_grads, expected_grads):
if expected_grad is not None:
assert_close(actual_grad, expected_grad, atol=atol, rtol=rtol)
else:
assert_close(actual_grad, torch.zeros_like(actual_grad), atol=atol, rtol=rtol)


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('reparametrized', [True, False])
def test_gamma_sample(batch_shape, sample_inputs, reparametrized):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

concentration = Tensor(torch.rand(batch_shape), inputs)
rate = Tensor(torch.rand(batch_shape), inputs)
funsor_dist = (dist.Gamma if reparametrized else dist.NonreparameterizedGamma)(concentration, rate)

_check_sample(funsor_dist, sample_inputs, inputs, num_samples=200000,
atol=5e-2 if reparametrized else 1e-1)


@pytest.mark.parametrize("with_lazy", [True, xfail_param(False, reason="missing pattern")])
@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('reparametrized', [True, False])
def test_normal_sample(with_lazy, batch_shape, sample_inputs, reparametrized):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

loc = Tensor(torch.randn(batch_shape), inputs)
scale = Tensor(torch.rand(batch_shape), inputs)
with interpretation(lazy if with_lazy else eager):
funsor_dist = (dist.Normal if reparametrized else dist.NonreparameterizedNormal)(loc, scale)

_check_sample(funsor_dist, sample_inputs, inputs, num_samples=200000, atol=1e-2 if reparametrized else 1e-1)


@pytest.mark.parametrize("with_lazy", [True, xfail_param(False, reason="missing pattern")])
@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str)
def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

loc = Tensor(torch.randn(batch_shape + event_shape), inputs)
scale_tril = Tensor(_random_scale_tril(batch_shape + event_shape * 2), inputs)
with interpretation(lazy if with_lazy else eager):
funsor_dist = dist.MultivariateNormal(loc, scale_tril)

_check_sample(funsor_dist, sample_inputs, inputs, atol=5e-2, num_samples=200000)


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('event_shape', [(1,), (4,), (5,)], ids=str)
@pytest.mark.parametrize('reparametrized', [True, False])
def test_dirichlet_sample(batch_shape, sample_inputs, event_shape, reparametrized):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

concentration = Tensor(torch.randn(batch_shape + event_shape).exp(), inputs)
funsor_dist = (dist.Dirichlet if reparametrized else dist.NonreparameterizedDirichlet)(concentration)

_check_sample(funsor_dist, sample_inputs, inputs, atol=1e-2 if reparametrized else 1e-1)


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
def test_bernoullilogits_sample(batch_shape, sample_inputs):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

logits = Tensor(torch.rand(batch_shape), inputs)
funsor_dist = dist.Bernoulli(logits=logits)

_check_sample(funsor_dist, sample_inputs, inputs, atol=5e-2, num_samples=100000)


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
def test_bernoulliprobs_sample(batch_shape, sample_inputs):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

probs = Tensor(torch.rand(batch_shape), inputs)
funsor_dist = dist.Bernoulli(probs=probs)

_check_sample(funsor_dist, sample_inputs, inputs, atol=5e-2, num_samples=100000)


@pytest.mark.parametrize("with_lazy", [True, xfail_param(False, reason="missing pattern")])
@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('reparametrized', [True, False])
def test_beta_sample(with_lazy, batch_shape, sample_inputs, reparametrized):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs)
concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs)
with interpretation(lazy if with_lazy else eager):
funsor_dist = (dist.Beta if reparametrized else dist.NonreparameterizedBeta)(
concentration1, concentration0)

_check_sample(funsor_dist, sample_inputs, inputs, atol=1e-2 if reparametrized else 1e-1,
statistic="variance", num_samples=100000)


@pytest.mark.parametrize("with_lazy", [True, xfail_param(False, reason="missing pattern")])
@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
def test_binomial_sample(with_lazy, batch_shape, sample_inputs):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

max_count = 10
total_count_data = random_tensor(inputs, bint(max_count)).data.float()
total_count = Tensor(total_count_data, inputs)
probs = Tensor(torch.rand(batch_shape), inputs)
with interpretation(lazy if with_lazy else eager):
funsor_dist = dist.Binomial(total_count, probs)

_check_sample(funsor_dist, sample_inputs, inputs, skip_grad=True)


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
def test_poisson_sample(batch_shape, sample_inputs):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

rate = Tensor(torch.rand(batch_shape), inputs)
funsor_dist = dist.Poisson(rate)

_check_sample(funsor_dist, sample_inputs, inputs, skip_grad=True)