Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Monte Carlo interpretation of Integrate #54

Merged
merged 63 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
6d3c7fd
Sketch Monte Carlo interpretation of logaddexp reduction
fritzo Mar 7, 2019
dbe80dc
Merge branch 'master' into montecarlo
fritzo Mar 9, 2019
95d030f
Use AssociativeOp in patterns
fritzo Mar 9, 2019
ca3fdcc
Merge branch 'master' into montecarlo
fritzo Mar 13, 2019
a4a0ea7
Fix op pattern matcher
fritzo Mar 13, 2019
1df4425
Try eager before monte_carlo
fritzo Mar 13, 2019
c68e45b
Drop ops.sample, ops.marginal
fritzo Mar 13, 2019
4060457
Sketch VAE example using monte carlo interpretation
fritzo Mar 13, 2019
0bb8eba
Refactor, focusing on .sample() and .monte_carlo_logsumexp() methods
fritzo Mar 13, 2019
a799bc7
Fix vae example
fritzo Mar 13, 2019
a34bd82
Sketch Tensor.sample() (untested)
fritzo Mar 14, 2019
4dbe8b6
Fix cyclic import
fritzo Mar 14, 2019
3792e87
Sketch Gaussian.sample() (untested)
fritzo Mar 14, 2019
80eda9d
Implement Delta.sample()
fritzo Mar 14, 2019
3fa6956
Sketch Expectation class
fritzo Mar 14, 2019
4e75094
Sketch sampler implementations
fritzo Mar 14, 2019
5121478
Delete Expectation in favor of Integrate in a separate PR
fritzo Mar 15, 2019
bcbe76e
Revert .sample() sketch
fritzo Mar 15, 2019
eef211f
Update VAE example to use multi-output Functions
fritzo Mar 20, 2019
361fff8
Fix reductions in VAE
fritzo Mar 20, 2019
91fd5b2
Merge branch 'master' into montecarlo
fritzo Mar 20, 2019
ea587e3
Sketch support for multiple args in __getitem__
fritzo Mar 20, 2019
c41c6cb
Fix bugs in getitem_tensor_tensor
fritzo Mar 21, 2019
5920641
Add stronger tests for tensor getitem
fritzo Mar 21, 2019
2339f27
Add support for string indexing
fritzo Mar 21, 2019
7a692b1
Simplify vae example using multi-getitem
fritzo Mar 21, 2019
bbdca68
Add stub for Integrate
fritzo Mar 21, 2019
6193291
Merge branch 'master' into montecarlo
fritzo Mar 21, 2019
ab8107f
Fix typo
fritzo Mar 21, 2019
77f7f04
Sketch monte_carlo registration of Gaussian-Gaussian things
fritzo Mar 21, 2019
01ba6ad
Add stubs for Joint integration
fritzo Mar 21, 2019
db97b92
Fix typos
fritzo Mar 21, 2019
231967c
Sketch support for multiple samples
fritzo Mar 21, 2019
8e06fe4
Merge branch 'master' into montecarlo
fritzo Mar 21, 2019
339aa17
Fix test usage of registry
fritzo Mar 21, 2019
f5d189b
Fix bugs in gaussian integral
fritzo Mar 21, 2019
cb85e1f
Merge branch 'master' into montecarlo
fritzo Mar 21, 2019
ddcc70d
Merge branch 'master' into multi-getitem
fritzo Mar 21, 2019
f84c473
Handle scale factors in Funsor.sample()
fritzo Mar 22, 2019
801d0ff
Use Integrate in test_samplers.py
fritzo Mar 22, 2019
5d5f53e
Fix bug in Integrate; be less clever
fritzo Mar 22, 2019
d498604
Merge branch 'multi-getitem' into montecarlo
fritzo Mar 22, 2019
657779c
Add implementations of gaussian-linear integrals
fritzo Mar 22, 2019
7de197c
Add interpretation logging controlled by FUNSOR_DEBUG
fritzo Mar 22, 2019
853cd2b
Simplify debug printing
fritzo Mar 22, 2019
1588796
Merge branch 'debug-interp' into montecarlo
fritzo Mar 22, 2019
6688931
Fix lazy reduction for Joint.reduce()
fritzo Mar 22, 2019
58d693d
Merge branch 'master' into montecarlo
fritzo Mar 22, 2019
91dda7f
Merge branch 'fix-joint-reduce-lazy' into montecarlo
fritzo Mar 22, 2019
868fb66
Fix recursion bug
fritzo Mar 22, 2019
5fa8a6f
Merge branch 'fix-joint-reduce-lazy' into montecarlo
fritzo Mar 22, 2019
95f7ebc
Get univariate Gaussian sampling to mostly work
fritzo Mar 23, 2019
5ed4836
Fix bug in Tensor.eager_reduce with nontrivial output
fritzo Mar 23, 2019
78ad7b6
Fix output shape broadcasting in Tensor
fritzo Mar 23, 2019
c476a8e
Fix assert_close in test_samplers.py
fritzo Mar 23, 2019
1758620
Merge branch 'master' into montecarlo
fritzo Mar 23, 2019
1576534
Fix cholesky bugs
fritzo Mar 24, 2019
f10e09b
Fix bug in _trace_mm()
fritzo Mar 24, 2019
39586ad
Fixes for examples/vae.py
fritzo Mar 24, 2019
bee75b9
Remove examples/vae.py
fritzo Mar 24, 2019
7784fbb
Add docstrings
fritzo Mar 24, 2019
6f2d453
Updates per review
fritzo Mar 25, 2019
56ca93f
Revert accidental change
fritzo Mar 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion funsor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

