Skip to content

Commit

Permalink
Add plated einsum implementation (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored and fritzo committed Mar 6, 2019
1 parent 448c1fa commit 4537e90
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 53 deletions.
5 changes: 3 additions & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from funsor.domains import Domain, bint, find_domain, reals
from funsor.interpreter import reinterpret
from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor
from funsor.torch import Function, Tensor, arange, einsum, function
from funsor.torch import Function, Tensor, arange, torch_einsum, function

from . import distributions, domains, gaussian, handlers, interpreter, minipyro, ops, terms, torch
from . import distributions, domains, einsum, gaussian, handlers, interpreter, minipyro, ops, terms, torch

__all__ = [
'Domain',
Expand Down Expand Up @@ -33,4 +33,5 @@
'terms',
'to_funsor',
'torch',
'torch_einsum',
]
133 changes: 133 additions & 0 deletions funsor/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import absolute_import, division, print_function

from collections import defaultdict, OrderedDict
from six.moves import reduce

import funsor.ops as ops
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.terms import Funsor, reflect


def naive_einsum(eqn, *terms, **kwargs):
backend = kwargs.pop('backend', 'torch')
if backend == 'torch':
sum_op, prod_op = ops.add, ops.mul
elif backend == 'pyro.ops.einsum.torch_log':
sum_op, prod_op = ops.logaddexp, ops.add
else:
raise ValueError("{} backend not implemented".format(backend))

assert isinstance(eqn, str)
assert all(isinstance(term, Funsor) for term in terms)
inputs, output = eqn.split('->')
assert len(output.split(',')) == 1
input_dims = frozenset(d for inp in inputs.split(',') for d in inp)
output_dims = frozenset(output)
reduce_dims = input_dims - output_dims
return reduce(prod_op, terms).reduce(sum_op, reduce_dims)


def _partition(terms, sum_vars):
# Construct a bipartite graph between terms and the vars
neighbors = OrderedDict([(t, []) for t in terms])
for term in terms:
for dim in term.inputs.keys():
if dim in sum_vars:
neighbors[term].append(dim)
neighbors.setdefault(dim, []).append(term)

# Partition the bipartite graph into connected components for contraction.
components = []
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)

# Split this connected component into tensors and dims.
component_terms = tuple(v for v in component if isinstance(v, Funsor))
if component_terms:
component_dims = frozenset(v for v in component if not isinstance(v, Funsor))
components.append((component_terms, component_dims))
return components


def naive_plated_einsum(eqn, *terms, **kwargs):
"""
Implements Tensor Variable Elimination (Algorithm 1 in [Obermeyer et al 2019])
[Obermeyer et al 2019] Obermeyer, F., Bingham, E., Jankowiak, M., Chiu, J.,
Pradhan, N., Rush, A., and Goodman, N. Tensor Variable Elimination for
Plated Factor Graphs, 2019
"""
plates = kwargs.pop('plates', '')
if not plates:
return naive_einsum(eqn, *terms, **kwargs)

backend = kwargs.pop('backend', 'torch')
if backend == 'torch':
sum_op, prod_op = ops.add, ops.mul
elif backend == 'pyro.ops.einsum.torch_log':
sum_op, prod_op = ops.logaddexp, ops.add
else:
raise ValueError("{} backend not implemented".format(backend))

assert isinstance(eqn, str)
assert all(isinstance(term, Funsor) for term in terms)
inputs, output = eqn.split('->')
assert len(output.split(',')) == 1
input_dims = frozenset(d for inp in inputs.split(',') for d in inp)
output_dims = frozenset(d for d in output)
plate_dims = frozenset(plates) - output_dims
reduce_vars = input_dims - output_dims - frozenset(plates)

if output_dims:
raise NotImplementedError("TODO")

var_tree = {}
term_tree = defaultdict(list)
for term in terms:
ordinal = frozenset(term.inputs) & plate_dims
term_tree[ordinal].append(term)
for var in term.inputs:
if var not in plate_dims:
var_tree[var] = var_tree.get(var, ordinal) & ordinal

ordinal_to_var = defaultdict(set)
for var, ordinal in var_tree.items():
ordinal_to_var[ordinal].add(var)

