Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Affine into Contraction #173

Merged
merged 8 commits into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from . import (
adjoint,
affine,
cnf,
contract,
delta,
Expand Down Expand Up @@ -38,7 +37,6 @@
'Tensor',
'Variable',
'adjoint',
'affine',
'arange',
'backward',
'bint',
Expand Down
194 changes: 0 additions & 194 deletions funsor/affine.py

This file was deleted.

35 changes: 32 additions & 3 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from funsor.gaussian import Gaussian, sym_inverse
from funsor.interpreter import recursion_reinterpret
from funsor.ops import AssociativeOp, DISTRIBUTIVE_OPS
from funsor.terms import Binary, Funsor, Number, Reduce, Subs, Unary, eager, moment_matching, normalize, reflect
from funsor.terms import Binary, Funsor, Number, Reduce, Subs, Unary, Variable, \
eager, moment_matching, normalize, reflect
from funsor.torch import Tensor


Expand Down Expand Up @@ -65,6 +66,20 @@ def __init__(self, red_op, bin_op, reduced_vars, terms):
self.bin_op = bin_op
self.terms = terms
self.reduced_vars = reduced_vars
self.is_affine = self._is_affine()

def _is_affine(self):
for t in self.terms:
if not isinstance(t, (Number, Tensor, Variable, Contraction)):
return False
if isinstance(t, Contraction):
if not (self.bin_op, t.bin_op) in DISTRIBUTIVE_OPS and t.is_affine:
return False

if self.bin_op is ops.add and self.red_op is not anyop:
return sum(1 for k, v in self.inputs.items() if v.dtype == 'real') == \
sum(sum(1 for k, v in t.inputs.items() if v.dtype == 'real') for t in self.terms)
return True


@recursion_reinterpret.register(Contraction)
Expand Down Expand Up @@ -236,15 +251,29 @@ def binary_subtract(op, lhs, rhs):
return lhs + -rhs


@normalize.register(Binary, ops.DivOp, Funsor, Funsor)
def binary_divide(op, lhs, rhs):
return lhs * Unary(ops.reciprocal, rhs)


@normalize.register(Unary, ops.NegOp, Contraction)
def unary_contract(op, arg):
if arg.bin_op is ops.add and arg.red_op is anyop:
return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *(-t for t in arg.terms))
raise NotImplementedError("TODO")
return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *(op(t) for t in arg.terms))
if arg.bin_op is ops.mul:
return arg * Number(-1.)
return None


@normalize.register(Unary, ops.NegOp, Variable)
def unary_neg_variable(op, arg):
return arg * Number(-1.)


@normalize.register(Unary, ops.ReciprocalOp, Contraction)
def unary_contract(op, arg):
if arg.bin_op is ops.mul and arg.red_op is anyop:
return Contraction(arg.red_op, arg.bin_op, arg.reduced_vars, *(op(t) for t in arg.terms))
raise NotImplementedError("TODO")


Expand Down
24 changes: 18 additions & 6 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import funsor.delta
import funsor.ops as ops
from funsor.affine import Affine
from funsor.cnf import Contraction
from funsor.domains import bint, reals
from funsor.gaussian import BlockMatrix, BlockVector, Gaussian
from funsor.interpreter import interpretation
Expand Down Expand Up @@ -392,16 +392,28 @@ def eager_normal(loc, scale, value):
return Normal(loc, scale, 'value')(value=value)


@eager.register(Normal, (Variable, Affine), Tensor, (Variable, Affine))
@eager.register(Normal, (Variable, Affine), Tensor, Tensor)
@eager.register(Normal, Tensor, Tensor, (Variable, Affine))
@eager.register(Normal, (Variable, Contraction), Tensor, (Variable, Contraction))
@eager.register(Normal, (Variable, Contraction), Tensor, Tensor)
@eager.register(Normal, Tensor, Tensor, (Variable, Contraction))
def eager_normal(loc, scale, value):
affine = (loc - value) / scale
assert isinstance(affine, Affine)
if not affine.is_affine:
return None

real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
assert not any(v.shape for v in real_inputs.values())

tensors = [affine.const] + [c for v, c in affine.coeffs.items()]
const, coeffs = to_funsor(torch.tensor(0.)), OrderedDict((k, Number(0.)) for k in real_inputs)
for t in affine.terms:
if isinstance(t, (Number, Tensor)):
const += t
elif isinstance(t, Variable):
coeffs[t.name] += 1.
elif isinstance(t, Contraction):
v, c = t.terms if isinstance(t.terms[0], Variable) else reversed(t.terms)
coeffs[v.name] += c

tensors = [const] + list(coeffs.values())
inputs, tensors = align_tensors(*tensors)
tensors = torch.broadcast_tensors(*tensors)
const, coeffs = tensors[0], tensors[1:]
Expand Down
10 changes: 7 additions & 3 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class NegOp(Op):
pass


class DivOp(Op):
pass


class GetitemMeta(type):
_cache = {}

Expand Down Expand Up @@ -112,7 +116,7 @@ def _default(self, x, y):
ne = Op(operator.ne)
neg = NegOp(operator.neg)
sub = SubOp(operator.sub)
truediv = Op(operator.truediv)
truediv = DivOp(operator.truediv)

add = AddOp(operator.add)
and_ = AssociativeOp(operator.and_)
Expand Down Expand Up @@ -199,13 +203,13 @@ def logaddexp(x, y):
return log(exp(x - shift) + exp(y - shift)) + shift


@Op
@SubOp
def safesub(x, y):
if isinstance(y, Number):
return sub(x, y)


@Op
@DivOp
def safediv(x, y):
if isinstance(y, Number):
return truediv(x, y)
Expand Down
5 changes: 5 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,11 @@ def _log1p(x):
return Unary(ops.log1p, x)


@ops.reciprocal.register(Funsor)
def _reciprocal(x):
return Unary(ops.reciprocal, x)


__all__ = [
'Binary',
'Funsor',
Expand Down
Loading