from funsor.domains import Domain, bint, find_domain, reals
from funsor.integrate import Integrate
from funsor.interpreter import reinterpret
from funsor.terms import Funsor, Number, Variable, of_shape, to_data, to_funsor
from funsor.torch import Tensor, arange, torch_einsum
Expand All @@ -14,9 +15,11 @@
einsum,
gaussian,
handlers,
integrate,
interpreter,
joint,
minipyro,
montecarlo,
ops,
sum_product,
terms,
Expand All @@ -26,24 +29,27 @@
__all__ = [
'Domain',
'Funsor',
'Integrate',
'Number',
'Tensor',
'Variable',
'adjoint',
'arange',
'backward',
'contract',
'bint',
'contract',
'delta',
'distributions',
'domains',
'einsum',
'find_domain',
'gaussian',
'handlers',
'integrate',
'interpreter',
'joint',
'minipyro',
'montecarlo',
'of_shape',
'ops',
'reals',
Expand Down
11 changes: 11 additions & 0 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import funsor.ops as ops
from funsor.domains import reals
from funsor.integrate import Integrate, integrator
from funsor.ops import Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor

Expand Down Expand Up @@ -107,6 +108,16 @@ def eager_binary(op, lhs, rhs):
return None # defer to default implementation


@eager.register(Integrate, Delta, Funsor, frozenset)
@integrator
def eager_integrate(delta, integrand, reduced_vars):
assert delta.name in reduced_vars
integrand = integrand.eager_subs(((delta.name, delta.point),))
log_measure = delta.log_density
reduced_vars -= frozenset([delta.name])
return Integrate(log_measure, integrand, reduced_vars)


__all__ = [
'Delta',
]
114 changes: 95 additions & 19 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import reals
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import AddOp
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
from funsor.util import lazy_property

Expand All @@ -26,10 +28,14 @@ def _issubshape(subshape, supershape):
return True


def _log_det_tril(x):
def _log_det_tri(x):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out some of our matrices are upper triangular, but that's ok since these functions work with both upper- and lower- triangular matrices.

return x.diagonal(dim1=-1, dim2=-2).log().sum(-1)


def _det_tri(x):
return x.diagonal(dim1=-1, dim2=-2).prod(-1)


def _mv(mat, vec):
return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1)

Expand All @@ -44,6 +50,16 @@ def _vmv(mat, vec):
return result.squeeze(-1).squeeze(-1)


def _trace_mm(x, y):
"""
Computes ``trace(x @ y)``.
"""
assert x.dim() >= 2
assert y.dim() >= 2
xy = x * y
return xy.reshape(xy.shape[:-2] + (-1,)).sum(-1)


