Skip to content

Commit

Permalink
Alpha renaming of bound variables (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored and fritzo committed Jun 25, 2019
1 parent 3edfb0f commit 7e456e6
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 26 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ format: FORCE

test: lint FORCE
pytest -v test
FUNSOR_GENERIC_SUBS=1 pytest -v test # TODO remove when removing eager_subs
FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py
FUNSOR_USE_TCO=1 pytest -v test/test_terms.py
FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py
Expand Down
6 changes: 6 additions & 0 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def eager_subs(self, subs):
def eager_affine(const, coeffs):
if not coeffs:
return const
if not all(isinstance(var, Variable) for var, coeff in coeffs):
result = Affine(const, tuple((var, coeff) for var, coeff in coeffs if isinstance(var, Variable)))
for var, coeff in coeffs:
if not isinstance(var, Variable):
result += var * coeff
return result
return None


Expand Down
4 changes: 3 additions & 1 deletion funsor/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self, sum_op, prod_op, lhs, rhs, reduced_vars):
inputs = OrderedDict([(k, d) for t in (lhs, rhs)
for k, d in t.inputs.items() if k not in reduced_vars])
output = rhs.output
super(Contract, self).__init__(inputs, output)
fresh = frozenset()
bound = reduced_vars
super(Contract, self).__init__(inputs, output, fresh, bound)
self.sum_op = sum_op
self.prod_op = prod_op
self.lhs = lhs
Expand Down
10 changes: 9 additions & 1 deletion funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Unary,
Variable,
eager,
substitute,
to_funsor
)

Expand Down Expand Up @@ -58,7 +59,9 @@ def __init__(self, name, point, log_density=0):
inputs.update(point.inputs)
inputs.update(log_density.inputs)
output = reals()
super(Delta, self).__init__(inputs, output)
fresh = frozenset({name})
bound = frozenset()
super(Delta, self).__init__(inputs, output, fresh, bound)
self.name = name
self.point = point
self.log_density = log_density
Expand Down Expand Up @@ -111,6 +114,11 @@ def eager_reduce(self, op, reduced_vars):
return None # defer to default implementation


@substitute.register(Delta, tuple)
def subs_gaussian(expr, subs):
return expr.eager_subs(tuple((k, to_funsor(v, expr.inputs[k]) if k in expr.inputs else v) for k, v in subs))


@eager.register(Binary, AddOp, Delta, (Funsor, Align))
def eager_add(op, lhs, rhs):
if lhs.name in rhs.inputs:
Expand Down
11 changes: 9 additions & 2 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.domains import reals
from funsor.integrate import Integrate, integrator
from funsor.ops import AddOp, NegOp, SubOp
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, substitute, to_funsor
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
from funsor.util import lazy_property

Expand Down Expand Up @@ -322,7 +322,9 @@ def __init__(self, loc, precision, inputs):
assert _issubshape(precision.shape, batch_shape + (dim, dim))

output = reals()
super(Gaussian, self).__init__(inputs, output)
fresh = frozenset(inputs.keys())
bound = frozenset()
super(Gaussian, self).__init__(inputs, output, fresh, bound)
self.loc = loc
self.precision = precision
self.batch_shape = batch_shape
Expand Down Expand Up @@ -528,6 +530,11 @@ def unscaled_sample(self, sampled_vars, sample_inputs):
raise NotImplementedError('TODO implement partial sampling of real variables')


@substitute.register(Gaussian, tuple)
def subs_gaussian(expr, subs):
return expr.eager_subs(tuple((k, to_funsor(v, expr.inputs[k]) if k in expr.inputs else v) for k, v in subs))


@eager.register(Binary, AddOp, Gaussian, Gaussian)
def eager_add_gaussian_gaussian(op, lhs, rhs):
# Fuse two Gaussians by adding their log-densities pointwise.
Expand Down
4 changes: 3 additions & 1 deletion funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(self, log_measure, integrand, reduced_vars):
for (k, d) in term.inputs.items()
if k not in reduced_vars)
output = integrand.output
super(Integrate, self).__init__(inputs, output)
fresh = frozenset()
bound = reduced_vars
super(Integrate, self).__init__(inputs, output, fresh, bound)
self.log_measure = log_measure
self.integrand = integrand
self.reduced_vars = reduced_vars
Expand Down
13 changes: 11 additions & 2 deletions funsor/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
import re
import types
import uuid
from collections import OrderedDict

import torch
Expand All @@ -23,6 +22,11 @@
_INTERPRETATION = None # To be set later in funsor.terms
_USE_TCO = int(os.environ.get("FUNSOR_USE_TCO", 0))

# TODO remove this, used temporarily for testing
_GENERIC_SUBS = int(os.environ.get("FUNSOR_GENERIC_SUBS", 0))

_GENSYM_COUNTER = 0


if _DEBUG:
def interpret(cls, *args):
Expand Down Expand Up @@ -175,9 +179,14 @@ def is_atom(x):


def gensym(x=None):
global _GENSYM_COUNTER
_GENSYM_COUNTER += 1
sym = _GENSYM_COUNTER
if x is not None:
if isinstance(x, str):
return x + "_" + str(sym)
return id(x)
return "V" + str(uuid.uuid4().hex)
return "V" + str(sym)


def stack_reinterpret(x):
Expand Down
6 changes: 6 additions & 0 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Unary,
Variable,
eager,
substitute,
to_funsor
)
from funsor.torch import Tensor, arange
Expand Down Expand Up @@ -244,6 +245,11 @@ def eager_independent(joint, reals_var, bint_var):
return None # defer to default implementation


@substitute.register(Joint, tuple)
def substitute_joint(expr, subs):
return expr.eager_subs(subs)


################################################################################
# Patterns to update a Joint with other funsors
################################################################################
Expand Down
11 changes: 9 additions & 2 deletions funsor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import funsor.ops as ops
from funsor.domains import Domain, bint, find_domain
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, eager, substitute, to_data, to_funsor


def align_array(new_inputs, x):
Expand Down Expand Up @@ -91,7 +91,9 @@ def __init__(self, data, inputs=None, dtype="real"):
assert all(isinstance(d.dtype, integer_types) for k, d in inputs)
inputs = OrderedDict(inputs)
output = Domain(data.shape[len(inputs):], dtype)
super(Array, self).__init__(inputs, output)
fresh = frozenset(inputs.keys())
bound = frozenset()
super(Array, self).__init__(inputs, output, fresh, bound)
self.data = data

def __repr__(self):
Expand Down Expand Up @@ -213,6 +215,11 @@ def _to_data_array(x):
return x.data


@substitute.register(Array, tuple)
def subs_gaussian(expr, subs):
return expr.eager_subs(tuple((k, to_funsor(v, expr.inputs[k]) if k in expr.inputs else v) for k, v in subs))


@eager.register(Binary, object, Array, Number)
def eager_binary_array_number(op, lhs, rhs):
if op is ops.getitem:
Expand Down
Loading

0 comments on commit 7e456e6

Please sign in to comment.