From 17f87b29e4120c26cc15feb85f259943362e8f73 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 11:42:53 -0700 Subject: [PATCH 01/10] Sketch sampler implementations --- funsor/delta.py | 7 +++++-- funsor/gaussian.py | 25 +++++++++++++++++++++++++ funsor/minipyro.py | 1 - funsor/ops.py | 14 -------------- funsor/terms.py | 6 ++++++ funsor/torch.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 79 insertions(+), 17 deletions(-) diff --git a/funsor/delta.py b/funsor/delta.py index 5102d384f..e04527e16 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -8,7 +8,6 @@ from funsor.domains import reals from funsor.ops import Op from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor -from funsor.torch import Tensor class DeltaMeta(FunsorMeta): @@ -66,7 +65,7 @@ def eager_subs(self, subs): if value is not None: if isinstance(value, Variable): name = value.name - elif isinstance(value, (Number, Tensor)) and isinstance(point, (Number, Tensor)): + elif not any(d.dtype == 'real' for side in (value, point) for d in side.inputs.values()): return (value == point).all().log() + log_density else: # TODO Compute a jacobian, update log_prob, and emit another Delta. @@ -83,6 +82,10 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation + def sample(self, sampled_vars): + assert all(k == self.name for k in sampled_vars if k in self.inputs) + return self + @eager.register(Binary, Op, Delta, (Funsor, Delta, Align)) def eager_binary(op, lhs, rhs): diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 26e5f713d..774baea80 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -6,8 +6,10 @@ import torch from pyro.distributions.util import broadcast_shape from six import add_metaclass, integer_types +from six.moves import reduce import funsor.ops as ops +from funsor.delta import Delta from funsor.domains import reals from funsor.ops import AddOp from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager @@ -250,6 +252,29 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation + def sample(self, sampled_vars, sample_inputs=None): + sampled_vars = sampled_vars.intersection(self.inputs) + if not sampled_vars: + return self + assert all(self.inputs[k].dtype == 'real' for k in sampled_vars) + + int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') + real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') + if sampled_vars == frozenset(real_inputs): + scale_tril = torch.inverse(torch.cholesky(self.precision)) + assert self.loc.shape == scale_tril.shape[:-1] + white_noise = torch.randn(self.loc.shape) + sample = self.loc + _mv(scale_tril, white_noise) + offsets, _ = _compute_offsets(real_inputs) + results = [] + for key, domain in real_inputs.items(): + data = sample[..., offsets[key]: offsets[key] + domain.num_elements] + point = Tensor(data, int_inputs, domain) + results.append(Delta(key, point)) + return reduce(ops.add, results) + + raise NotImplementedError('TODO implement partial sampling of real variables') + @eager.register(Binary, AddOp, Gaussian, Gaussian) def eager_add_gaussian_gaussian(op, lhs, rhs): diff --git a/funsor/minipyro.py b/funsor/minipyro.py index 5793e61ad..6d89afb63 100644 --- a/funsor/minipyro.py +++ b/funsor/minipyro.py @@ -389,7 +389,6 @@ def elbo(model, guide, *args, **kwargs): # FIXME do not marginalize; instead sample. q = guide_joint.log_prob.reduce(ops.logaddexp) tr = guide_joint.samples - tr.update(funsor.backward(ops.sample, q)) # force deferred samples? # replay model against guide with log_joint() as model_joint, replay(guide_trace=tr): diff --git a/funsor/ops.py b/funsor/ops.py index 421333649..d15ff0e88 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -123,18 +123,6 @@ def safediv(x, y): return truediv(x, y) -# just a placeholder -@Op -def marginal(x, y): - raise ValueError - - -# just a placeholder -@Op -def sample(x, y): - raise ValueError - - @Op def reciprocal(x): if isinstance(x, Number): @@ -176,7 +164,6 @@ def reciprocal(x): 'log', 'log1p', 'lt', - 'marginal', 'max', 'min', 'mul', @@ -186,7 +173,6 @@ def reciprocal(x): 'pow', 'safediv', 'safesub', - 'sample', 'sqrt', 'sub', 'truediv', diff --git a/funsor/terms.py b/funsor/terms.py index edab81213..c4e74f166 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -213,6 +213,12 @@ def reduce(self, op, reduced_vars=None): assert reduced_vars.issubset(self.inputs) return Reduce(op, self, reduced_vars) + def sample(self, sampled_vars, sample_inputs=None): + assert isinstance(sampled_vars, frozenset) + if sampled_vars.isdisjoint(self.inputs): + return self + raise NotImplementedError + def align(self, names): """ Align this funsor to match given ``names``. diff --git a/funsor/torch.py b/funsor/torch.py index c6dc86dfb..1aa46fdba 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -5,8 +5,10 @@ import torch from six import add_metaclass, integer_types +from six.moves import reduce import funsor.ops as ops +from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals from funsor.ops import Op from funsor.six import getargspec @@ -234,6 +236,47 @@ def eager_reduce(self, op, reduced_vars): return Tensor(data, inputs, self.dtype) return super(Tensor, self).eager_reduce(op, reduced_vars) + def sample(self, sampled_vars, sample_inputs=None): + assert self.output == reals() + sampled_vars = sampled_vars.intersection(self.inputs) + if not sampled_vars: + return self + + # Partition inputs into sample_inputs + batch_inputs + event_inputs. + if sample_inputs is None: + sample_inputs = OrderedDict() + assert frozenset(sample_inputs).isdisjoint(self.inputs) + sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) + batch_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in sampled_vars) + event_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k in sampled_vars) + be_inputs = batch_inputs.copy() + be_inputs.update(event_inputs) + sb_inputs = sample_inputs.copy() + sb_inputs.update(batch_inputs) + + # Sample all variables in a single Categorical call. + logits = align_tensor(be_inputs, self.data) + flat_logits = logits.reshape(logits.shape[:len(batch_inputs)] + (-1,)) + sample_shape = tuple(d.dtype for d in sample_inputs.values()) + flat_sample = torch.distributions.Categorical(logits=flat_logits).sample(sample_shape) + results = [] + for name, domain in reversed(list(event_inputs.items())): + size = domain.dtype + point = Tensor(flat_sample % size, sb_inputs, bint(size)) + flat_sample = flat_sample / size + results.append(Delta(name, point)) + + # Apply an optional dice factor to preserve differentiability. + if flat_logits.requires_grad: + index = [torch.arange(n).reshape((n,) + (1,) * (flat_sample.dim() - i)) + for i, n in enumerate(flat_sample.shape)] + index.append(flat_sample) + log_prob = flat_logits[index] + assert log_prob.shape == flat_sample.shape + results.append(Tensor(log_prob - log_prob.detach(), sb_inputs)) + + return reduce(ops.add, results) + @eager.register(Binary, Op, Tensor, Number) def eager_binary_tensor_number(op, lhs, rhs): From 7caac56a1483fffbe3290aba6d654ff875dbb95d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 12:14:44 -0700 Subject: [PATCH 02/10] Implement Joint.sample() --- funsor/joint.py | 14 ++++++++++++++ funsor/terms.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/funsor/joint.py b/funsor/joint.py index c3c6691cf..7f9f0525b 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -98,6 +98,20 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation + def sample(self, sampled_vars, sample_inputs=None): + discrete_vars = sampled_vars.intersection(self.discrete.inputs) + gaussian_vars = frozenset(k for k in sampled_vars + if k in self.gaussian.inputs + if self.gaussian.inputs[k].dtype == 'real') + result = self + if discrete_vars: + discrete = result.discrete.sample(discrete_vars, sample_inputs) + result = Joint(result.deltas, gaussian=result.gaussian) + discrete + if gaussian_vars: + gaussian = result.gaussian.sample(gaussian_vars, sample_inputs) + result = Joint(result.deltas, result.discrete) + gaussian + return result + @eager.register(Joint, tuple, Funsor, Funsor) def eager_joint(deltas, discrete, gaussian): diff --git a/funsor/terms.py b/funsor/terms.py index c4e74f166..2654e3a05 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -16,7 +16,7 @@ import itertools import numbers from abc import ABCMeta, abstractmethod -from collections import OrderedDict, Hashable +from collections import Hashable, OrderedDict from weakref import WeakValueDictionary from six import add_metaclass, integer_types @@ -24,7 +24,7 @@ import funsor.interpreter as interpreter import funsor.ops as ops -from funsor.domains import Domain, bint, find_domain +from funsor.domains import Domain, bint, find_domain, reals from funsor.interpreter import interpret from funsor.ops import AssociativeOp, Op from funsor.registry import KeyedRegistry @@ -214,6 +214,35 @@ def reduce(self, op, reduced_vars=None): return Reduce(op, self, reduced_vars) def sample(self, sampled_vars, sample_inputs=None): + """ + Create a Monte Carlo approximation to this funsor by replacing + functions of ``sampled_vars`` with :class:`~funsor.delta.Delta`s. + + If ``sample_inputs`` is not provided, the result is a :class:`Funsor` + with the same ``.inputs`` and ``.output`` as the original funsor, so + that self can be replaced by the sample in expectation computations:: + + y = x.sample(sampled_vars) + assert y.inputs == x.inputs + assert y.output == x.output + exact = (x.exp() * integrand).reduce(ops.add) + approx = (y.exp() * integrand).reduce(ops.add) + + If ``sample_inputs`` is provided, this creates a batch of samples + that are intended to be averaged, however this reduction is not + performed by the :meth:`sample` method:: + + y = x.sample(sampled_vars, sample_inputs) + total = reduce(ops.mul, d.num_elements) for d in sample_inputs.values()) + exact = (x.exp() * integrand).reduce(ops.add) + approx = (y.exp() * integrand).reduce(ops.add) / total + + :param frozenset sampled_vars: A set of input variables to sample. + :param OrderedDict sample_inputs: An optional mapping from variable + name to :class:`~funsor.domains.Domain` over which samples will + be batched. + """ + assert self.output == reals() assert isinstance(sampled_vars, frozenset) if sampled_vars.isdisjoint(self.inputs): return self From 6cbd1c37d0e83248a205ae9ece110e56f8cca35c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 14:37:18 -0700 Subject: [PATCH 03/10] Add smoke test for Tensor.sample() --- funsor/testing.py | 12 ++++++++++-- funsor/torch.py | 4 ++-- test/test_gaussian.py | 8 +------- test/test_samplers.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 11 deletions(-) create mode 100644 test/test_samplers.py diff --git a/funsor/testing.py b/funsor/testing.py index 2c53d4519..fbfb95baf 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, division, print_function -import contextlib import itertools import operator from collections import OrderedDict, namedtuple +import contextlib2 import numpy as np import opt_einsum import pytest @@ -19,7 +19,7 @@ from funsor.torch import Tensor -@contextlib.contextmanager +@contextlib2.contextmanager def xfail_if_not_implemented(msg="Not implemented"): try: yield @@ -35,6 +35,14 @@ def __repr__(self): return '\n'.join(['Expected:', str(self.expected), 'Actual:', str(self.actual)]) +def id_from_inputs(inputs): + if isinstance(inputs, (dict, OrderedDict)): + inputs = inputs.items() + if not inputs: + return '()' + return ','.join(k + ''.join(map(str, d.shape)) for k, d in inputs) + + def assert_close(actual, expected, atol=1e-6, rtol=1e-6): msg = ActualExpected(actual, expected) assert type(actual) == type(expected), msg diff --git a/funsor/torch.py b/funsor/torch.py index 1aa46fdba..3f2b6ec6b 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -255,14 +255,14 @@ def sample(self, sampled_vars, sample_inputs=None): sb_inputs.update(batch_inputs) # Sample all variables in a single Categorical call. - logits = align_tensor(be_inputs, self.data) + logits = align_tensor(be_inputs, self) flat_logits = logits.reshape(logits.shape[:len(batch_inputs)] + (-1,)) sample_shape = tuple(d.dtype for d in sample_inputs.values()) flat_sample = torch.distributions.Categorical(logits=flat_logits).sample(sample_shape) results = [] for name, domain in reversed(list(event_inputs.items())): size = domain.dtype - point = Tensor(flat_sample % size, sb_inputs, bint(size)) + point = Tensor(flat_sample % size, sb_inputs, size) flat_sample = flat_sample / size results.append(Delta(name, point)) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 113c6f414..699709fac 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -11,16 +11,10 @@ from funsor.gaussian import Gaussian from funsor.joint import Joint from funsor.terms import Number -from funsor.testing import assert_close, random_gaussian, random_tensor, xfail_if_not_implemented +from funsor.testing import assert_close, id_from_inputs, random_gaussian, random_tensor, xfail_if_not_implemented from funsor.torch import Tensor -def id_from_inputs(inputs): - if not inputs: - return '()' - return ','.join(k + ''.join(map(str, d.shape)) for k, d in inputs.items()) - - @pytest.mark.parametrize('expr,expected_type', [ ('g1 + 1', Joint), ('g1 - 1', Joint), diff --git a/test/test_samplers.py b/test/test_samplers.py new file mode 100644 index 000000000..e33e3fa70 --- /dev/null +++ b/test/test_samplers.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import, division, print_function + +import itertools +from collections import OrderedDict + +import pytest + +from funsor.domains import bint +from funsor.testing import id_from_inputs, random_tensor + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(2)),), + (('s', bint(2)), ('t', bint(3))), +], ids=id_from_inputs) +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(2)),), + (('b', bint(2)), ('c', bint(3))), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (), + (('e', bint(2)),), + (('e', bint(2)), ('f', bint(3))), +], ids=id_from_inputs) +def test_tensor_smoke(sample_inputs, batch_inputs, event_inputs): + be_inputs = OrderedDict(batch_inputs + event_inputs) + expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + + x = random_tensor(be_inputs) + for num_sampled in range(len(event_inputs)): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + y = x.sample(sampled_vars, sample_inputs) + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x From 7c6e10a2ea56583d3792f1557a807eaec6a53d55 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 15:58:58 -0700 Subject: [PATCH 04/10] Add smoke test for Gaussian.sample() --- funsor/gaussian.py | 4 +++- test/test_samplers.py | 55 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 774baea80..10a609922 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -269,7 +269,9 @@ def sample(self, sampled_vars, sample_inputs=None): results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]: offsets[key] + domain.num_elements] - point = Tensor(data, int_inputs, domain) + data = data.reshape(self.loc.shape[:-1] + domain.shape) + point = Tensor(data, int_inputs) + assert point.output == domain results.append(Delta(key, point)) return reduce(ops.add, results) diff --git a/test/test_samplers.py b/test/test_samplers.py index e33e3fa70..b97abc4f3 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -5,22 +5,21 @@ import pytest -from funsor.domains import bint -from funsor.testing import id_from_inputs, random_tensor +from funsor.domains import bint, reals +from funsor.testing import id_from_inputs, random_gaussian, random_tensor @pytest.mark.parametrize('sample_inputs', [ (), - (('s', bint(2)),), - (('s', bint(2)), ('t', bint(3))), + (('s', bint(6)),), + (('s', bint(6)), ('t', bint(7))), ], ids=id_from_inputs) @pytest.mark.parametrize('batch_inputs', [ (), - (('b', bint(2)),), - (('b', bint(2)), ('c', bint(3))), + (('b', bint(4)),), + (('b', bint(4)), ('c', bint(5))), ], ids=id_from_inputs) @pytest.mark.parametrize('event_inputs', [ - (), (('e', bint(2)),), (('e', bint(2)), ('f', bint(3))), ], ids=id_from_inputs) @@ -30,8 +29,8 @@ def test_tensor_smoke(sample_inputs, batch_inputs, event_inputs): sample_inputs = OrderedDict(sample_inputs) batch_inputs = OrderedDict(batch_inputs) event_inputs = OrderedDict(event_inputs) - x = random_tensor(be_inputs) + for num_sampled in range(len(event_inputs)): for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): sampled_vars = frozenset(sampled_vars) @@ -40,3 +39,43 @@ def test_tensor_smoke(sample_inputs, batch_inputs, event_inputs): assert dict(y.inputs) == dict(expected_inputs), sampled_vars else: assert y is x + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(3)),), + (('s', bint(3)), ('t', bint(4))), +], ids=id_from_inputs) +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(2)),), + (('c', reals()),), + (('b', bint(2)), ('c', reals())), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', reals()),), + (('e', reals()), ('f', reals(2))), +], ids=id_from_inputs) +def test_gaussian_smoke(sample_inputs, batch_inputs, event_inputs): + be_inputs = OrderedDict(batch_inputs + event_inputs) + expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + x = random_gaussian(be_inputs) + + xfail = False + for num_sampled in range(len(event_inputs)): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + try: + y = x.sample(sampled_vars, sample_inputs) + except NotImplementedError: + xfail = True + continue + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x + if xfail: + pytest.xfail(reason='Not implemented') From 950e6e3f6609e5dfa3da0032e3a3c2d48b9f4815 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 16:45:24 -0700 Subject: [PATCH 05/10] Add smoke test for Joint.sample() --- funsor/gaussian.py | 15 +++++++++++--- funsor/joint.py | 7 ++++++- funsor/testing.py | 4 ++-- test/test_samplers.py | 48 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 10a609922..7d7eef371 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -258,19 +258,28 @@ def sample(self, sampled_vars, sample_inputs=None): return self assert all(self.inputs[k].dtype == 'real' for k in sampled_vars) + # Partition inputs into sample_inputs + int_inputs + real_inputs. + if sample_inputs is None: + sample_inputs = OrderedDict() + assert frozenset(sample_inputs).isdisjoint(self.inputs) + sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') + inputs = sample_inputs.copy() + inputs.update(int_inputs) + if sampled_vars == frozenset(real_inputs): scale_tril = torch.inverse(torch.cholesky(self.precision)) assert self.loc.shape == scale_tril.shape[:-1] - white_noise = torch.randn(self.loc.shape) + shape = sample_shape + self.loc.shape + white_noise = torch.randn(shape) sample = self.loc + _mv(scale_tril, white_noise) offsets, _ = _compute_offsets(real_inputs) results = [] for key, domain in real_inputs.items(): data = sample[..., offsets[key]: offsets[key] + domain.num_elements] - data = data.reshape(self.loc.shape[:-1] + domain.shape) - point = Tensor(data, int_inputs) + data = data.reshape(shape[:-1] + domain.shape) + point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) return reduce(ops.add, results) diff --git a/funsor/joint.py b/funsor/joint.py index 7f9f0525b..549fb924a 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -99,6 +99,9 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation def sample(self, sampled_vars, sample_inputs=None): + if sample_inputs is None: + sample_inputs = OrderedDict() + assert frozenset(sample_inputs).isdisjoint(self.inputs) discrete_vars = sampled_vars.intersection(self.discrete.inputs) gaussian_vars = frozenset(k for k in sampled_vars if k in self.gaussian.inputs @@ -108,6 +111,8 @@ def sample(self, sampled_vars, sample_inputs=None): discrete = result.discrete.sample(discrete_vars, sample_inputs) result = Joint(result.deltas, gaussian=result.gaussian) + discrete if gaussian_vars: + sample_inputs = OrderedDict((k, v) for k, v in sample_inputs.items() + if k not in result.gaussian.inputs) gaussian = result.gaussian.sample(gaussian_vars, sample_inputs) result = Joint(result.deltas, result.discrete) + gaussian return result @@ -168,7 +173,7 @@ def eager_add(op, joint, other): @eager.register(Binary, AddOp, Joint, Gaussian) def eager_add(op, joint, other): # Update with a delayed gaussian random variable. - subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs) + subs = tuple((d.name, d.point) for d in joint.deltas if d.name in other.inputs) if subs: other = other.eager_subs(subs) if joint.gaussian is not Number(0): diff --git a/funsor/testing.py b/funsor/testing.py index fbfb95baf..1880323e1 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, division, print_function +import contextlib import itertools import operator from collections import OrderedDict, namedtuple -import contextlib2 import numpy as np import opt_einsum import pytest @@ -19,7 +19,7 @@ from funsor.torch import Tensor -@contextlib2.contextmanager +@contextlib.contextmanager def xfail_if_not_implemented(msg="Not implemented"): try: yield diff --git a/test/test_samplers.py b/test/test_samplers.py index b97abc4f3..1a43a627e 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -6,6 +6,7 @@ import pytest from funsor.domains import bint, reals +from funsor.joint import Joint from funsor.testing import id_from_inputs, random_gaussian, random_tensor @@ -31,9 +32,10 @@ def test_tensor_smoke(sample_inputs, batch_inputs, event_inputs): event_inputs = OrderedDict(event_inputs) x = random_tensor(be_inputs) - for num_sampled in range(len(event_inputs)): + for num_sampled in range(len(event_inputs) + 1): for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) y = x.sample(sampled_vars, sample_inputs) if sampled_vars: assert dict(y.inputs) == dict(expected_inputs), sampled_vars @@ -64,10 +66,54 @@ def test_gaussian_smoke(sample_inputs, batch_inputs, event_inputs): event_inputs = OrderedDict(event_inputs) x = random_gaussian(be_inputs) + xfail = False + for num_sampled in range(len(event_inputs) + 1): + for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): + sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) + try: + y = x.sample(sampled_vars, sample_inputs) + except NotImplementedError: + xfail = True + continue + if sampled_vars: + assert dict(y.inputs) == dict(expected_inputs), sampled_vars + else: + assert y is x + if xfail: + pytest.xfail(reason='Not implemented') + + +@pytest.mark.parametrize('sample_inputs', [ + (), + (('s', bint(6)),), + (('s', bint(6)), ('t', bint(7))), +], ids=id_from_inputs) +@pytest.mark.parametrize('int_event_inputs', [ + (), + (('d', bint(2)),), + (('d', bint(2)), ('e', bint(3))), +], ids=id_from_inputs) +@pytest.mark.parametrize('real_event_inputs', [ + (('g', reals()),), + (('g', reals()), ('h', reals(4))), +], ids=id_from_inputs) +def test_joint_smoke(sample_inputs, int_event_inputs, real_event_inputs): + event_inputs = int_event_inputs + real_event_inputs + discrete_inputs = OrderedDict(int_event_inputs) + gaussian_inputs = OrderedDict(event_inputs) + expected_inputs = OrderedDict(sample_inputs + event_inputs) + sample_inputs = OrderedDict(sample_inputs) + event_inputs = OrderedDict(event_inputs) + t = random_tensor(discrete_inputs) + g = random_gaussian(gaussian_inputs) + x = Joint(discrete=t, gaussian=g) + xfail = False for num_sampled in range(len(event_inputs)): for sampled_vars in itertools.combinations(list(event_inputs), num_sampled): sampled_vars = frozenset(sampled_vars) + print('sampled_vars: {}'.format(', '.join(sampled_vars))) try: y = x.sample(sampled_vars, sample_inputs) except NotImplementedError: From 7a71abb24ec7cd73e0d22a9776164e8f58c251d3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 18:41:55 -0700 Subject: [PATCH 06/10] Fix tensor distribution test --- funsor/domains.py | 6 +++- funsor/torch.py | 30 +++++++++++++++--- test/test_samplers.py | 72 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 98 insertions(+), 10 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index 0da97b1b0..e25c3cfb7 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -72,7 +72,11 @@ def find_domain(op, *domains): assert callable(op), op assert all(isinstance(arg, Domain) for arg in domains) if len(domains) == 1: - return domains[0] + dtype = domains[0].dtype + shape = domains[0].shape + if op is ops.log or op is ops.exp: + dtype = 'real' + return Domain(shape, dtype) lhs, rhs = domains if op is ops.getitem: diff --git a/funsor/torch.py b/funsor/torch.py index 3f2b6ec6b..d19cfc41c 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -197,14 +197,15 @@ def eager_subs(self, subs): return Tensor(data, inputs, self.dtype) def eager_unary(self, op): + dtype = find_domain(op, self.output).dtype if op in REDUCE_OP_TO_TORCH: batch_dim = len(self.data.shape) - len(self.output.shape) data = self.data.reshape(self.data.shape[:batch_dim] + (-1,)) data = REDUCE_OP_TO_TORCH[op](data, -1) if op is ops.min or op is ops.max: data = data[0] - return Tensor(data, self.inputs, self.dtype) - return Tensor(op(self.data), self.inputs, self.dtype) + return Tensor(data, self.inputs, dtype) + return Tensor(op(self.data), self.inputs, dtype) def eager_reduce(self, op, reduced_vars): if op in REDUCE_OP_TO_TORCH: @@ -256,9 +257,11 @@ def sample(self, sampled_vars, sample_inputs=None): # Sample all variables in a single Categorical call. logits = align_tensor(be_inputs, self) - flat_logits = logits.reshape(logits.shape[:len(batch_inputs)] + (-1,)) + batch_shape = logits.shape[:len(batch_inputs)] + flat_logits = logits.reshape(batch_shape + (-1,)) sample_shape = tuple(d.dtype for d in sample_inputs.values()) flat_sample = torch.distributions.Categorical(logits=flat_logits).sample(sample_shape) + assert flat_sample.shape == sample_shape + batch_shape results = [] for name, domain in reversed(list(event_inputs.items())): size = domain.dtype @@ -266,14 +269,31 @@ def sample(self, sampled_vars, sample_inputs=None): flat_sample = flat_sample / size results.append(Delta(name, point)) - # Apply an optional dice factor to preserve differentiability. + # Account for the log normalizer factor. + # Derivation: Let f be a nonnormalized distribution (a funsor), and + # consider operations in linear space (source code is in log space). + # Let x0 ~ f/|f| be a monte carlo sample from a normalized f. + # f(x0) / |f| # dice numerator + # Let g = delta(x=x0) |f| ----------------- + # detach(f(x0)/|f|) # dice denominator + # f(x0) |detach(f)| + # = delta(x=x0) ----------------- be a dice approximation of f. + # detach(f(x0)) + # Then g should be an unbiased estimator of f in value and all + # derivatives, including the normalier |f|. + # In case f = detach(f), we can simplify to + # g = delta(x=x0) |detach(f)|. if flat_logits.requires_grad: + # Apply a dice factor to preserve differentiability. index = [torch.arange(n).reshape((n,) + (1,) * (flat_sample.dim() - i)) for i, n in enumerate(flat_sample.shape)] index.append(flat_sample) log_prob = flat_logits[index] assert log_prob.shape == flat_sample.shape - results.append(Tensor(log_prob - log_prob.detach(), sb_inputs)) + results.append(Tensor(flat_logits.detach().logsumexp(-1) + + (log_prob - log_prob.detach()), sb_inputs)) + else: + results.append(Tensor(flat_logits.logsumexp(-1), batch_inputs)) return reduce(ops.add, results) diff --git a/test/test_samplers.py b/test/test_samplers.py index 1a43a627e..20e6f65d4 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -5,9 +5,12 @@ import pytest +import funsor.ops as ops from funsor.domains import bint, reals from funsor.joint import Joint -from funsor.testing import id_from_inputs, random_gaussian, random_tensor +from funsor.terms import Variable +from funsor.testing import assert_close, id_from_inputs, random_gaussian, random_tensor +from funsor.torch import materialize @pytest.mark.parametrize('sample_inputs', [ @@ -24,7 +27,7 @@ (('e', bint(2)),), (('e', bint(2)), ('f', bint(3))), ], ids=id_from_inputs) -def test_tensor_smoke(sample_inputs, batch_inputs, event_inputs): +def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) sample_inputs = OrderedDict(sample_inputs) @@ -58,7 +61,7 @@ def test_tensor_smoke(sample_inputs, batch_inputs, event_inputs): (('e', reals()),), (('e', reals()), ('f', reals(2))), ], ids=id_from_inputs) -def test_gaussian_smoke(sample_inputs, batch_inputs, event_inputs): +def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs) sample_inputs = OrderedDict(sample_inputs) @@ -98,7 +101,7 @@ def test_gaussian_smoke(sample_inputs, batch_inputs, event_inputs): (('g', reals()),), (('g', reals()), ('h', reals(4))), ], ids=id_from_inputs) -def test_joint_smoke(sample_inputs, int_event_inputs, real_event_inputs): +def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): event_inputs = int_event_inputs + real_event_inputs discrete_inputs = OrderedDict(int_event_inputs) gaussian_inputs = OrderedDict(event_inputs) @@ -125,3 +128,64 @@ def test_joint_smoke(sample_inputs, int_event_inputs, real_event_inputs): assert y is x if xfail: pytest.xfail(reason='Not implemented') + + +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(4)),), + (('b', bint(4)), ('c', bint(2))), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', bint(2)),), + (('e', bint(2)), ('f', bint(3))), +], ids=id_from_inputs) +def test_tensor_distribution(event_inputs, batch_inputs): + num_samples = 50000 + sample_inputs = OrderedDict(n=bint(num_samples)) + be_inputs = OrderedDict(batch_inputs + event_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + sampled_vars = frozenset(event_inputs) + p = random_tensor(be_inputs) + + q = p.sample(sampled_vars, sample_inputs) - ops.log(num_samples) + mq = materialize(q).reduce(ops.logaddexp, 'n') + mq = mq.align(tuple(p.inputs)) + assert_close(mq, p, atol=0.1, rtol=None) + + +@pytest.mark.skip(reason='infinite loop') +@pytest.mark.parametrize('batch_inputs', [ + (), + (('b', bint(4)),), + (('b', bint(4)), ('c', bint(5))), +], ids=id_from_inputs) +@pytest.mark.parametrize('event_inputs', [ + (('e', reals()),), + (('e', reals()), ('f', reals(2))), +], ids=id_from_inputs) +def test_gaussian_distribution(event_inputs, batch_inputs): + num_samples = 10000 + sample_inputs = OrderedDict(n=bint(num_samples)) + be_inputs = OrderedDict(batch_inputs + event_inputs) + batch_inputs = OrderedDict(batch_inputs) + event_inputs = OrderedDict(event_inputs) + sampled_vars = frozenset(event_inputs) + p = random_gaussian(be_inputs) + + q = p.sample(sampled_vars, sample_inputs) - ops.log(num_samples) + p_vars = sampled_vars + q_vars = sampled_vars | frozenset(['n']) + # Check zeroth moment. + assert_close(q.reduce(ops.logaddexp, q_vars), + p.reduce(ops.logaddexp, p_vars), atol=1e-6, rtol=None) + for k1, d1 in event_inputs.item(): + x = Variable(k1, d1) + # Check first moments. + assert_close((q.exp() * x).reduce(ops.add, q_vars), + (p.exp() * x).reduce(ops.add, p_vars), atol=1e-2, rtol=None) + for k2, d2 in event_inputs.item(): + y = Variable(k2, d2) + # Check second moments. + assert_close((q.exp() * (x * y)).reduce(ops.add, q_vars), + (p.exp() * (x * y)).reduce(ops.add, p_vars), atol=1e-2, rtol=None) From 0e0a24e18c2eedbe5a115c5c84389b9bbe936177 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 14 Mar 2019 19:08:26 -0700 Subject: [PATCH 07/10] Add test for gradient of Tensor.sample() --- funsor/torch.py | 21 +++++++++++---------- test/test_samplers.py | 18 +++++++++++++++--- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/funsor/torch.py b/funsor/torch.py index d19cfc41c..7e715030b 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -263,36 +263,37 @@ def sample(self, sampled_vars, sample_inputs=None): flat_sample = torch.distributions.Categorical(logits=flat_logits).sample(sample_shape) assert flat_sample.shape == sample_shape + batch_shape results = [] + mod_sample = flat_sample for name, domain in reversed(list(event_inputs.items())): size = domain.dtype - point = Tensor(flat_sample % size, sb_inputs, size) - flat_sample = flat_sample / size + point = Tensor(mod_sample % size, sb_inputs, size) + mod_sample = mod_sample / size results.append(Delta(name, point)) # Account for the log normalizer factor. # Derivation: Let f be a nonnormalized distribution (a funsor), and # consider operations in linear space (source code is in log space). - # Let x0 ~ f/|f| be a monte carlo sample from a normalized f. + # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. # f(x0) / |f| # dice numerator # Let g = delta(x=x0) |f| ----------------- # detach(f(x0)/|f|) # dice denominator - # f(x0) |detach(f)| + # |detach(f)| f(x0) # = delta(x=x0) ----------------- be a dice approximation of f. # detach(f(x0)) - # Then g should be an unbiased estimator of f in value and all - # derivatives, including the normalier |f|. - # In case f = detach(f), we can simplify to - # g = delta(x=x0) |detach(f)|. + # Then g is an unbiased estimator of f in value and all derivatives. + # In the special case f = detach(f), we can simplify to + # g = delta(x=x0) |f|. if flat_logits.requires_grad: # Apply a dice factor to preserve differentiability. - index = [torch.arange(n).reshape((n,) + (1,) * (flat_sample.dim() - i)) - for i, n in enumerate(flat_sample.shape)] + index = [torch.arange(n).reshape((n,) + (1,) * (flat_logits.dim() - i - 2)) + for i, n in enumerate(flat_logits.shape[:-1])] index.append(flat_sample) log_prob = flat_logits[index] assert log_prob.shape == flat_sample.shape results.append(Tensor(flat_logits.detach().logsumexp(-1) + (log_prob - log_prob.detach()), sb_inputs)) else: + # This is the special case f = detach(f). results.append(Tensor(flat_logits.logsumexp(-1), batch_inputs)) return reduce(ops.add, results) diff --git a/test/test_samplers.py b/test/test_samplers.py index 20e6f65d4..f71d29fb2 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -4,13 +4,15 @@ from collections import OrderedDict import pytest +import torch +from torch.autograd import grad import funsor.ops as ops from funsor.domains import bint, reals from funsor.joint import Joint from funsor.terms import Variable from funsor.testing import assert_close, id_from_inputs, random_gaussian, random_tensor -from funsor.torch import materialize +from funsor.torch import align_tensors, materialize @pytest.mark.parametrize('sample_inputs', [ @@ -133,13 +135,14 @@ def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @pytest.mark.parametrize('batch_inputs', [ (), (('b', bint(4)),), - (('b', bint(4)), ('c', bint(2))), + (('b', bint(3)), ('c', bint(2))), ], ids=id_from_inputs) @pytest.mark.parametrize('event_inputs', [ (('e', bint(2)),), (('e', bint(2)), ('f', bint(3))), ], ids=id_from_inputs) -def test_tensor_distribution(event_inputs, batch_inputs): +@pytest.mark.parametrize('test_grad', [False, True], ids=['value', 'grad']) +def test_tensor_distribution(event_inputs, batch_inputs, test_grad): num_samples = 50000 sample_inputs = OrderedDict(n=bint(num_samples)) be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -147,12 +150,21 @@ def test_tensor_distribution(event_inputs, batch_inputs): event_inputs = OrderedDict(event_inputs) sampled_vars = frozenset(event_inputs) p = random_tensor(be_inputs) + p.data.requires_grad_(test_grad) q = p.sample(sampled_vars, sample_inputs) - ops.log(num_samples) mq = materialize(q).reduce(ops.logaddexp, 'n') mq = mq.align(tuple(p.inputs)) assert_close(mq, p, atol=0.1, rtol=None) + if test_grad: + _, (p_data, mq_data) = align_tensors(p, mq) + assert p_data.shape == mq_data.shape + probe = torch.randn(p_data.shape) + expected = grad((p_data.exp() * probe).sum(), [p.data])[0] + actual = grad((mq_data.exp() * probe).sum(), [p.data])[0] + assert_close(actual, expected, atol=0.1, rtol=None) + @pytest.mark.skip(reason='infinite loop') @pytest.mark.parametrize('batch_inputs', [ From 98202e6c4e1481e169a6a8749a275e8cc7996410 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 15 Mar 2019 11:51:58 -0700 Subject: [PATCH 08/10] Account for log_normalizer in Gaussian.sample() --- funsor/gaussian.py | 24 +++++++++++++++++++----- funsor/joint.py | 41 +++++++++++++++++++++++++++++------------ funsor/testing.py | 11 +++++++++-- 3 files changed, 57 insertions(+), 19 deletions(-) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 7d7eef371..b6add5840 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -14,6 +14,7 @@ from funsor.ops import AddOp from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager from funsor.torch import Tensor, align_tensor, align_tensors, materialize +from funsor.util import lazy_property def _issubshape(subshape, supershape): @@ -125,6 +126,13 @@ class Gaussian(Funsor): """ Funsor representing a batched joint Gaussian distribution as a log-density function. + + Note that :class:`Gaussian`s are not normalized, rather they are + canonicalized to evaluate to zero at their maximum value (at ``loc``). This + canonical form is useful because it allows :class:`Gaussian`s with + incomplete information, i.e. zero eigenvalues in the precision matrix. + These incomplete distributions arise when making low-dimensional + observations on higher dimensional hidden state. """ def __init__(self, loc, precision, inputs): assert isinstance(loc, torch.Tensor) @@ -208,6 +216,14 @@ def eager_subs(self, subs): raise NotImplementedError('TODO implement partial substitution of real variables') + @lazy_property + def _log_normalizer(self): + dim = self.loc.size(-1) + log_det_term = _log_det_tril(torch.cholesky(self.precision)) + data = log_det_term - 0.5 * math.log(2 * math.pi) * dim + inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') + return Tensor(data, inputs) + def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. @@ -219,13 +235,10 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals) - int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') if reduced_reals == real_vars: - dim = self.loc.size(-1) - log_det_term = _log_det_tril(torch.cholesky(self.precision)) - data = log_det_term - 0.5 * math.log(2 * math.pi) * dim - result = Tensor(data, int_inputs) + result = self._log_normalizer else: + int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') offsets, _ = _compute_offsets(self.inputs) index = [] for key, domain in inputs.items(): @@ -282,6 +295,7 @@ def sample(self, sampled_vars, sample_inputs=None): point = Tensor(data, inputs) assert point.output == domain results.append(Delta(key, point)) + results.append(self._log_normalizer) return reduce(ops.add, results) raise NotImplementedError('TODO implement partial sampling of real variables') diff --git a/funsor/joint.py b/funsor/joint.py index 549fb924a..a32a805e3 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -75,23 +75,40 @@ def eager_subs(self, subs): def eager_reduce(self, op, reduced_vars): if op is ops.logaddexp: - # Integrate out delayed discrete variables. - discrete_vars = reduced_vars.intersection(self.discrete.inputs) - mixture_params = frozenset(self.gaussian.inputs).union(*(x.point.inputs for x in self.deltas)) - lazy_vars = discrete_vars & mixture_params # Mixtures must remain lazy. - discrete_vars -= mixture_params + # Integrate out degenerate variables, i.e. drop selected delta. + deltas = [] + remaining_vars = set(reduced_vars) + for d in self.deltas: + if d.name in reduced_vars: + remaining_vars.remove(d.name) + else: + deltas.append(d) + deltas = tuple(deltas) + reduced_vars = frozenset(remaining_vars) + + # Integrate out delayed discrete variables, but keep mixtures lazy. + lazy_vars = reduced_vars.difference(self.gaussian.inputs, *(x.inputs for x in deltas)) + discrete_vars = reduced_vars.intersection(self.discrete.inputs).difference(lazy_vars) discrete = self.discrete.reduce(op, discrete_vars) + reduced_vars = reduced_vars.difference(discrete_vars, lazy_vars) # Integrate out delayed gaussian variables. gaussian_vars = reduced_vars.intersection(self.gaussian.inputs) gaussian = self.gaussian.reduce(ops.logaddexp, gaussian_vars) - assert (reduced_vars - gaussian_vars).issubset(d.name for d in self.deltas) - - # Integrate out delayed degenerate variables, i.e. drop them. - deltas = tuple(d for d in self.deltas if d.name not in reduced_vars) - - assert not lazy_vars - return (Joint(deltas, discrete) + gaussian).reduce(ops.logaddexp, lazy_vars) + reduced_vars = reduced_vars.difference(gaussian_vars) + + # Account for remaining reduced vars that were inputs to dropped deltas. + eager_result = Joint(deltas, discrete) + gaussian + reduced_vars |= lazy_vars.difference(eager_result.inputs) + lazy_vars = lazy_vars.intersection(eager_result.inputs) + if reduced_vars: + eager_result += ops.log(reduce(ops.mul, [self.inputs[v].dtype for v in reduced_vars])) + + # Return a value only if progress has been made. + if eager_result is self: + return None # defer to default implementation + else: + return eager_result.reduce(ops.logaddexp, lazy_vars) if op is ops.add: raise NotImplementedError('TODO product-reduce along a plate dimension') diff --git a/funsor/testing.py b/funsor/testing.py index 1880323e1..3c05d28db 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -2,6 +2,7 @@ import contextlib import itertools +import numbers import operator from collections import OrderedDict, namedtuple @@ -15,7 +16,7 @@ from funsor.gaussian import Gaussian from funsor.joint import Joint from funsor.numpy import Array -from funsor.terms import Funsor +from funsor.terms import Funsor, Number from funsor.torch import Tensor @@ -52,7 +53,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): assert actual.inputs == expected.inputs, (actual.inputs, expected.inputs) assert actual.output == expected.output, (actual.output, expected.output) - if isinstance(actual, Tensor): + if isinstance(actual, (Number, Tensor)): assert_close(actual.data, expected.data, atol=atol, rtol=rtol) elif isinstance(actual, Gaussian): assert_close(actual.loc, expected.loc, atol=atol, rtol=rtol) @@ -81,6 +82,12 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): assert diff.max() < atol, msg if rtol is not None: assert (diff / (atol + expected.detach().abs())).max() < rtol, msg + elif isinstance(actual, numbers.Number): + diff = abs(actual - expected) + if atol is not None: + assert diff < atol, msg + if rtol is not None: + assert diff < (atol + expected) * rtol, msg else: raise ValueError('cannot compare objects of type {}'.format(type(actual))) From 859cb1c0b5301ff58bd1414dd2a5d0c69b65dbda Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 15 Mar 2019 12:26:06 -0700 Subject: [PATCH 09/10] Fix bugs in Gaussian.sample() --- funsor/joint.py | 9 +++++++++ test/test_samplers.py | 5 +++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/funsor/joint.py b/funsor/joint.py index a32a805e3..e62a02ea9 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -3,6 +3,7 @@ from collections import OrderedDict from six import add_metaclass +from six.moves import reduce import funsor.ops as ops from funsor.delta import Delta @@ -187,6 +188,14 @@ def eager_add(op, joint, other): return Joint(joint.deltas, joint.discrete + other, joint.gaussian) +@eager.register(Binary, Op, Joint, (Number, Tensor)) +def eager_add(op, joint, other): + if op is ops.sub: + return joint + -other + + return None # defer to default implementation + + @eager.register(Binary, AddOp, Joint, Gaussian) def eager_add(op, joint, other): # Update with a delayed gaussian random variable. diff --git a/test/test_samplers.py b/test/test_samplers.py index f71d29fb2..5aadec8b1 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -166,7 +166,6 @@ def test_tensor_distribution(event_inputs, batch_inputs, test_grad): assert_close(actual, expected, atol=0.1, rtol=None) -@pytest.mark.skip(reason='infinite loop') @pytest.mark.parametrize('batch_inputs', [ (), (('b', bint(4)),), @@ -191,7 +190,9 @@ def test_gaussian_distribution(event_inputs, batch_inputs): # Check zeroth moment. assert_close(q.reduce(ops.logaddexp, q_vars), p.reduce(ops.logaddexp, p_vars), atol=1e-6, rtol=None) - for k1, d1 in event_inputs.item(): + + pytest.xfail(reason='infinite loop') + for k1, d1 in event_inputs.items(): x = Variable(k1, d1) # Check first moments. assert_close((q.exp() * x).reduce(ops.add, q_vars), From e0cdc166530e7f3d4ecee437db589093ca46ccf3 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 15 Mar 2019 14:18:19 -0700 Subject: [PATCH 10/10] Add stub to use Integrate in test_samplers.py --- test/test_samplers.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_samplers.py b/test/test_samplers.py index 5aadec8b1..9764ca676 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -135,11 +135,11 @@ def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @pytest.mark.parametrize('batch_inputs', [ (), (('b', bint(4)),), - (('b', bint(3)), ('c', bint(2))), + (('b', bint(2)), ('c', bint(2))), ], ids=id_from_inputs) @pytest.mark.parametrize('event_inputs', [ - (('e', bint(2)),), - (('e', bint(2)), ('f', bint(3))), + (('e', bint(3)),), + (('e', bint(2)), ('f', bint(2))), ], ids=id_from_inputs) @pytest.mark.parametrize('test_grad', [False, True], ids=['value', 'grad']) def test_tensor_distribution(event_inputs, batch_inputs, test_grad): @@ -166,6 +166,11 @@ def test_tensor_distribution(event_inputs, batch_inputs, test_grad): assert_close(actual, expected, atol=0.1, rtol=None) +# This is a stub for a future PR. +def Integrate(log_measure, integrand, reduced_vars): + pytest.xfail(reason='Integrate is not implemented') + + @pytest.mark.parametrize('batch_inputs', [ (), (('b', bint(4)),), @@ -190,15 +195,13 @@ def test_gaussian_distribution(event_inputs, batch_inputs): # Check zeroth moment. assert_close(q.reduce(ops.logaddexp, q_vars), p.reduce(ops.logaddexp, p_vars), atol=1e-6, rtol=None) - - pytest.xfail(reason='infinite loop') for k1, d1 in event_inputs.items(): x = Variable(k1, d1) # Check first moments. - assert_close((q.exp() * x).reduce(ops.add, q_vars), - (p.exp() * x).reduce(ops.add, p_vars), atol=1e-2, rtol=None) + assert_close(Integrate(q, x, q_vars), + Integrate(p, x, p_vars), atol=1e-2, rtol=None) for k2, d2 in event_inputs.item(): y = Variable(k2, d2) # Check second moments. - assert_close((q.exp() * (x * y)).reduce(ops.add, q_vars), - (p.exp() * (x * y)).reduce(ops.add, p_vars), atol=1e-2, rtol=None) + assert_close(Integrate(q, x * y, q_vars), + Integrate(p, x * y, p_vars), atol=1e-2, rtol=None)