diff --git a/funsor/__init__.py b/funsor/__init__.py index 570408d06..3fd4d28ae 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -6,7 +6,6 @@ from . import ( adjoint, - affine, cnf, contract, delta, @@ -37,7 +36,6 @@ 'Tensor', 'Variable', 'adjoint', - 'affine', 'arange', 'backward', 'bint', diff --git a/funsor/affine.py b/funsor/affine.py deleted file mode 100644 index be7c0892f..000000000 --- a/funsor/affine.py +++ /dev/null @@ -1,192 +0,0 @@ -from collections import OrderedDict - -import funsor.ops as ops -from funsor.domains import find_domain -from funsor.ops import NegOp, Op -from funsor.terms import Binary, Funsor, Number, Unary, Variable, eager -from funsor.torch import Tensor - - -class Affine(Funsor): - """ - Pattern representing multilinear function of input variables - """ - def __init__(self, const, coeffs): - assert isinstance(const, (Number, Tensor)) - assert not any(d.dtype == "real" for d in const.inputs.values()) - assert isinstance(coeffs, tuple) - inputs = const.inputs.copy() - output = const.output - assert output.dtype == "real" - for var, coeff in coeffs: - assert isinstance(var, Variable) - assert isinstance(coeff, (Number, Tensor)) - assert not any(d.dtype == "real" for d in coeff.inputs.values()) - inputs.update(coeff.inputs) - inputs.update(var.inputs) - output = find_domain(ops.add, output, find_domain(ops.mul, var.output, coeff.output)) - assert var.dtype == "real" - assert coeff.dtype == "real" - assert output.dtype == "real" - - super(Affine, self).__init__(inputs, output) - self.coeffs = OrderedDict(coeffs) - self.const = const - - -############################################### -# patterns for merging Affine with other terms -############################################### - -@eager.register(Affine, (Number, Tensor), tuple) -def eager_affine(const, coeffs): - if not coeffs: - return const - if not all(isinstance(var, Variable) for var, coeff in coeffs): - result = Affine(const, tuple((var, coeff) for var, coeff in coeffs if isinstance(var, Variable))) - for var, coeff in coeffs: - if not isinstance(var, Variable): - result += var * coeff - return result - return None - - -@eager.register(Binary, Op, Affine, (Number, Tensor)) -def eager_binary_affine(op, lhs, rhs): - if op is ops.add or op is ops.sub: - const = op(lhs.const, rhs) - return Affine(const, tuple(lhs.coeffs.items())) - if op is ops.mul or op is ops.truediv: - const = op(lhs.const, rhs) - coeffs = tuple((var, op(coeff, rhs)) for var, coeff in lhs.coeffs.items()) - return Affine(const, coeffs) - return None - - -@eager.register(Binary, Op, (Number, Tensor), Affine) -def eager_binary_affine(op, lhs, rhs): - if op is ops.add: - const = lhs + rhs.const - return Affine(const, tuple(rhs.coeffs.items())) - elif op is ops.sub: - return lhs + -rhs - if op is ops.mul: - const = lhs * rhs.const - coeffs = tuple((var, lhs * coeff) for var, coeff in rhs.coeffs.items()) - return Affine(const, coeffs) - return None - - -@eager.register(Binary, Op, Affine, Affine) -def eager_binary_affine_affine(op, lhs, rhs): - if op is ops.add: - const = lhs.const + rhs.const - coeffs = lhs.coeffs.copy() - for var, coeff in rhs.coeffs.items(): - if var in coeffs: - coeffs[var] += coeff - else: - coeffs[var] = coeff - return Affine(const, tuple(coeffs.items())) - - if op is ops.sub: - return lhs + -rhs - - return None - - -@eager.register(Binary, Op, Affine, Variable) -def eager_binary_affine_variable(op, affine, other): - if op is ops.add: - const = affine.const - coeffs = affine.coeffs.copy() - if other in affine.inputs: - coeffs[other] += 1 - else: - coeffs[other] = Number(1.) - return Affine(const, tuple(coeffs.items())) - - if op is ops.sub: - return affine + -other - - return None - - -@eager.register(Binary, Op, Variable, Affine) -def eager_binary_variable_affine(op, other, affine): - if op is ops.add: - return affine + other - - if op is ops.sub: - return -affine + other - - return None - - -@eager.register(Unary, NegOp, Affine) -def eager_negate_affine(op, affine): - const = -affine.const - coeffs = affine.coeffs.copy() - for var, coeff in coeffs.items(): - coeffs[var] = -coeff - return Affine(const, tuple(coeffs.items())) - - -######################################### -# patterns for creating new Affine terms -######################################### - -@eager.register(Binary, Op, Variable, (Number, Tensor)) -def eager_binary(op, var, other): - if var.dtype != "real" or other.dtype != "real": - return None - - if op is ops.add: - const = other - coeffs = ((var, Number(1.)),) - return Affine(const, coeffs) - elif op is ops.mul: - const = Number(0.) - coeffs = ((var, other),) - return Affine(const, coeffs) - elif op is ops.sub: - return var + -other - elif op is ops.truediv: - return var * (1. / other) - return None - - -@eager.register(Binary, Op, Variable, Variable) -def eager_binary(op, lhs, rhs): - if lhs.dtype != "real" or rhs.dtype != "real": - return None - - if op is ops.add: - const = Number(0.) - coeffs = ((lhs, Number(1.)), (rhs, Number(1.))) - return Affine(const, coeffs) - elif op is ops.sub: - return lhs + -rhs - return None - - -@eager.register(Binary, Op, (Number, Tensor), Variable) -def eager_binary(op, other, var): - if other.dtype != "real" or var.dtype != "real": - return None - - if op is ops.add or op is ops.mul: - return op(var, other) - elif op is ops.sub: - return -var + other - return None - - -@eager.register(Unary, NegOp, Variable) -def eager_negate_variable(op, var): - if var.dtype != "real": - return None - - const = Number(0.) - coeffs = ((var, Number(-1, "real")),) - return Affine(const, coeffs) diff --git a/funsor/cnf.py b/funsor/cnf.py index 0813ee6c7..0d2abb16f 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -10,7 +10,8 @@ from funsor.gaussian import Gaussian, sym_inverse from funsor.interpreter import recursion_reinterpret from funsor.ops import AssociativeOp, DISTRIBUTIVE_OPS -from funsor.terms import Binary, Funsor, Number, Reduce, Subs, Unary, eager, moment_matching, normalize +from funsor.terms import Binary, Funsor, Number, Reduce, Subs, Unary, Variable, \ + eager, moment_matching, normalize from funsor.torch import Tensor @@ -62,6 +63,20 @@ def __init__(self, red_op, bin_op, reduced_vars, terms): self.bin_op = bin_op self.terms = terms self.reduced_vars = reduced_vars + self.is_affine = self._is_affine() + + def _is_affine(self): + for t in self.terms: + if not isinstance(t, (Number, Tensor, Variable, Contraction)): + return False + if isinstance(t, Contraction): + if not (self.bin_op, t.bin_op) in DISTRIBUTIVE_OPS and t.is_affine: + return False + + if self.bin_op is ops.add and self.red_op is not anyop: + return sum(1 for k, v in self.inputs.items() if v.dtype == 'real') == \ + sum(sum(1 for k, v in t.inputs.items() if v.dtype == 'real') for t in self.terms) + return True @recursion_reinterpret.register(Contraction) @@ -233,15 +248,29 @@ def binary_subtract(op, lhs, rhs): return lhs + -rhs +@normalize.register(Binary, ops.DivOp, Funsor, Funsor) +def binary_divide(op, lhs, rhs): + return lhs * Unary(ops.reciprocal, rhs) + + @normalize.register(Unary, ops.NegOp, Contraction) def unary_contract(op, arg): if arg.bin_op is ops.add and arg.red_op is anyop: - return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *(-t for t in arg.terms)) - raise NotImplementedError("TODO") + return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *(op(t) for t in arg.terms)) + if arg.bin_op is ops.mul: + return arg * Number(-1.) + return None + + +@normalize.register(Unary, ops.NegOp, Variable) +def unary_neg_variable(op, arg): + return arg * Number(-1.) @normalize.register(Unary, ops.ReciprocalOp, Contraction) def unary_contract(op, arg): + if arg.bin_op is ops.mul and arg.red_op is anyop: + return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *(op(t) for t in arg.terms)) raise NotImplementedError("TODO") diff --git a/funsor/distributions.py b/funsor/distributions.py index 7638631b3..c32fe8d37 100644 --- a/funsor/distributions.py +++ b/funsor/distributions.py @@ -7,7 +7,7 @@ import funsor.delta import funsor.ops as ops -from funsor.affine import Affine +from funsor.cnf import Contraction from funsor.domains import bint, reals from funsor.gaussian import BlockMatrix, BlockVector, Gaussian from funsor.interpreter import interpretation @@ -388,16 +388,28 @@ def eager_normal(loc, scale, value): return Normal(loc, scale, 'value')(value=value) -@eager.register(Normal, (Variable, Affine), Tensor, (Variable, Affine)) -@eager.register(Normal, (Variable, Affine), Tensor, Tensor) -@eager.register(Normal, Tensor, Tensor, (Variable, Affine)) +@eager.register(Normal, (Variable, Contraction), Tensor, (Variable, Contraction)) +@eager.register(Normal, (Variable, Contraction), Tensor, Tensor) +@eager.register(Normal, Tensor, Tensor, (Variable, Contraction)) def eager_normal(loc, scale, value): affine = (loc - value) / scale - assert isinstance(affine, Affine) + if not affine.is_affine: + return None + real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert not any(v.shape for v in real_inputs.values()) - tensors = [affine.const] + [c for v, c in affine.coeffs.items()] + const, coeffs = to_funsor(torch.tensor(0.)), OrderedDict((k, Number(0.)) for k in real_inputs) + for t in affine.terms: + if isinstance(t, (Number, Tensor)): + const += t + elif isinstance(t, Variable): + coeffs[t.name] += 1. + elif isinstance(t, Contraction): + v, c = t.terms if isinstance(t.terms[0], Variable) else reversed(t.terms) + coeffs[v.name] += c + + tensors = [const] + list(coeffs.values()) inputs, tensors = align_tensors(*tensors) tensors = torch.broadcast_tensors(*tensors) const, coeffs = tensors[0], tensors[1:] diff --git a/funsor/ops.py b/funsor/ops.py index a266d1f62..cb28f4460 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -69,6 +69,10 @@ class NegOp(Op): pass +class DivOp(Op): + pass + + class GetitemMeta(type): _cache = {} @@ -108,7 +112,7 @@ def _default(self, x, y): ne = Op(operator.ne) neg = NegOp(operator.neg) sub = SubOp(operator.sub) -truediv = Op(operator.truediv) +truediv = DivOp(operator.truediv) add = AddOp(operator.add) and_ = AssociativeOp(operator.and_) @@ -195,13 +199,13 @@ def logaddexp(x, y): return log(exp(x - shift) + exp(y - shift)) + shift -@Op +@SubOp def safesub(x, y): if isinstance(y, Number): return sub(x, y) -@Op +@DivOp def safediv(x, y): if isinstance(y, Number): return truediv(x, y) diff --git a/funsor/terms.py b/funsor/terms.py index 69f54bfdd..299574705 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1213,6 +1213,11 @@ def _log1p(x): return Unary(ops.log1p, x) +@ops.reciprocal.register(Funsor) +def _reciprocal(x): + return Unary(ops.reciprocal, x) + + __all__ = [ 'Binary', 'Funsor', diff --git a/test/test_affine.py b/test/test_affine.py index c258c80d2..9e09f2fbc 100644 --- a/test/test_affine.py +++ b/test/test_affine.py @@ -3,21 +3,21 @@ import pytest import torch -from funsor.affine import Affine +from funsor.cnf import Contraction from funsor.domains import bint, reals from funsor.terms import Number, Variable from funsor.testing import check_funsor from funsor.torch import Tensor SMOKE_TESTS = [ - ('t+x', Affine), - ('x+t', Affine), - ('n+x', Affine), - ('n*x', Affine), - ('t*x', Affine), - ('x*t', Affine), - ('-x', Affine), - ('t-x', Affine), + ('t+x', Contraction), + ('x+t', Contraction), + ('n+x', Contraction), + ('n*x', Contraction), + ('t*x', Contraction), + ('x*t', Contraction), + ('-x', Contraction), + ('t-x', Contraction), ] @@ -38,15 +38,16 @@ def test_smoke(expr, expected_type): result = eval(expr) assert isinstance(result, expected_type) + assert result.is_affine SUBS_TESTS = [ - ("(t * x)(i=1)", Affine, {"j": bint(3), "x": reals()}), - ("(t * x)(i=1, x=y)", Affine, {"j": bint(3), "y": reals()}), - ("(t * x + n)(x=y)", Affine, {"y": reals(), "i": bint(2), "j": bint(3)}), - ("(x + y)(y=z)", Affine, {"x": reals(), "z": reals()}), - ("(-x)(x=y+z)", Affine, {"y": reals(), "z": reals()}), - ("(t * x + t * y)(x=z)", Affine, {"y": reals(), "z": reals(), "i": bint(2), "j": bint(3)}), + ("(t * x)(i=1)", Contraction, {"j": bint(3), "x": reals()}), + ("(t * x)(i=1, x=y)", Contraction, {"j": bint(3), "y": reals()}), + ("(t * x + n)(x=y)", Contraction, {"y": reals(), "i": bint(2), "j": bint(3)}), + ("(x + y)(y=z)", Contraction, {"x": reals(), "z": reals()}), + ("(-x)(x=y+z)", Contraction, {"y": reals(), "z": reals()}), + ("(t * x + t * y)(x=z)", Contraction, {"y": reals(), "z": reals(), "i": bint(2), "j": bint(3)}), ] @@ -73,3 +74,4 @@ def test_affine_subs(expr, expected_type, expected_inputs): result = eval(expr) assert isinstance(result, expected_type) check_funsor(result, expected_inputs, expected_output) + assert result.is_affine diff --git a/test/test_cnf.py b/test/test_cnf.py index 36c75300a..7af293e11 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -10,8 +10,8 @@ from funsor.einsum import einsum, naive_plated_einsum from funsor.gaussian import Gaussian from funsor.interpreter import interpretation, reinterpret -from funsor.terms import Number, Variable, eager, moment_matching, normalize, reflect -from funsor.testing import assert_close, check_funsor, make_einsum_example +from funsor.terms import Number, eager, moment_matching, normalize, reflect +from funsor.testing import assert_close, make_einsum_example # , xfail_param from funsor.torch import Tensor @@ -64,52 +64,6 @@ def test_normalize_einsum(equation, plates, backend, einsum_impl): assert_close(actual, expected, rtol=1e-4) -AFFINE_SMOKE_TESTS = [ - ('t+x', Contraction, {"i": bint(2), "j": bint(3), "x": reals()}), - ('x+t', Contraction, {"x": reals(), "i": bint(2), "j": bint(3)}), - ('n+x', Contraction, {"x": reals()}), - ('n*x', Contraction, {"x": reals()}), - ('t*x', Contraction, {"i": bint(2), "j": bint(3), "x": reals()}), - ('x*t', Contraction, {"x": reals(), "i": bint(2), "j": bint(3)}), - ("-(y+z)", Contraction, {"y": reals(), "z": reals()}), - # xfail_param(('-x', Contraction, {"x": reals()}), reason="not a contraction"), - ('t-x', Contraction, {"i": bint(2), "j": bint(3), "x": reals()}), - ("(t * x)(i=1)", Contraction, {"j": bint(3), "x": reals()}), - ("(t * x)(i=1, x=y)", Contraction, {"j": bint(3), "y": reals()}), - ("(t * x + n)(x=y)", Contraction, {"y": reals(), "i": bint(2), "j": bint(3)}), - ("(x + y)(y=z)", Contraction, {"x": reals(), "z": reals()}), - # xfail_param(("(-x)(x=y+z)", Contraction, {"y": reals(), "z": reals()}), reason="not a contraction"), - ("(t * x + t * y)(x=z)", Contraction, {"y": reals(), "z": reals(), "i": bint(2), "j": bint(3)}), -] - - -@pytest.mark.parametrize("expr,expected_type,expected_inputs", AFFINE_SMOKE_TESTS) -def test_affine_subs(expr, expected_type, expected_inputs): - - expected_output = reals() - - t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) - assert isinstance(t, Tensor) - - n = Number(2.) - assert isinstance(n, Number) - - x = Variable('x', reals()) - assert isinstance(x, Variable) - - y = Variable('y', reals()) - assert isinstance(y, Variable) - - z = Variable('z', reals()) - assert isinstance(z, Variable) - - with interpretation(normalize): - result = eval(expr) - - assert isinstance(result, expected_type) - check_funsor(result, expected_inputs, expected_output) - - JOINT_SMOKE_TESTS = [ ('dx + dy', Contraction), ('dx + g', Contraction),