Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Joint into multivariate Delta and Contraction #169

Merged
merged 49 commits into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
24a3fcc
start removing some joint patterns
eb8680 Jul 19, 2019
ecd16b6
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 19, 2019
93cf114
Update joint smoke tests
eb8680 Jul 19, 2019
ab6d97d
add a rule for permuting joint inputs
eb8680 Jul 19, 2019
84fbda2
make more joint and gaussian tests pass
eb8680 Jul 19, 2019
ff833ad
Add multidelta term
eb8680 Jul 24, 2019
8be7091
remove Joint, integrator, dead code
eb8680 Jul 24, 2019
0607930
remove remaining Joint appearances
eb8680 Jul 24, 2019
69467dc
remove duplicate test cases
eb8680 Jul 24, 2019
aa6d6e2
fix smoke tests
eb8680 Jul 24, 2019
95c243b
lint
eb8680 Jul 24, 2019
207699d
remove duplicate moment matching test
eb8680 Jul 24, 2019
a6f921a
make commutativity pattern less of a hack
eb8680 Jul 24, 2019
27b1bfc
fix bug in delta
eb8680 Jul 24, 2019
6a655c5
move joint patterns to joint.py
eb8680 Jul 24, 2019
0538b47
remove redundant pattern
eb8680 Jul 24, 2019
7990b78
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 25, 2019
25526fc
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 25, 2019
6fa538e
remove Delta entirely in favor of MultiDelta
eb8680 Jul 25, 2019
8a97689
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 25, 2019
fa9bab4
refactor MultiDelta to have a single log-density tensor
eb8680 Jul 25, 2019
d8ee093
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 25, 2019
eb7bace
have Tensor.unscaled_sample return a single MultiDelta
eb8680 Jul 25, 2019
1501039
revert Tensor.unscaled_sample to Delta
eb8680 Jul 26, 2019
d0a5b9e
fix moment matching
eb8680 Jul 26, 2019
92e4dda
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 26, 2019
5495e92
lint
eb8680 Jul 26, 2019
bc28853
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 26, 2019
607d029
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 26, 2019
ec511bd
remove incorrect tensor contraction
eb8680 Jul 26, 2019
c7b89b7
another attempt at scaling
eb8680 Jul 27, 2019
1e557fe
removed faulty pattern that was causing gaussian integration tests to…
eb8680 Jul 27, 2019
5b37297
fix one bug in minipyro and expose another
eb8680 Jul 27, 2019
0163e49
fix minipyro.Distribution.expand_inputs
eb8680 Jul 27, 2019
9a44d44
increase tolerance in sequential_sum_product test
eb8680 Jul 27, 2019
ac664c0
fix a couple more tests
eb8680 Jul 27, 2019
ec8abc4
fix wrong log pattern
eb8680 Jul 27, 2019
970a27c
sketch independent?
eb8680 Jul 27, 2019
e78f7fa
All integrate tests pass
eb8680 Jul 30, 2019
3700920
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 30, 2019
ac03a44
Add basic align method to Contraction
eb8680 Jul 30, 2019
a1f21ed
nit
eb8680 Jul 30, 2019
4b6d26f
remove inplace op in reciprocal
eb8680 Jul 30, 2019
949ad9f
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 30, 2019
a6c7db7
fix advanced indexing tests
eb8680 Jul 30, 2019
08a54d7
fix independent
eb8680 Jul 31, 2019
3b961d9
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 31, 2019
a973541
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 31, 2019
d06d11a
Merge branch 'contraction-normal-form' into cnf-joint
eb8680 Jul 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from multipledispatch.variadic import Variadic

import funsor.ops as ops
from funsor.delta import Delta
from funsor.delta import MultiDelta
from funsor.domains import find_domain
from funsor.gaussian import Gaussian
from funsor.interpreter import recursion_reinterpret
Expand Down Expand Up @@ -164,13 +164,13 @@ def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs):
# Normalizing Contractions
##########################################

GROUND_TERMS = (Delta, Gaussian, Number, Tensor)
GROUND_TERMS = (MultiDelta, Gaussian, Number, Tensor)


