Skip to content

Commit

Permalink
Implement plate reductions for Gaussian, Joint (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Mar 29, 2019
1 parent 37ce962 commit b9cdbac
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 12 deletions.
38 changes: 32 additions & 6 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import absolute_import, division, print_function

import math
import sys
import warnings
from collections import OrderedDict

import six
import torch
from pyro.distributions.util import broadcast_shape
from six import add_metaclass, integer_types
Expand Down Expand Up @@ -72,11 +71,15 @@ def _sym_solve_mv(mat, vec):
tri = torch.inverse(torch.cholesky(mat))
return _mv(tri.transpose(-1, -2), _mv(tri, vec))
except RuntimeError as e:
if 'not positive definite' not in e.message:
_, exc_value, traceback = sys.exc_info()
six.reraise(RuntimeError, e, traceback)
warnings.warn(e.message, RuntimeWarning)

# Fall back to pseudoinverse.
if mat.size(-1) == 1:
mat = mat.squeeze(-1)
mat, vec = torch.broadcast_tensors(mat, vec)
result = vec / mat
result[(mat != 0) == 0] = 0
return result
return _mv(torch.pinverse(mat), vec)


Expand Down Expand Up @@ -270,6 +273,9 @@ def eager_subs(self, subs):
assert result.output == reals()
return Subs(result, lazy_subs)

# Perform a partial substution of a subset of real variables, resulting in a Joint.
# See "The Matrix Cookbook" (November 15, 2012) ss. 8.1.3 eq. 353.
# http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
raise NotImplementedError('TODO implement partial substitution of real variables')

@lazy_property
Expand Down Expand Up @@ -317,7 +323,27 @@ def eager_reduce(self, op, reduced_vars):
return result.reduce(ops.logaddexp, reduced_ints)

elif op is ops.add:
raise NotImplementedError('TODO product-reduce along a plate dimension')
for v in reduced_vars:
if self.inputs[v].dtype == 'real':
raise ValueError("Cannot sum along a real dimension: {}".format(repr(v)))

# Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian().
old_ints = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real')
new_ints = OrderedDict((k, v) for k, v in old_ints.items() if k not in reduced_vars)
inputs = OrderedDict((k, v) for k, v in self.inputs.items() if k not in reduced_vars)

precision = Tensor(self.precision, old_ints).reduce(ops.add, reduced_vars)
precision_loc = Tensor(_mv(self.precision, self.loc),
old_ints).reduce(ops.add, reduced_vars)
assert precision.inputs == new_ints
assert precision_loc.inputs == new_ints
loc = Tensor(_sym_solve_mv(precision.data, precision_loc.data), new_ints)
expanded_loc = align_tensor(old_ints, loc)
quadratic_term = Tensor(_vmv(self.precision, expanded_loc - self.loc),
old_ints).reduce(ops.add, reduced_vars)
assert quadratic_term.inputs == new_ints
likelihood = -0.5 * quadratic_term
return likelihood + Gaussian(loc.data, precision.data, inputs)

return None # defer to default implementation

Expand Down
5 changes: 4 additions & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def eager_reduce(self, op, reduced_vars):
return eager_result.reduce(ops.logaddexp, lazy_vars)

if op is ops.add:
raise NotImplementedError('TODO product-reduce along a plate dimension')
terms = list(self.deltas) + [self.discrete, self.gaussian]
for i, term in enumerate(terms):
terms[i] = term.reduce(ops.add, reduced_vars.intersection(term.inputs))
return reduce(ops.add, terms)

return None # defer to default implementation

Expand Down
18 changes: 17 additions & 1 deletion test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,22 @@ def test_add_gaussian_gaussian(lhs_inputs, rhs_inputs):
assert_close((g1 + g2)(**values), g1(**values) + g2(**values), atol=1e-4, rtol=None)


@pytest.mark.parametrize('inputs', [
OrderedDict([('i', bint(2)), ('x', reals())]),
OrderedDict([('i', bint(3)), ('x', reals())]),
OrderedDict([('i', bint(2)), ('x', reals(2))]),
OrderedDict([('i', bint(2)), ('x', reals()), ('y', reals())]),
OrderedDict([('i', bint(3)), ('j', bint(4)), ('x', reals(2))]),
], ids=id_from_inputs)
def test_reduce_add(inputs):
g = random_gaussian(inputs)
actual = g.reduce(ops.add, 'i')

gs = [g(i=i) for i in range(g.inputs['i'].dtype)]
expected = reduce(ops.add, gs)
assert_close(actual, expected)


@pytest.mark.parametrize('int_inputs', [
{},
{'i': bint(2)},
Expand All @@ -233,7 +249,7 @@ def test_add_gaussian_gaussian(lhs_inputs, rhs_inputs):
{'x': reals(4), 'y': reals(2, 3), 'z': reals()},
{'w': reals(5), 'x': reals(4), 'y': reals(2, 3), 'z': reals()},
], ids=id_from_inputs)
def test_logsumexp(int_inputs, real_inputs):
def test_reduce_logsumexp(int_inputs, real_inputs):
int_inputs = OrderedDict(sorted(int_inputs.items()))
real_inputs = OrderedDict(sorted(real_inputs.items()))
inputs = int_inputs.copy()
Expand Down
28 changes: 24 additions & 4 deletions test/test_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
from six.moves import reduce

import funsor.ops as ops
from funsor.delta import Delta
Expand Down Expand Up @@ -132,7 +133,7 @@ def test_smoke(expr, expected_type):
{'x': reals(2), 'y': reals(3)},
{'x': reals(4), 'y': reals(2, 3), 'z': reals()},
], ids=id_from_inputs)
def test_reduce(int_inputs, real_inputs):
def test_reduce_logaddexp(int_inputs, real_inputs):
int_inputs = OrderedDict(sorted(int_inputs.items()))
real_inputs = OrderedDict(sorted(real_inputs.items()))
inputs = int_inputs.copy()
Expand All @@ -154,7 +155,7 @@ def test_reduce(int_inputs, real_inputs):
assert_close(actual, expected)


