From 99d81ff58b4447ada3868a038d292d3b4a34e2ec Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Mar 2019 11:35:00 -0700 Subject: [PATCH 1/5] Refactor contract dependencies --- funsor/contract.py | 76 +++++++++++++++++----------------------------- funsor/torch.py | 13 ++++++++ 2 files changed, 41 insertions(+), 48 deletions(-) diff --git a/funsor/contract.py b/funsor/contract.py index 383e2100c..ac6d9a9bc 100644 --- a/funsor/contract.py +++ b/funsor/contract.py @@ -1,19 +1,12 @@ from __future__ import absolute_import, division, print_function +import functools from collections import OrderedDict -import opt_einsum - import funsor.ops as ops -from funsor.distributions import Gaussian, Delta from funsor.optimizer import Finitary, optimize from funsor.sum_product import _partition -from funsor.terms import Funsor, Number, Variable, eager -from funsor.torch import Tensor - - -# TODO handle Joint as well -ATOMS = (Tensor, Gaussian, Delta, Number, Variable) +from funsor.terms import Funsor, eager def _order_lhss(lhs, reduced_vars): @@ -29,10 +22,13 @@ def _order_lhss(lhs, reduced_vars): return root_lhs, remaining_lhs -def _simplify_contract(lhs, rhs, reduced_vars): +def _simplify_contract(fn, lhs, rhs, reduced_vars): """ Reduce free variables that do not appear explicitly in the lhs """ + if not reduced_vars: + return lhs * rhs + lhs_vars = frozenset(lhs.inputs) rhs_vars = frozenset(rhs.inputs) assert reduced_vars <= lhs_vars | rhs_vars @@ -45,11 +41,17 @@ def _simplify_contract(lhs, rhs, reduced_vars): lhs = lhs.reduce(ops.add, reduced_vars - rhs_vars) reduced_vars = reduced_vars & rhs_vars progress = True - if progress: return Contract(lhs, rhs, reduced_vars) - return None + return fn(lhs, rhs, reduced_vars) + + +def contractor(fn): + """ + Decorator for contract implementations to simplify inputs. + """ + return functools.partial(_simplify_contract, fn) class Contract(Funsor): @@ -77,50 +79,22 @@ def eager_subs(self, subs): self.reduced_vars) -@optimize.register(Contract, ATOMS[1:], ATOMS, frozenset) -@optimize.register(Contract, ATOMS, ATOMS[1:], frozenset) -@eager.register(Contract, ATOMS[1:], ATOMS, frozenset) -@eager.register(Contract, ATOMS, ATOMS[1:], frozenset) -def contract_ground_ground(lhs, rhs, reduced_vars): - result = _simplify_contract(lhs, rhs, reduced_vars) - if result is not None: - return result - +@optimize.register(Contract, Funsor, Funsor, frozenset) +@eager.register(Contract, Funsor, Funsor, frozenset) +@contractor +def contract_funsor_funsor(lhs, rhs, reduced_vars): return (lhs * rhs).reduce(ops.add, reduced_vars) -@eager.register(Contract, Tensor, Tensor, frozenset) -def eager_contract_tensor_tensor(lhs, rhs, reduced_vars): - result = _simplify_contract(lhs, rhs, reduced_vars) - if result is not None: - return result - - out_inputs = OrderedDict([(k, d) for t in (lhs, rhs) - for k, d in t.inputs.items() if k not in reduced_vars]) - - return Tensor( - opt_einsum.contract(lhs.data, list(lhs.inputs.keys()), - rhs.data, list(rhs.inputs.keys()), - list(out_inputs.keys()), backend="torch"), - out_inputs - ) - - -@optimize.register(Contract, ATOMS, Finitary, frozenset) +@optimize.register(Contract, Funsor, Finitary, frozenset) +@contractor def contract_ground_finitary(lhs, rhs, reduced_vars): - result = _simplify_contract(lhs, rhs, reduced_vars) - if result is not None: - return result - return Contract(rhs, lhs, reduced_vars) -@optimize.register(Contract, Finitary, (Finitary,) + ATOMS, frozenset) +@optimize.register(Contract, Finitary, (Finitary, Funsor), frozenset) +@contractor def contract_finitary_ground(lhs, rhs, reduced_vars): - result = _simplify_contract(lhs, rhs, reduced_vars) - if result is not None: - return result - # exploit linearity of contraction if lhs.op is ops.add: return Finitary( @@ -139,3 +113,9 @@ def contract_finitary_ground(lhs, rhs, reduced_vars): reduced_vars & frozenset(root_lhs.inputs)) return None + + +__all__ = [ + 'Contract', + 'contractor', +] diff --git a/funsor/torch.py b/funsor/torch.py index 62d521a31..e550353f8 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -3,11 +3,13 @@ import functools from collections import OrderedDict +import opt_einsum import torch from six import add_metaclass, integer_types from six.moves import reduce import funsor.ops as ops +from funsor.contract import Contract, contractor from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals from funsor.ops import Op @@ -352,6 +354,17 @@ def eager_binary_tensor_tensor(op, lhs, rhs): return Tensor(data, inputs, dtype) +@eager.register(Contract, Tensor, Tensor, frozenset) +@contractor +def eager_contract_tensor_tensor(lhs, rhs, reduced_vars): + inputs = OrderedDict((k, d) for t in (lhs, rhs) + for k, d in t.inputs.items() if k not in reduced_vars) + data = opt_einsum.contract(lhs.data, list(lhs.inputs), + rhs.data, list(rhs.inputs), + list(inputs), backend="torch") + return Tensor(data, inputs, rhs.dtype) + + def arange(name, size): """ Helper to create a named :func:`torch.arange` funsor. From cc2c19594e958a56d15d1b216609b91026b9bf5f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Mar 2019 11:42:32 -0700 Subject: [PATCH 2/5] Fix optimize(Contract, Tensor, Tensor) --- funsor/contract.py | 4 ++-- funsor/torch.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/funsor/contract.py b/funsor/contract.py index ac6d9a9bc..236c2f9ea 100644 --- a/funsor/contract.py +++ b/funsor/contract.py @@ -88,13 +88,13 @@ def contract_funsor_funsor(lhs, rhs, reduced_vars): @optimize.register(Contract, Funsor, Finitary, frozenset) @contractor -def contract_ground_finitary(lhs, rhs, reduced_vars): +def contract_funsor_finitary(lhs, rhs, reduced_vars): return Contract(rhs, lhs, reduced_vars) @optimize.register(Contract, Finitary, (Finitary, Funsor), frozenset) @contractor -def contract_finitary_ground(lhs, rhs, reduced_vars): +def contract_finitary_funsor(lhs, rhs, reduced_vars): # exploit linearity of contraction if lhs.op is ops.add: return Finitary( diff --git a/funsor/torch.py b/funsor/torch.py index e550353f8..9f5ad6694 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -13,6 +13,7 @@ from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals from funsor.ops import Op +from funsor.optimizer import optimize from funsor.six import getargspec from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager, to_data, to_funsor @@ -356,7 +357,7 @@ def eager_binary_tensor_tensor(op, lhs, rhs): @eager.register(Contract, Tensor, Tensor, frozenset) @contractor -def eager_contract_tensor_tensor(lhs, rhs, reduced_vars): +def eager_contract(lhs, rhs, reduced_vars): inputs = OrderedDict((k, d) for t in (lhs, rhs) for k, d in t.inputs.items() if k not in reduced_vars) data = opt_einsum.contract(lhs.data, list(lhs.inputs), @@ -365,6 +366,11 @@ def eager_contract_tensor_tensor(lhs, rhs, reduced_vars): return Tensor(data, inputs, rhs.dtype) +@optimize.register(Contract, Tensor, Tensor, frozenset) +def optimize_contract(lhs, rhs, reduced_vars): + return None # reflect + + def arange(name, size): """ Helper to create a named :func:`torch.arange` funsor. From ea067b2e3eb042f6f56e3b352ba9891ad14a635c Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Mar 2019 11:47:18 -0700 Subject: [PATCH 3/5] Fix dtype computation --- funsor/torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/funsor/torch.py b/funsor/torch.py index 9f5ad6694..f2e45bd02 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -363,7 +363,8 @@ def eager_contract(lhs, rhs, reduced_vars): data = opt_einsum.contract(lhs.data, list(lhs.inputs), rhs.data, list(rhs.inputs), list(inputs), backend="torch") - return Tensor(data, inputs, rhs.dtype) + dtype = find_domain(ops.mul, lhs.output, rhs.output).dtype + return Tensor(data, inputs, dtype) @optimize.register(Contract, Tensor, Tensor, frozenset) From 4b81ab0c075bf87bc8835c72a3ca62997f218cf8 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Mar 2019 12:11:34 -0700 Subject: [PATCH 4/5] Address review comment --- funsor/contract.py | 13 +++++++++---- funsor/torch.py | 6 ------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/funsor/contract.py b/funsor/contract.py index 236c2f9ea..90e0dd566 100644 --- a/funsor/contract.py +++ b/funsor/contract.py @@ -79,22 +79,27 @@ def eager_subs(self, subs): self.reduced_vars) -@optimize.register(Contract, Funsor, Funsor, frozenset) @eager.register(Contract, Funsor, Funsor, frozenset) @contractor -def contract_funsor_funsor(lhs, rhs, reduced_vars): +def eager_contract(lhs, rhs, reduced_vars): return (lhs * rhs).reduce(ops.add, reduced_vars) +@optimize.register(Contract, Funsor, Funsor, frozenset) +@contractor +def optmize_contract(lhs, rhs, reduced_vars): + return None + + @optimize.register(Contract, Funsor, Finitary, frozenset) @contractor -def contract_funsor_finitary(lhs, rhs, reduced_vars): +def optimize_contract_funsor_finitary(lhs, rhs, reduced_vars): return Contract(rhs, lhs, reduced_vars) @optimize.register(Contract, Finitary, (Finitary, Funsor), frozenset) @contractor -def contract_finitary_funsor(lhs, rhs, reduced_vars): +def optimize_contract_finitary_funsor(lhs, rhs, reduced_vars): # exploit linearity of contraction if lhs.op is ops.add: return Finitary( diff --git a/funsor/torch.py b/funsor/torch.py index f2e45bd02..1f66421ff 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -13,7 +13,6 @@ from funsor.delta import Delta from funsor.domains import Domain, bint, find_domain, reals from funsor.ops import Op -from funsor.optimizer import optimize from funsor.six import getargspec from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager, to_data, to_funsor @@ -367,11 +366,6 @@ def eager_contract(lhs, rhs, reduced_vars): return Tensor(data, inputs, dtype) -@optimize.register(Contract, Tensor, Tensor, frozenset) -def optimize_contract(lhs, rhs, reduced_vars): - return None # reflect - - def arange(name, size): """ Helper to create a named :func:`torch.arange` funsor. From 774de4be7958a04b29de01e24bd9f3af2555c810 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Thu, 21 Mar 2019 12:51:51 -0700 Subject: [PATCH 5/5] Fix typo --- funsor/contract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funsor/contract.py b/funsor/contract.py index 90e0dd566..f3b53652d 100644 --- a/funsor/contract.py +++ b/funsor/contract.py @@ -87,7 +87,7 @@ def eager_contract(lhs, rhs, reduced_vars): @optimize.register(Contract, Funsor, Funsor, frozenset) @contractor -def optmize_contract(lhs, rhs, reduced_vars): +def optimize_contract(lhs, rhs, reduced_vars): return None