Skip to content

Commit

Permalink
Resurrect lazy Subs funsor (again) (#99)
Browse files Browse the repository at this point in the history
* Sketch Monte Carlo interpretation of logaddexp reduction

* Use AssociativeOp in patterns

* Fix op pattern matcher

* Try eager before monte_carlo

* Drop ops.sample, ops.marginal

* Sketch VAE example using monte carlo interpretation

* Refactor, focusing on .sample() and .monte_carlo_logsumexp() methods

* Fix vae example

* Sketch Tensor.sample() (untested)

* Fix cyclic import

* Sketch Gaussian.sample() (untested)

* Implement Delta.sample()

* Sketch Expectation class

* Sketch sampler implementations

* Delete Expectation in favor of Integrate in a separate PR

* Revert .sample() sketch

* Update VAE example to use multi-output Functions

* Fix reductions in VAE

* Sketch support for multiple args in __getitem__

* Fix bugs in getitem_tensor_tensor

* Add stronger tests for tensor getitem

* Add support for string indexing

* Simplify vae example using multi-getitem

* Add stub for Integrate

* Fix typo

* Sketch monte_carlo registration of Gaussian-Gaussian things

* Add stubs for Joint integration

* Fix typos

* Sketch support for multiple samples

* Fix test usage of registry

* Fix bugs in gaussian integral

* Handle scale factors in Funsor.sample()

* Use Integrate in test_samplers.py

* Fix bug in Integrate; be less clever

* Add implementations of gaussian-linear integrals

* Add interpretation logging controlled by FUNSOR_DEBUG

* Simplify debug printing

* Fix lazy reduction for Joint.reduce()

* Fix recursion bug

* Get univariate Gaussian sampling to mostly work

* Fix bug in Tensor.eager_reduce with nontrivial output

* Fix output shape broadcasting in Tensor

* Fix assert_close in test_samplers.py

* Fix cholesky bugs

* Fix bug in _trace_mm()

* Fixes for examples/vae.py

* Remove examples/vae.py

* Add docstrings

* Resurrect lazy Subs funsor (again)

* Fix typo

* Allow completely lazy eager_subs method
  • Loading branch information
fritzo authored and eb8680 committed Mar 26, 2019
1 parent 15b0c73 commit 951630c
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 61 deletions.
7 changes: 4 additions & 3 deletions funsor/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import funsor.ops as ops
from funsor.optimizer import Finitary, optimize
from funsor.sum_product import _partition
from funsor.terms import Funsor, eager
from funsor.terms import Funsor, Subs, eager


def _order_lhss(lhs, reduced_vars):
Expand Down Expand Up @@ -75,8 +75,9 @@ def eager_subs(self, subs):
return self
if not all(self.reduced_vars.isdisjoint(v.inputs) for k, v in subs):
raise NotImplementedError('TODO alpha-convert to avoid conflict')
return Contract(self.lhs.eager_subs(subs), self.rhs.eager_subs(subs),
self.reduced_vars)
lhs = Subs(self.lhs, subs)
rhs = Subs(self.rhs, subs)
return Contract(lhs, rhs, self.reduced_vars)


@eager.register(Contract, Funsor, Funsor, frozenset)
Expand Down
8 changes: 4 additions & 4 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from funsor.domains import reals
from funsor.integrate import Integrate, integrator
from funsor.ops import Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager, to_funsor


class DeltaMeta(FunsorMeta):
Expand Down Expand Up @@ -61,8 +61,8 @@ def eager_subs(self, subs):
return self

name = self.name
point = self.point.eager_subs(index_part)
log_density = self.log_density.eager_subs(index_part)
point = Subs(self.point, index_part)
log_density = Subs(self.log_density, index_part)
if value is not None:
if isinstance(value, Variable):
name = value.name
Expand Down Expand Up @@ -112,7 +112,7 @@ def eager_binary(op, lhs, rhs):
@integrator
def eager_integrate(delta, integrand, reduced_vars):
assert delta.name in reduced_vars
integrand = integrand.eager_subs(((delta.name, delta.point),))
integrand = Subs(integrand, ((delta.name, delta.point),))
log_measure = delta.log_density
reduced_vars -= frozenset([delta.name])
return Integrate(log_measure, integrand, reduced_vars)
Expand Down
4 changes: 2 additions & 2 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import funsor.ops as ops
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.terms import Funsor, FunsorMeta, Number, Subs, Variable, eager, to_funsor
from funsor.torch import Tensor, align_tensors, materialize


Expand Down Expand Up @@ -69,7 +69,7 @@ def eager_subs(self, subs):
assert isinstance(subs, tuple)
if not any(k in self.inputs for k, v in subs):
return self
params = OrderedDict((k, v.eager_subs(subs)) for k, v in self.params)
params = OrderedDict((k, Subs(v, subs)) for k, v in self.params)
return type(self)(**params)

def eager_reduce(self, op, reduced_vars):
Expand Down
6 changes: 3 additions & 3 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import AddOp
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
from funsor.util import lazy_property

Expand Down Expand Up @@ -196,11 +196,11 @@ def eager_subs(self, subs):
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
tensors = [self.loc, self.precision]
funsors = [Tensor(x, int_inputs).eager_subs(int_subs) for x in tensors]
funsors = [Subs(Tensor(x, int_inputs), int_subs) for x in tensors]
inputs = funsors[0].inputs.copy()
inputs.update(real_inputs)
int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
return int_result.eager_subs(real_subs)
return Subs(int_result, real_subs)

# Try to perform a complete substitution of all real variables, resulting in a Tensor.
assert real_subs and not int_subs
Expand Down
22 changes: 11 additions & 11 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import AddOp, Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager, to_funsor
from funsor.torch import Tensor, arange


Expand Down Expand Up @@ -61,13 +61,13 @@ def __init__(self, deltas, discrete, gaussian):
self.gaussian = gaussian

def eager_subs(self, subs):
gaussian = self.gaussian.eager_subs(subs)
gaussian = Subs(self.gaussian, subs)
assert isinstance(gaussian, (Number, Tensor, Gaussian))
discrete = self.discrete.eager_subs(subs)
gaussian = self.gaussian.eager_subs(subs)
discrete = Subs(self.discrete, subs)
gaussian = Subs(self.gaussian, subs)
deltas = []
for x in self.deltas:
x = x.eager_subs(subs)
x = Subs(x, subs)
if isinstance(x, Delta):
deltas.append(x)
elif isinstance(x, (Number, Tensor)):
Expand Down Expand Up @@ -174,12 +174,12 @@ def eager_add(op, joint, other):
def eager_add(op, joint, delta):
# Update with a degenerate distribution, typically a monte carlo sample.
if delta.name in joint.inputs:
joint = joint.eager_subs(((delta.name, delta.point),))
joint = Subs(joint, ((delta.name, delta.point),))
if not isinstance(joint, Joint):
return joint + delta
for d in joint.deltas:
if d.name in delta.inputs:
delta = delta.eager_subs(((d.name, d.point),))
delta = Subs(delta, ((d.name, d.point),))
deltas = joint.deltas + (delta,)
return Joint(deltas, joint.discrete, joint.gaussian)

Expand All @@ -189,7 +189,7 @@ def eager_add(op, joint, other):
# Update with a delayed discrete random variable.
subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs)
if subs:
return joint + other.eager_subs(subs)
return joint + Subs(other, subs)
return Joint(joint.deltas, joint.discrete + other, joint.gaussian)