def test_reduce_deltas_lazy():
def test_reduce_logaddexp_deltas_lazy():
a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3))))
b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3))))
x = a + b
Expand All @@ -167,7 +168,7 @@ def test_reduce_deltas_lazy():
assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))


def test_reduce_deltas_discrete_lazy():
def test_reduce_logaddexp_deltas_discrete_lazy():
a = Delta('a', Tensor(torch.randn(3, 2), OrderedDict(i=bint(3))))
b = Delta('b', Tensor(torch.randn(3), OrderedDict(i=bint(3))))
c = Tensor(torch.randn(3), OrderedDict(i=bint(3)))
Expand All @@ -181,7 +182,7 @@ def test_reduce_deltas_discrete_lazy():
assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))


def test_reduce_gaussian_lazy():
def test_reduce_logaddexp_gaussian_lazy():
a = random_gaussian(OrderedDict(i=bint(3), a=reals(2)))
b = random_tensor(OrderedDict(i=bint(3), b=bint(2)))
x = a + b
Expand All @@ -192,3 +193,22 @@ def test_reduce_gaussian_lazy():
assert isinstance(y, Reduce)
assert set(y.inputs) == {'a', 'b'}
assert_close(x.reduce(ops.logaddexp), y.reduce(ops.logaddexp))


@pytest.mark.parametrize('inputs', [
OrderedDict([('i', bint(2)), ('x', reals())]),
OrderedDict([('i', bint(3)), ('x', reals())]),
OrderedDict([('i', bint(2)), ('x', reals(2))]),
OrderedDict([('i', bint(2)), ('x', reals()), ('y', reals())]),
OrderedDict([('i', bint(3)), ('j', bint(4)), ('x', reals(2))]),
OrderedDict([('j', bint(2)), ('i', bint(3)), ('k', bint(2)), ('x', reals(2))]),
], ids=id_from_inputs)
def test_reduce_add(inputs):
int_inputs = OrderedDict((k, d) for k, d in inputs.items() if d.dtype != 'real')
x = random_gaussian(inputs) + random_tensor(int_inputs)
assert isinstance(x, Joint)
actual = x.reduce(ops.add, 'i')

xs = [x(i=i) for i in range(x.inputs['i'].dtype)]
expected = reduce(ops.add, xs)
assert_close(actual, expected, atol=1e-3, rtol=1e-4)

0 comments on commit b9cdbac

Please sign in to comment.