Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor optimizer to use Contraction and normalize #165

Merged
merged 11 commits into from
Jul 26, 2019
2 changes: 0 additions & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
adjoint,
affine,
cnf,
contract,
delta,
distributions,
domains,
Expand Down Expand Up @@ -42,7 +41,6 @@
'backward',
'bint',
'cnf',
'contract',
'delta',
'distributions',
'domains',
Expand Down
10 changes: 5 additions & 5 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

import funsor.ops as ops
from funsor.contract import Contract
from funsor.cnf import Contraction
from funsor.interpreter import interpretation, reinterpret
from funsor.ops import AssociativeOp
from funsor.registry import KeyedRegistry
Expand All @@ -18,7 +18,7 @@ def __init__(self):

def __call__(self, cls, *args):
result = eager(cls, *args)
if cls in (Reduce, Contract, Binary, Tensor):
if cls in (Reduce, Contraction, Binary, Tensor):
self.tape.append((result, cls, args))
return result

Expand Down Expand Up @@ -88,13 +88,13 @@ def adjoint_reduce(out_adj, out, op, arg, reduced_vars):
return {arg: out_adj + Binary(ops.safesub, out, arg)}


@adjoint_ops.register(Contract, Funsor, Funsor, AssociativeOp, AssociativeOp, Funsor, Funsor, frozenset)
@adjoint_ops.register(Contraction, Funsor, Funsor, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor)
def adjoint_contract(out_adj, out, sum_op, prod_op, lhs, rhs, reduced_vars):

lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)
lhs_adj = Contract(sum_op, prod_op, out_adj, rhs, lhs_reduced_vars)
lhs_adj = Contraction(sum_op, prod_op, lhs_reduced_vars, out_adj, rhs)

rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)
rhs_adj = Contract(sum_op, prod_op, out_adj, lhs, rhs_reduced_vars)
rhs_adj = Contraction(sum_op, prod_op, rhs_reduced_vars, out_adj, lhs)

return {lhs: lhs_adj, rhs: rhs_adj}
76 changes: 0 additions & 76 deletions funsor/contract.py

This file was deleted.

39 changes: 7 additions & 32 deletions funsor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,14 @@
import torch

import funsor.ops as ops
from funsor.contract import Contract
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import Finitary, apply_optimizer, optimize
from funsor.cnf import Contraction
from funsor.interpreter import interpretation
from funsor.optimizer import apply_optimizer, optimize
from funsor.sum_product import sum_product
from funsor.terms import Funsor, reflect
from funsor.terms import Funsor, normalize
from funsor.torch import Tensor


def _make_base_lhs(prod_op, arg, reduced_vars, normalized=False):
if not all(isinstance(d.dtype, int) for d in arg.inputs.values()):
raise NotImplementedError("TODO implement continuous base lhss")

if prod_op not in (ops.add, ops.mul):
raise NotImplementedError("{} not supported product op".format(prod_op))

make_unit = torch.ones if prod_op is ops.mul else torch.zeros

sizes = OrderedDict(set((var, dtype) for var, dtype in arg.inputs.items()))
terms = tuple(
Tensor(make_unit((d.dtype,)) / float(d.dtype), OrderedDict([(var, d)]))
if normalized else
Tensor(make_unit((d.dtype,)), OrderedDict([(var, d)]))
for var, d in sizes.items() if var in reduced_vars
)
return Finitary(prod_op, terms) if len(terms) > 1 else terms[0]


def naive_contract_einsum(eqn, *terms, **kwargs):
"""
Use for testing Contract against einsum
Expand All @@ -56,12 +37,7 @@ def naive_contract_einsum(eqn, *terms, **kwargs):
reduced_vars = input_dims - output_dims

with interpretation(optimize):
rhs = Finitary(prod_op, tuple(terms))
lhs = _make_base_lhs(prod_op, rhs, reduced_vars, normalized=False)
assert frozenset(lhs.inputs) == reduced_vars
result = Contract(sum_op, prod_op, lhs, rhs, reduced_vars)

return reinterpret(result)
return Contraction(sum_op, prod_op, reduced_vars, terms)


def naive_einsum(eqn, *terms, **kwargs):
Expand Down Expand Up @@ -136,7 +112,6 @@ def einsum(eqn, *terms, **kwargs):
terms): dimensions in plates but not in outputs are product-reduced;
dimensions in neither plates nor outputs are sum-reduced.
"""
with interpretation(reflect):
with interpretation(normalize):
naive_ast = naive_plated_einsum(eqn, *terms, **kwargs)
optimized_ast = apply_optimizer(naive_ast)
return reinterpret(optimized_ast) # eager by default
return apply_optimizer(naive_ast)
12 changes: 5 additions & 7 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.contract import Contract
from funsor.terms import Funsor, Reduce, eager


Expand Down Expand Up @@ -61,19 +60,18 @@ def integrator(fn):


@eager.register(Integrate, Funsor, Funsor, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars)
def eager_integrate_generic(log_measure, integrand, reduced_vars):
# return Contraction(ops.add, ops.mul, reduced_vars, log_measure.exp(), integrand) # XXX circular imports
return (log_measure.exp() * integrand).reduce(ops.add, reduced_vars)


@eager.register(Integrate, Reduce, Funsor, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
def eager_integrate_reduce(log_measure, integrand, reduced_vars):
if log_measure.op is ops.logaddexp:
arg = Integrate(log_measure.arg, integrand, reduced_vars)
return arg.reduce(ops.add, log_measure.reduced_vars)

return Contract(ops.add, ops.mul, log_measure.exp(), integrand, reduced_vars)
return eager_integrate_generic(log_measure, integrand, reduced_vars)


__all__ = [
Expand Down
Loading