# Direct translation of Algorithm 1
scalars = []
while term_tree:
leaf = max(term_tree, key=len)
leaf_terms = term_tree.pop(leaf)
leaf_reduce_vars = ordinal_to_var[leaf]
for (group_terms, group_vars) in _partition(leaf_terms, leaf_reduce_vars):
term = reduce(prod_op, group_terms).reduce(sum_op, group_vars)
remaining_vars = frozenset(term.inputs) & reduce_vars
if not remaining_vars:
scalars.append(term.reduce(prod_op, leaf))
else:
new_plates = frozenset().union(
*(var_tree[v] for v in remaining_vars))
if new_plates == leaf:
raise ValueError("intractable!")
term = term.reduce(prod_op, leaf - new_plates)
term_tree[new_plates].append(term)

return reduce(prod_op, scalars)


def einsum(eqn, *terms, **kwargs):
with interpretation(reflect):
naive_ast = naive_plated_einsum(eqn, *terms, **kwargs)
optimized_ast = apply_optimizer(naive_ast)
return reinterpret(optimized_ast) # eager by default
20 changes: 19 additions & 1 deletion funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ def sample(x, y):
raise ValueError


def reciprocal(x):
if isinstance(x, Number):
return 1. / x
if isinstance(x, torch.Tensor):
result = x.reciprocal()
result.clamp_(max=torch.finfo(result.dtype).max)
return result
raise ValueError("No reciprocal for type {}".format(type(x)))


REDUCE_OP_TO_TORCH = {
add: torch.sum,
mul: torch.prod,
Expand Down Expand Up @@ -115,10 +125,17 @@ def sample(x, y):
])


PRODUCT_INVERSES = {
mul: reciprocal,
add: neg,
}