@normalize.register(Contraction, AssociativeOp, ops.AddOp, frozenset, GROUND_TERMS, GROUND_TERMS)
def normalize_contraction_commutative_canonical_order(red_op, bin_op, reduced_vars, *terms):
# when bin_op is commutative, put terms into a canonical order for pattern matching
ordering = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4}
ordering = {MultiDelta: 1, Number: 2, Tensor: 3, Gaussian: 4}
new_terms = tuple(
v for i, v in sorted(enumerate(terms),
key=lambda t: (ordering.get(type(t[1]), -1), t[0]))
Expand Down
298 changes: 174 additions & 124 deletions funsor/delta.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
from collections import OrderedDict

import funsor.ops as ops
import funsor.terms
from funsor.domains import Domain, reals
from funsor.integrate import Integrate, integrator
from funsor.integrate import Integrate
from funsor.interpreter import debug_logged
from funsor.ops import AddOp, SubOp, TransformOp
from funsor.registry import KeyedRegistry
Expand All @@ -12,8 +13,6 @@
Binary,
Funsor,
FunsorMeta,
Independent,
Lambda,
Number,
Reduce,
Subs,
Expand All @@ -24,127 +23,10 @@
)


class DeltaMeta(FunsorMeta):
"""
Wrapper to fill in defaults.
"""
def __call__(cls, name, point, log_density=0):
point = to_funsor(point)
log_density = to_funsor(log_density)
return super(DeltaMeta, cls).__call__(name, point, log_density)


class Delta(Funsor, metaclass=DeltaMeta):
"""
Normalized delta distribution binding a single variable.

:param str name: Name of the bound variable.
:param Funsor point: Value of the bound variable.
:param Funsor log_density: Optional log density to be added when evaluating
at a point. This is needed to make :class:`Delta` closed under
differentiable substitution.
"""
def __init__(self, name, point, log_density=0):
assert isinstance(name, str)
assert isinstance(point, Funsor)
assert isinstance(log_density, Funsor)
assert log_density.output == reals()
inputs = OrderedDict([(name, point.output)])
inputs.update(point.inputs)
inputs.update(log_density.inputs)
output = reals()
fresh = frozenset({name})
bound = frozenset()
super(Delta, self).__init__(inputs, output, fresh, bound)
self.name = name
self.point = point
self.log_density = log_density

def eager_subs(self, subs):
assert len(subs) == 1 and subs[0][0] == self.name
value = subs[0][1]

if isinstance(value, (str, Variable)):
value = to_funsor(value, self.output)
return Delta(value.name, self.point, self.log_density)

if not any(d.dtype == 'real' for side in (value, self.point)
for d in side.inputs.values()):
return (value == self.point).all().log() + self.log_density

# Try to invert the substitution.
soln = solve(value, self.point)
if soln is None:
return None # lazily substitute
name, point, log_density = soln
log_density += self.log_density
return Delta(name, point, log_density)

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp:
if self.name in reduced_vars:
return Number(0) # Deltas are normalized.

# TODO Implement ops.add to simulate .to_event().

return None # defer to default implementation


@eager.register(Binary, AddOp, Delta, (Funsor, Align))
def eager_add(op, lhs, rhs):
if lhs.name in rhs.inputs:
rhs = rhs(**{lhs.name: lhs.point})
return op(lhs, rhs)

return None # defer to default implementation


@eager.register(Binary, SubOp, Delta, (Funsor, Align))
def eager_sub(op, lhs, rhs):
if lhs.name in rhs.inputs:
rhs = rhs(**{lhs.name: lhs.point})
return op(lhs, rhs)

return None # defer to default implementation


@eager.register(Binary, AddOp, (Funsor, Align), Delta)
def eager_add(op, lhs, rhs):
if rhs.name in lhs.inputs:
lhs = lhs(**{rhs.name: rhs.point})
return op(lhs, rhs)

return None # defer to default implementation


eager.register(Binary, AddOp, Delta, Reduce)(
funsor.terms.eager_distribute_other_reduce)
eager.register(Binary, AddOp, Reduce, Delta)(
funsor.terms.eager_distribute_reduce_other)