def _compute_offsets(inputs):
"""
Compute offsets of real inputs into the concatenated Gaussian dims.
Expand Down Expand Up @@ -195,7 +211,7 @@ def eager_subs(self, subs):
Tensor(self.precision, int_inputs)]
tensors.extend(subs.values())
inputs, tensors = align_tensors(*tensors)
batch_dim = self.loc.dim() - 1
batch_dim = tensors[0].dim() - 1
batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
(loc, precision), values = tensors[:2], tensors[2:]

Expand All @@ -219,8 +235,8 @@ def eager_subs(self, subs):
@lazy_property
def _log_normalizer(self):
dim = self.loc.size(-1)
log_det_term = _log_det_tril(torch.cholesky(self.precision))
data = log_det_term - 0.5 * math.log(2 * math.pi) * dim
log_det_term = _log_det_tri(torch.cholesky(self.precision))
data = -log_det_term + 0.5 * math.log(2 * math.pi) * dim
inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real')
return Tensor(data, inputs)

Expand All @@ -247,15 +263,15 @@ def eager_reduce(self, op, reduced_vars):
index = torch.tensor(index)

loc = self.loc[..., index]
self_scale_tril = torch.inverse(torch.cholesky(self.precision))
self_covariance = torch.matmul(self_scale_tril, self_scale_tril.transpose(-1, -2))
self_scale_tri = torch.inverse(torch.cholesky(self.precision)).transpose(-1, -2)
self_covariance = torch.matmul(self_scale_tri, self_scale_tri.transpose(-1, -2))
covariance = self_covariance[..., index.unsqueeze(-1), index]
scale_tril = torch.cholesky(covariance)
inv_scale_tril = torch.inverse(scale_tril)
precision = torch.matmul(inv_scale_tril, inv_scale_tril.transpose(-1, -2))
scale_tri = torch.cholesky(covariance)
inv_scale_tri = torch.inverse(scale_tri)
precision = torch.matmul(inv_scale_tri.transpose(-1, -2), inv_scale_tri)
reduced_dim = sum(self.inputs[k].num_elements for k in reduced_reals)
log_det_term = _log_det_tril(scale_tril) - _log_det_tril(self_scale_tril)
log_prob = Tensor(log_det_term - 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs)
log_det_term = _log_det_tri(self_scale_tri) - _log_det_tri(scale_tri)
log_prob = Tensor(log_det_term + 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs)
result = log_prob + Gaussian(loc, precision, inputs)

return result.reduce(ops.logaddexp, reduced_ints)
Expand All @@ -265,7 +281,7 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def sample(self, sampled_vars, sample_inputs=None):
def unscaled_sample(self, sampled_vars, sample_inputs=None):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
Expand All @@ -274,19 +290,21 @@ def sample(self, sampled_vars, sample_inputs=None):
# Partition inputs into sample_inputs + int_inputs + real_inputs.
if sample_inputs is None:
sample_inputs = OrderedDict()
assert frozenset(sample_inputs).isdisjoint(self.inputs)
else:
sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items()
if k not in self.inputs)
sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
inputs = sample_inputs.copy()
inputs.update(int_inputs)

if sampled_vars == frozenset(real_inputs):
scale_tril = torch.inverse(torch.cholesky(self.precision))
assert self.loc.shape == scale_tril.shape[:-1]
scale_tri = torch.inverse(torch.cholesky(self.precision)).transpose(-1, -2)
assert self.loc.shape == scale_tri.shape[:-1]
shape = sample_shape + self.loc.shape
white_noise = torch.randn(shape)
sample = self.loc + _mv(scale_tril, white_noise)
sample = self.loc + _mv(scale_tri, white_noise)
offsets, _ = _compute_offsets(real_inputs)
results = []
for key, domain in real_inputs.items():
Expand Down Expand Up @@ -317,13 +335,71 @@ def eager_add_gaussian_gaussian(op, lhs, rhs):
# Fuse aligned Gaussians.
precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc)
precision = lhs_precision + rhs_precision
scale_tril = torch.inverse(torch.cholesky(precision))
loc = _mv(scale_tril.transpose(-1, -2), _mv(scale_tril, precision_loc))
scale_tri = torch.inverse(torch.cholesky(precision)).transpose(-1, -2)
loc = _mv(scale_tri, _mv(scale_tri.transpose(-1, -2), precision_loc))
quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc)
likelihood = Tensor(-0.5 * quadratic_term, int_inputs)
return likelihood + Gaussian(loc, precision, inputs)


@eager.register(Integrate, Gaussian, Variable, frozenset)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to generalize this to arbitrary polynomials.

@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars:
assert real_vars == frozenset([integrand.name])
data = log_measure.loc * log_measure._log_normalizer.data.exp().unsqueeze(-1)
data = data.reshape(log_measure.loc.shape[:-1] + integrand.output.shape)
inputs = OrderedDict((k, d) for k, d in log_measure.inputs.items() if d.dtype != 'real')
return Tensor(data, inputs)

return None # defer to default implementation


@eager.register(Integrate, Gaussian, Gaussian, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars:

lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real')
rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real')
if lhs_reals == real_vars and rhs_reals <= real_vars:
inputs = OrderedDict((k, d) for t in (log_measure, integrand)
for k, d in t.inputs.items())
lhs_loc, lhs_precision = align_gaussian(inputs, log_measure)
rhs_loc, rhs_precision = align_gaussian(inputs, integrand)

# Compute the expectation of a non-normalized quadratic form.
# See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380.
# http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
lhs_scale_tri = torch.inverse(torch.cholesky(lhs_precision)).transpose(-1, -2)
lhs_covariance = torch.matmul(lhs_scale_tri, lhs_scale_tri.transpose(-1, -2))
dim = lhs_loc.size(-1)
norm = _det_tri(lhs_scale_tri) * (2 * math.pi) ** (0.5 * dim)
data = -0.5 * norm * (_vmv(rhs_precision, lhs_loc - rhs_loc) +
_trace_mm(rhs_precision, lhs_covariance))
inputs = OrderedDict((k, d) for k, d in inputs.items() if k not in reduced_vars)
result = Tensor(data, inputs)
return result.reduce(ops.add, reduced_vars - real_vars)

raise NotImplementedError('TODO implement partial integration')

return None # defer to default implementation


@monte_carlo.register(Integrate, Gaussian, Funsor, frozenset)
@integrator
def monte_carlo_integrate(log_measure, integrand, reduced_vars):
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real')
if real_vars:
log_measure = log_measure.sample(real_vars, monte_carlo.sample_inputs)
reduced_vars = reduced_vars | frozenset(monte_carlo.sample_inputs)
return Integrate(log_measure, integrand, reduced_vars)

return None # defer to default implementation


__all__ = [
'Gaussian',
'align_gaussian',
Expand Down
85 changes: 85 additions & 0 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import absolute_import, division, print_function

import functools
from collections import OrderedDict

import funsor.ops as ops
from funsor.contract import Contract
from funsor.terms import Funsor, Reduce, eager


class Integrate(Funsor):
"""
Funsor representing an integral wrt a log density funsor.
"""
def __init__(self, log_measure, integrand, reduced_vars):
assert isinstance(log_measure, Funsor)
assert isinstance(integrand, Funsor)
assert isinstance(reduced_vars, frozenset)
inputs = OrderedDict((k, d) for term in (log_measure, integrand)
for (k, d) in term.inputs.items()
if k not in reduced_vars)
output = integrand.output
super(Integrate, self).__init__(inputs, output)
self.log_measure = log_measure
self.integrand = integrand
self.reduced_vars = reduced_vars

def eager_subs(self, subs):
raise NotImplementedError('TODO')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just use Contract.eager_subs here



def _simplify_integrate(fn, log_measure, integrand, reduced_vars):
"""
Reduce free variables that do not appear in both inputs.
"""
if not reduced_vars:
return log_measure.exp() * integrand

log_measure_vars = frozenset(log_measure.inputs)
integrand_vars = frozenset(integrand.inputs)
assert reduced_vars <= log_measure_vars | integrand_vars
progress = False
if not reduced_vars <= log_measure_vars:
integrand = integrand.reduce(ops.add, reduced_vars - log_measure_vars)
reduced_vars = reduced_vars & log_measure_vars
progress = True
if not reduced_vars <= integrand_vars:
log_measure = log_measure.reduce(ops.logaddexp, reduced_vars - integrand_vars)
reduced_vars = reduced_vars & integrand_vars
progress = True
if progress:
return Integrate(log_measure, integrand, reduced_vars)

return fn(log_measure, integrand, reduced_vars)


def integrator(fn):
"""
Decorator for integration implementations.
"""
return functools.partial(_simplify_integrate, fn)


@eager.register(Integrate, Funsor, Funsor, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
return Contract(log_measure.exp(), integrand, reduced_vars)


@eager.register(Integrate, Reduce, Funsor, frozenset)
@integrator
def eager_integrate(log_measure, integrand, reduced_vars):
if log_measure.op is ops.logaddexp:
if not log_measure.reduced_vars.isdisjoint(reduced_vars):
raise NotImplementedError('TODO alpha convert')
arg = Integrate(log_measure.arg, integrand, reduced_vars)
return arg.reduce(ops.add, log_measure.reduced_vars)

return None # defer to default implementation
fritzo marked this conversation as resolved.
Show resolved Hide resolved


__all__ = [
'Integrate',
'integrator',
]
9 changes: 8 additions & 1 deletion funsor/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import os
import re
import types
from collections import OrderedDict

Expand All @@ -24,12 +25,18 @@ def interpret(cls, *args):
indent = ' ' * _STACK_SIZE
typenames = [cls.__name__] + [type(arg).__name__ for arg in args]
print(indent + ' '.join(typenames))

_STACK_SIZE += 1
try:
result = _INTERPRETATION(cls, *args)
finally:
_STACK_SIZE -= 1
print(indent + '-> ' + type(result).__name__)

if _DEBUG > 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can specify FUNSOR_DEBUG=2 for even more verbose output.

result_str = re.sub('\n', '\n ' + indent, str(result))
else:
result_str = type(result).__name__
print(indent + '-> ' + result_str)
return result
else:
def interpret(cls, *args):
Expand Down
Loading