Expand All @@ -206,7 +206,7 @@ def eager_add(op, joint, other):
# Update with a delayed gaussian random variable.
subs = tuple((d.name, d.point) for d in joint.deltas if d.name in other.inputs)
if subs:
other = other.eager_subs(subs)
other = Subs(other, subs)
if joint.gaussian is not Number(0):
other = joint.gaussian + other
if not isinstance(other, Gaussian):
Expand Down Expand Up @@ -238,7 +238,7 @@ def eager_add(op, lhs, rhs):
@eager.register(Binary, AddOp, Delta, (Number, Tensor, Gaussian))
def eager_add(op, delta, other):
if delta.name in other.inputs:
other = other.eager_subs(((delta.name, delta.point),))
other = Subs(other, ((delta.name, delta.point),))
assert isinstance(other, (Number, Tensor, Gaussian))
if isinstance(other, (Number, Tensor)):
return Joint((delta,), discrete=other)
Expand Down Expand Up @@ -282,7 +282,7 @@ def _simplify_integrate(fn, joint, integrand, reduced_vars):
subs = tuple((d.name, d.point) for d in joint.deltas if d.name in reduced_vars)
deltas = tuple(d for d in joint.deltas if d.name not in reduced_vars)
log_measure = Joint(deltas, joint.discrete, joint.gaussian)
integrand = integrand.eager_subs(subs)
integrand = Subs(integrand, subs)
reduced_vars = reduced_vars - frozenset(name for name, point in subs)
return Integrate(log_measure, integrand, reduced_vars)

Expand Down
4 changes: 2 additions & 2 deletions funsor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import funsor.ops as ops
from funsor.domains import Domain, bint, find_domain
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, eager, to_data, to_funsor


def align_array(new_inputs, x):
Expand Down Expand Up @@ -282,4 +282,4 @@ def materialize(x):
assert not domain.shape
subs.append((name, arange(name, domain.dtype)))
subs = tuple(subs)
return x.eager_subs(subs)
return Subs(x, subs)
14 changes: 7 additions & 7 deletions funsor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from funsor.domains import find_domain
from funsor.interpreter import dispatched_interpretation, interpretation, reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp
from funsor.terms import Binary, Funsor, Reduce, eager, reflect
from funsor.terms import Binary, Funsor, Reduce, Subs, eager, lazy


class Finitary(Funsor):
Expand Down Expand Up @@ -37,7 +37,7 @@ def __repr__(self):
def eager_subs(self, subs):
if not any(k in self.inputs for k, v in subs):
return self
operands = tuple(operand.eager_subs(subs) for operand in self.operands)
operands = tuple(Subs(operand, subs) for operand in self.operands)
return Finitary(self.op, operands)


Expand All @@ -50,7 +50,7 @@ def eager_finitary(op, operands):
def associate(cls, *args):
result = associate.dispatch(cls, *args)
if result is None:
result = reflect(cls, *args)
result = lazy(cls, *args)
return result


Expand All @@ -70,7 +70,7 @@ def associate_finitary(op, operands):
else:
new_operands.append(term)

with interpretation(reflect):
with interpretation(lazy):
return Finitary(op, tuple(new_operands))


Expand All @@ -91,7 +91,7 @@ def associate_reduce(op, arg, reduced_vars):
def distribute(cls, *args):
result = distribute.dispatch(cls, *args)
if result is None:
result = reflect(cls, *args)
result = lazy(cls, *args)
return result


Expand Down Expand Up @@ -130,7 +130,7 @@ def distribute_finitary(op, operands):
def optimize(cls, *args):
result = optimize.dispatch(cls, *args)
if result is None:
result = reflect(cls, *args)
result = lazy(cls, *args)
return result


Expand Down Expand Up @@ -222,7 +222,7 @@ def remove_single_finitary(op, operands):
def desugar(cls, *args):
result = desugar.dispatch(cls, *args)
if result is None:
result = reflect(cls, *args)
result = lazy(cls, *args)
return result


Expand Down
Loading

0 comments on commit 951630c

Please sign in to comment.