@eager.register(Independent, Delta, str, str)
def eager_independent(delta, reals_var, bint_var):
if delta.name == reals_var or delta.name.startswith(reals_var + "__BOUND"):
i = Variable(bint_var, delta.inputs[bint_var])
point = Lambda(i, delta.point)
if bint_var in delta.log_density.inputs:
log_density = delta.log_density.reduce(ops.add, bint_var)
else:
log_density = delta.log_density * delta.inputs[bint_var].dtype
return Delta(reals_var, point, log_density)

return None # defer to default implementation


@eager.register(Integrate, Delta, Funsor, frozenset)
@integrator
def eager_integrate(delta, integrand, reduced_vars):
assert delta.name in reduced_vars
integrand = Subs(integrand, ((delta.name, delta.point),))
log_measure = delta.log_density
reduced_vars -= frozenset([delta.name])
return Integrate(log_measure, integrand, reduced_vars)
def Delta(name, point, log_density=0):
"""Syntactic sugar for MultiDelta"""
point, log_density = to_funsor(point), to_funsor(log_density)
return MultiDelta(((name, point),), log_density)


def solve(expr, value):
Expand Down Expand Up @@ -189,7 +71,175 @@ def solve_unary(op, arg, y):
return name, point, log_density


class MultiDeltaMeta(FunsorMeta):
"""
Wrapper to fill in defaults.
"""
def __call__(cls, terms, log_density=0):
terms = tuple(terms.items()) if isinstance(terms, OrderedDict) else terms
terms = tuple((name, to_funsor(point)) for name, point in terms)
log_density = to_funsor(log_density)
return super(MultiDeltaMeta, cls).__call__(terms, log_density)


class MultiDelta(Funsor, metaclass=MultiDeltaMeta):
"""
Normalized delta distribution binding multiple variables.
Represents joint log-density of all points with a single Tensor.
"""
def __init__(self, terms, log_density):
assert isinstance(terms, tuple) and len(terms) > 0
assert isinstance(log_density, Funsor)
assert log_density.output == reals()
inputs = log_density.inputs.copy()
for name, point in terms:
assert isinstance(name, str)
assert isinstance(point, Funsor)
assert name not in inputs
assert name not in point.inputs
inputs.update({name: point.output})
inputs.update(point.inputs)

output = log_density.output
fresh = frozenset(name for name, point in terms)
bound = frozenset()
super(MultiDelta, self).__init__(inputs, output, fresh, bound)
self.terms = terms
self.log_density = log_density

def align(self, names):
assert isinstance(names, tuple)
assert all(name in self.fresh for name in names)
if not names or names == tuple(n for n, p in self.terms):
return self

new_terms = sorted(self.terms, key=lambda t: names.index(t[0]))
return MultiDelta(new_terms, self.log_density)

def eager_subs(self, subs):
terms = OrderedDict(self.terms)
new_terms = terms.copy()
log_density = self.log_density
for name, value in subs:
if isinstance(value, (str, Variable)):
value = to_funsor(value, self.output)
new_terms[value.name] = new_terms.pop(name)
continue

if not any(d.dtype == 'real' for side in (value, terms[name])
for d in side.inputs.values()):
point = new_terms.pop(name)
log_density += (value == point).all().log()
continue

# Try to invert the substitution.
soln = solve(value, terms[name])
if soln is None:
return None # lazily substitute
new_name, point, point_log_density = soln
log_density += point_log_density
new_terms.pop(name)
new_terms[new_name] = point

return MultiDelta(new_terms, log_density) if new_terms else log_density

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp:
if reduced_vars - self.fresh and self.fresh - reduced_vars:
result = self.eager_reduce(op, reduced_vars & self.fresh) if reduced_vars & self.fresh else self
if result is not self:
result = result.eager_reduce(op, reduced_vars - self.fresh) if reduced_vars - self.fresh else self
return result if result is not self else None
return None

result = Subs(self, tuple((name, point) for name, point in self.terms if name in reduced_vars))
reduced_vars -= frozenset(name for name, point in self.terms)
if isinstance(result, MultiDelta):
terms = []
for name, point in result.terms:
if reduced_vars.intersection(point.inputs):
point_reduced_vars = reduced_vars.intersection(
frozenset(point.inputs) | frozenset(result.log_density.inputs))
point = Integrate(result.log_density, point, point_reduced_vars)
terms.append((name, point))