__all__ = [
'REDUCE_OP_TO_TORCH',
'ASSOCIATIVE_OPS',
'DISTRIBUTIVE_OPS',
'PRODUCT_INVERSES',
'REDUCE_OP_TO_TORCH',
'abs',
'add',
'and_',
Expand All @@ -137,6 +154,7 @@ def sample(x, y):
'neg',
'or_',
'pow',
'reciprocal',
'sample',
'sub',
'truediv',
Expand Down
27 changes: 1 addition & 26 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import torch
from six.moves import reduce

import funsor.ops as ops
from funsor.domains import Domain, bint
from funsor.gaussian import Gaussian
from funsor.terms import Binary, Funsor
from funsor.terms import Funsor
from funsor.torch import Tensor


Expand Down Expand Up @@ -140,27 +139,3 @@ def random_gaussian(inputs):
prec_sqrt = torch.randn(batch_shape + event_shape + event_shape)
precision = torch.matmul(prec_sqrt, prec_sqrt.transpose(-1, -2))
return Gaussian(log_density, loc, precision, inputs)


def naive_einsum(eqn, *terms, **kwargs):
backend = kwargs.pop('backend', 'torch')
if backend == 'torch':
sum_op, prod_op = ops.add, ops.mul
elif backend == 'pyro.ops.einsum.torch_log':
sum_op, prod_op = ops.logaddexp, ops.add
else:
raise ValueError("{} backend not implemented".format(backend))

assert isinstance(eqn, str)
assert all(isinstance(term, Funsor) for term in terms)
inputs, output = eqn.split('->')
assert len(output.split(',')) == 1
input_dims = frozenset(d for inp in inputs.split(',') for d in inp)
output_dims = frozenset(d for d in output)
reduce_dims = tuple(d for d in input_dims - output_dims)
prod = terms[0]
for term in terms[1:]:
prod = Binary(prod_op, prod, term)
for reduce_dim in reduce_dims:
prod = prod.reduce(sum_op, reduce_dim)
return prod
4 changes: 2 additions & 2 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def mvn_log_prob(loc, scale_tril, x):
return functools.partial(_function, inputs, output)


def einsum(equation, *operands):
def torch_einsum(equation, *operands):
"""
Wrapper around :func:`torch.einsum` to operate on real-valued Funsors.
Expand Down Expand Up @@ -412,7 +412,7 @@ def einsum(equation, *operands):
'align_tensor',
'align_tensors',
'arange',
'einsum',
'torch_einsum',
'function',
'materialize',
]
41 changes: 22 additions & 19 deletions test/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@
from funsor.torch import Tensor
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import apply_optimizer
from funsor.testing import assert_close, make_einsum_example

from funsor.testing import assert_close, make_einsum_example, naive_einsum


def naive_plated_einsum(eqn, *terms, **kwargs):
assert isinstance(eqn, str)
assert all(isinstance(term, funsor.Funsor) for term in terms)
# ...
raise NotImplementedError("TODO implement naive plated einsum")
from funsor.einsum import naive_einsum, naive_plated_einsum


EINSUM_EXAMPLES = [
Expand Down Expand Up @@ -103,25 +97,34 @@ def test_einsum_categorical(equation):
assert actual.inputs[output_dim].dtype == sizes[output_dim]


PLATED_EINSUM_EXAMPLES = [(ex, '') for ex in EINSUM_EXAMPLES] + [
PLATED_EINSUM_EXAMPLES = [
('i->', 'i'),
('i->i', 'i'),
(',i->', 'i'),
(',i->i', 'i'),
('ai->', 'i'),
('ai->i', 'i'),
('ai->ai', 'i'),
(',ai,abij->aij', 'ij'),
('a,ai,bij->bij', 'ij'),
(',ai,abij->', 'ij'),
('a,ai,bij->', 'ij'),
('ai,abi,bci,cdi->', 'i'),
('aij,abij,bcij,cdij->', 'ij'),
('a,abi,bcij,cdij->', 'ij'),
]


@pytest.mark.xfail(reason="naive plated einsum not implemented")
@pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES)
def test_plated_einsum(equation, plates):
@pytest.mark.parametrize('backend', ['torch', 'pyro.ops.einsum.torch_log'])
def test_plated_einsum(equation, plates, backend):
inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)
expected = naive_ubersum(equation, *operands, plates=plates, backend='torch', modulo_total=False)[0]
actual = naive_plated_einsum(equation, *funsor_operands, plates=plates)
expected = naive_ubersum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0]
with interpretation(reflect):
naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend)
optimized_ast = apply_optimizer(naive_ast)
actual_optimized = reinterpret(optimized_ast) # eager by default
actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend)

assert_close(actual, actual_optimized, atol=1e-4)

if len(outputs[0]) > 0:
actual = actual.align(tuple(outputs[0]))

assert expected.shape == actual.data.shape
assert torch.allclose(expected, actual.data)
for output in outputs:
Expand Down
50 changes: 49 additions & 1 deletion test/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import opt_einsum
import torch
import pyro.ops.contract as pyro_einsum

import funsor

Expand All @@ -15,7 +16,8 @@
from funsor.terms import reflect, Variable
from funsor.torch import Tensor

from funsor.testing import make_einsum_example, naive_einsum
from funsor.testing import make_einsum_example, assert_close
from funsor.einsum import naive_einsum, naive_plated_einsum, einsum


def make_chain_einsum(num_steps):
Expand Down Expand Up @@ -101,3 +103,49 @@ def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2):

assert torch.allclose(expected1, actual1.data)
assert torch.allclose(expected2, actual2.data)


def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0):

assert num_obs_plates >= num_hidden_plates
t0 = num_obs_plates

obs_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_obs_plates))
hidden_plates = ''.join(opt_einsum.get_symbol(i) for i in range(num_hidden_plates))

inputs = [str(opt_einsum.get_symbol(t0))]
for t in range(t0, num_steps+t0):
inputs.append(str(opt_einsum.get_symbol(t)) + str(opt_einsum.get_symbol(t+1)) + hidden_plates)
inputs.append(str(opt_einsum.get_symbol(t+1)) + obs_plates)
equation = ",".join(inputs) + "->"
return (equation, ''.join(set(obs_plates + hidden_plates)))


PLATED_EINSUM_EXAMPLES = [
make_plated_hmm_einsum(num_steps, num_obs_plates=b, num_hidden_plates=a)
for num_steps in range(2, 6)
for (a, b) in [(0, 1), (0, 2), (0, 0), (1, 1), (1, 2), (1, 2)]
]


@pytest.mark.parametrize('equation,plates', PLATED_EINSUM_EXAMPLES)
@pytest.mark.parametrize('backend', ['pyro.ops.einsum.torch_log'])
def test_optimized_plated_einsum(equation, plates, backend):
inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)
expected = pyro_einsum.einsum(equation, *operands, plates=plates, backend=backend)[0]
actual = einsum(equation, *funsor_operands, plates=plates, backend=backend)

if len(equation) < 10:
actual_naive = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend)
assert_close(actual, actual_naive)

assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
if len(outputs[0]) > 0:
actual = actual.align(tuple(outputs[0]))

assert expected.shape == actual.data.shape
assert torch.allclose(expected, actual.data)
for output in outputs:
for i, output_dim in enumerate(output):
assert output_dim in actual.inputs
assert actual.inputs[output_dim].dtype == sizes[output_dim]
Loading

0 comments on commit 4537e90

Please sign in to comment.