# rescale the log_density to account for reduced vars that only appeared in points
scale1 = Number(
sum([math.log(self.inputs[v].dtype) for v in reduced_vars.difference(result.log_density.inputs)]))

# rescale the output term to account for non-point reduced_vars
scale2 = Number(
sum([math.log(self.inputs[v].dtype) for v in reduced_vars.intersection(result.log_density.inputs)]))

log_density = result.log_density.reduce(op, reduced_vars.intersection(result.log_density.inputs))
final_result = MultiDelta(tuple(terms), log_density - scale1) + (scale1 + scale2)
return final_result
else:
value = Number(sum([math.log(self.inputs[v].dtype) for v in reduced_vars]))
log_density = result.reduce(op, reduced_vars.intersection(result.inputs))
final_result = value + (0. * log_density) if log_density.inputs else value
return final_result

if op is ops.add:
raise NotImplementedError("TODO Implement ops.add to simulate .to_event().")

return None # defer to default implementation

def unscaled_sample(self, sampled_vars, sample_inputs):
if sampled_vars <= self.fresh:
return self
raise NotImplementedError("TODO implement sample for particle indices")


eager.register(Binary, AddOp, MultiDelta, Reduce)(
funsor.terms.eager_distribute_other_reduce)
eager.register(Binary, AddOp, Reduce, MultiDelta)(
funsor.terms.eager_distribute_reduce_other)


@eager.register(Binary, AddOp, MultiDelta, MultiDelta)
def eager_add_multidelta(op, lhs, rhs):
if lhs.fresh.intersection(rhs.inputs):
return eager_add_delta_funsor(op, lhs, rhs)

if rhs.fresh.intersection(lhs.inputs):
return eager_add_funsor_delta(op, lhs, rhs)

return MultiDelta(lhs.terms + rhs.terms, lhs.log_density + rhs.log_density)


@eager.register(Binary, (AddOp, SubOp), MultiDelta, (Funsor, Align))
def eager_add_delta_funsor(op, lhs, rhs):
if lhs.fresh.intersection(rhs.inputs):
rhs = rhs(**{name: point for name, point in lhs.terms if name in rhs.inputs})
return op(lhs, rhs)

return None # defer to default implementation


@eager.register(Binary, AddOp, (Funsor, Align), MultiDelta)
def eager_add_funsor_delta(op, lhs, rhs):
if rhs.fresh.intersection(lhs.inputs):
lhs = lhs(**{name: point for name, point in rhs.terms if name in lhs.inputs})
return op(lhs, rhs)

return None


@eager.register(Integrate, MultiDelta, Funsor, frozenset)
def eager_integrate(delta, integrand, reduced_vars):
if not reduced_vars & delta.fresh:
return None
subs = tuple((name, point) for name, point in delta.terms
if name in reduced_vars)
new_integrand = Subs(integrand, subs)
new_log_measure = Subs(delta, subs)
result = Integrate(new_log_measure, new_integrand, reduced_vars - delta.fresh)
return result


__all__ = [
'Delta',
'MultiDelta',
'solve',
]
8 changes: 3 additions & 5 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import reals
from funsor.integrate import Integrate, integrator
from funsor.integrate import Integrate
from funsor.ops import AddOp, NegOp, SubOp
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, reflect, to_funsor
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
Expand Down Expand Up @@ -593,11 +593,10 @@ def eager_neg(op, arg):


@eager.register(Integrate, Gaussian, Variable, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars:
assert real_vars == frozenset([integrand.name])
if real_vars == frozenset([integrand.name]):
# assert real_vars == frozenset([integrand.name])
data = log_measure.loc * log_measure._log_normalizer.data.exp().unsqueeze(-1)
data = data.reshape(log_measure.loc.shape[:-1] + integrand.output.shape)
inputs = OrderedDict((k, d) for k, d in log_measure.inputs.items() if d.dtype != 'real')
Expand All @@ -607,7 +606,6 @@ def eager_integrate(log_measure, integrand, reduced_vars):


@eager.register(Integrate, Gaussian, Gaussian, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars:
Expand Down
Loading