-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Monte Carlo interpretation of Integrate #54
Changes from 61 commits
6d3c7fd
dbe80dc
95d030f
ca3fdcc
a4a0ea7
1df4425
c68e45b
4060457
0bb8eba
a799bc7
a34bd82
4dbe8b6
3792e87
80eda9d
3fa6956
4e75094
5121478
bcbe76e
eef211f
361fff8
91fd5b2
ea587e3
c41c6cb
5920641
2339f27
7a692b1
bbdca68
6193291
ab8107f
77f7f04
01ba6ad
db97b92
231967c
8e06fe4
339aa17
f5d189b
cb85e1f
ddcc70d
f84c473
801d0ff
5d5f53e
d498604
657779c
7de197c
853cd2b
1588796
6688931
58d693d
91dda7f
868fb66
5fa8a6f
95f7ebc
5ed4836
78ad7b6
c476a8e
1758620
1576534
f10e09b
39586ad
bee75b9
7784fbb
6f2d453
56ca93f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,10 @@ | |
import funsor.ops as ops | ||
from funsor.delta import Delta | ||
from funsor.domains import reals | ||
from funsor.integrate import Integrate, integrator | ||
from funsor.montecarlo import monte_carlo | ||
from funsor.ops import AddOp | ||
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager | ||
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager | ||
from funsor.torch import Tensor, align_tensor, align_tensors, materialize | ||
from funsor.util import lazy_property | ||
|
||
|
@@ -26,10 +28,14 @@ def _issubshape(subshape, supershape): | |
return True | ||
|
||
|
||
def _log_det_tril(x): | ||
def _log_det_tri(x): | ||
return x.diagonal(dim1=-1, dim2=-2).log().sum(-1) | ||
|
||
|
||
def _det_tri(x): | ||
return x.diagonal(dim1=-1, dim2=-2).prod(-1) | ||
|
||
|
||
def _mv(mat, vec): | ||
return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1) | ||
|
||
|
@@ -44,6 +50,16 @@ def _vmv(mat, vec): | |
return result.squeeze(-1).squeeze(-1) | ||
|
||
|
||
def _trace_mm(x, y): | ||
""" | ||
Computes ``trace(x @ y)``. | ||
""" | ||
assert x.dim() >= 2 | ||
assert y.dim() >= 2 | ||
xy = x * y | ||
return xy.reshape(xy.shape[:-2] + (-1,)).sum(-1) | ||
|
||
|
||
def _compute_offsets(inputs): | ||
""" | ||
Compute offsets of real inputs into the concatenated Gaussian dims. | ||
|
@@ -195,7 +211,7 @@ def eager_subs(self, subs): | |
Tensor(self.precision, int_inputs)] | ||
tensors.extend(subs.values()) | ||
inputs, tensors = align_tensors(*tensors) | ||
batch_dim = self.loc.dim() - 1 | ||
batch_dim = tensors[0].dim() - 1 | ||
batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors)) | ||
(loc, precision), values = tensors[:2], tensors[2:] | ||
|
||
|
@@ -219,8 +235,8 @@ def eager_subs(self, subs): | |
@lazy_property | ||
def _log_normalizer(self): | ||
dim = self.loc.size(-1) | ||
log_det_term = _log_det_tril(torch.cholesky(self.precision)) | ||
data = log_det_term - 0.5 * math.log(2 * math.pi) * dim | ||
log_det_term = _log_det_tri(torch.cholesky(self.precision)) | ||
data = -log_det_term + 0.5 * math.log(2 * math.pi) * dim | ||
inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real') | ||
return Tensor(data, inputs) | ||
|
||
|
@@ -247,15 +263,15 @@ def eager_reduce(self, op, reduced_vars): | |
index = torch.tensor(index) | ||
|
||
loc = self.loc[..., index] | ||
self_scale_tril = torch.inverse(torch.cholesky(self.precision)) | ||
self_covariance = torch.matmul(self_scale_tril, self_scale_tril.transpose(-1, -2)) | ||
self_scale_tri = torch.inverse(torch.cholesky(self.precision)).transpose(-1, -2) | ||
self_covariance = torch.matmul(self_scale_tri, self_scale_tri.transpose(-1, -2)) | ||
covariance = self_covariance[..., index.unsqueeze(-1), index] | ||
scale_tril = torch.cholesky(covariance) | ||
inv_scale_tril = torch.inverse(scale_tril) | ||
precision = torch.matmul(inv_scale_tril, inv_scale_tril.transpose(-1, -2)) | ||
scale_tri = torch.cholesky(covariance) | ||
inv_scale_tri = torch.inverse(scale_tri) | ||
precision = torch.matmul(inv_scale_tri.transpose(-1, -2), inv_scale_tri) | ||
reduced_dim = sum(self.inputs[k].num_elements for k in reduced_reals) | ||
log_det_term = _log_det_tril(scale_tril) - _log_det_tril(self_scale_tril) | ||
log_prob = Tensor(log_det_term - 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs) | ||
log_det_term = _log_det_tri(self_scale_tri) - _log_det_tri(scale_tri) | ||
log_prob = Tensor(log_det_term + 0.5 * math.log(2 * math.pi) * reduced_dim, int_inputs) | ||
result = log_prob + Gaussian(loc, precision, inputs) | ||
|
||
return result.reduce(ops.logaddexp, reduced_ints) | ||
|
@@ -265,7 +281,7 @@ def eager_reduce(self, op, reduced_vars): | |
|
||
return None # defer to default implementation | ||
|
||
def sample(self, sampled_vars, sample_inputs=None): | ||
def unscaled_sample(self, sampled_vars, sample_inputs=None): | ||
sampled_vars = sampled_vars.intersection(self.inputs) | ||
if not sampled_vars: | ||
return self | ||
|
@@ -274,19 +290,21 @@ def sample(self, sampled_vars, sample_inputs=None): | |
# Partition inputs into sample_inputs + int_inputs + real_inputs. | ||
if sample_inputs is None: | ||
sample_inputs = OrderedDict() | ||
assert frozenset(sample_inputs).isdisjoint(self.inputs) | ||
else: | ||
sample_inputs = OrderedDict((k, d) for k, d in sample_inputs.items() | ||
if k not in self.inputs) | ||
sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) | ||
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real') | ||
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real') | ||
inputs = sample_inputs.copy() | ||
inputs.update(int_inputs) | ||
|
||
if sampled_vars == frozenset(real_inputs): | ||
scale_tril = torch.inverse(torch.cholesky(self.precision)) | ||
assert self.loc.shape == scale_tril.shape[:-1] | ||
scale_tri = torch.inverse(torch.cholesky(self.precision)).transpose(-1, -2) | ||
assert self.loc.shape == scale_tri.shape[:-1] | ||
shape = sample_shape + self.loc.shape | ||
white_noise = torch.randn(shape) | ||
sample = self.loc + _mv(scale_tril, white_noise) | ||
sample = self.loc + _mv(scale_tri, white_noise) | ||
offsets, _ = _compute_offsets(real_inputs) | ||
results = [] | ||
for key, domain in real_inputs.items(): | ||
|
@@ -317,13 +335,71 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): | |
# Fuse aligned Gaussians. | ||
precision_loc = _mv(lhs_precision, lhs_loc) + _mv(rhs_precision, rhs_loc) | ||
precision = lhs_precision + rhs_precision | ||
scale_tril = torch.inverse(torch.cholesky(precision)) | ||
loc = _mv(scale_tril.transpose(-1, -2), _mv(scale_tril, precision_loc)) | ||
scale_tri = torch.inverse(torch.cholesky(precision)).transpose(-1, -2) | ||
loc = _mv(scale_tri, _mv(scale_tri.transpose(-1, -2), precision_loc)) | ||
quadratic_term = _vmv(lhs_precision, loc - lhs_loc) + _vmv(rhs_precision, loc - rhs_loc) | ||
likelihood = Tensor(-0.5 * quadratic_term, int_inputs) | ||
return likelihood + Gaussian(loc, precision, inputs) | ||
|
||
|
||
@eager.register(Integrate, Gaussian, Variable, frozenset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to generalize this to arbitrary polynomials. |
||
@integrator | ||
def eager_integrate(log_measure, integrand, reduced_vars): | ||
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') | ||
if real_vars: | ||
assert real_vars == frozenset([integrand.name]) | ||
data = log_measure.loc * log_measure._log_normalizer.data.exp().unsqueeze(-1) | ||
data = data.reshape(log_measure.loc.shape[:-1] + integrand.output.shape) | ||
inputs = OrderedDict((k, d) for k, d in log_measure.inputs.items() if d.dtype != 'real') | ||
return Tensor(data, inputs) | ||
|
||
return None # defer to default implementation | ||
|
||
|
||
@eager.register(Integrate, Gaussian, Gaussian, frozenset) | ||
@integrator | ||
def eager_integrate(log_measure, integrand, reduced_vars): | ||
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') | ||
if real_vars: | ||
|
||
lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') | ||
rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') | ||
if lhs_reals == real_vars and rhs_reals <= real_vars: | ||
inputs = OrderedDict((k, d) for t in (log_measure, integrand) | ||
for k, d in t.inputs.items()) | ||
lhs_loc, lhs_precision = align_gaussian(inputs, log_measure) | ||
rhs_loc, rhs_precision = align_gaussian(inputs, integrand) | ||
|
||
# Compute the expectation of a non-normalized quadratic form. | ||
# See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. | ||
# http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf | ||
lhs_scale_tri = torch.inverse(torch.cholesky(lhs_precision)).transpose(-1, -2) | ||
lhs_covariance = torch.matmul(lhs_scale_tri, lhs_scale_tri.transpose(-1, -2)) | ||
dim = lhs_loc.size(-1) | ||
norm = _det_tri(lhs_scale_tri) * (2 * math.pi) ** (0.5 * dim) | ||
data = -0.5 * norm * (_vmv(rhs_precision, lhs_loc - rhs_loc) + | ||
_trace_mm(rhs_precision, lhs_covariance)) | ||
inputs = OrderedDict((k, d) for k, d in inputs.items() if k not in reduced_vars) | ||
result = Tensor(data, inputs) | ||
return result.reduce(ops.add, reduced_vars - real_vars) | ||
|
||
raise NotImplementedError('TODO implement partial integration') | ||
|
||
return None # defer to default implementation | ||
|
||
|
||
@monte_carlo.register(Integrate, Gaussian, Funsor, frozenset) | ||
@integrator | ||
def monte_carlo_integrate(log_measure, integrand, reduced_vars): | ||
real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') | ||
if real_vars: | ||
log_measure = log_measure.sample(real_vars, monte_carlo.sample_inputs) | ||
reduced_vars = reduced_vars | frozenset(monte_carlo.sample_inputs) | ||
return Integrate(log_measure, integrand, reduced_vars) | ||
|
||
return None # defer to default implementation | ||
|
||
|
||
__all__ = [ | ||
'Gaussian', | ||
'align_gaussian', | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import functools | ||
from collections import OrderedDict | ||
|
||
import funsor.ops as ops | ||
from funsor.contract import Contract | ||
from funsor.terms import Funsor, Reduce, eager | ||
|
||
|
||
class Integrate(Funsor): | ||
""" | ||
Funsor representing an integral wrt a log density funsor. | ||
""" | ||
def __init__(self, log_measure, integrand, reduced_vars): | ||
assert isinstance(log_measure, Funsor) | ||
assert isinstance(integrand, Funsor) | ||
assert isinstance(reduced_vars, frozenset) | ||
inputs = OrderedDict((k, d) for term in (log_measure, integrand) | ||
for (k, d) in term.inputs.items() | ||
if k not in reduced_vars) | ||
output = integrand.output | ||
super(Integrate, self).__init__(inputs, output) | ||
self.log_measure = log_measure | ||
self.integrand = integrand | ||
self.reduced_vars = reduced_vars | ||
|
||
def eager_subs(self, subs): | ||
raise NotImplementedError('TODO') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can just use |
||
|
||
|
||
def _simplify_integrate(fn, log_measure, integrand, reduced_vars): | ||
""" | ||
Reduce free variables that do not appear in both inputs. | ||
""" | ||
if not reduced_vars: | ||
return log_measure.exp() * integrand | ||
|
||
log_measure_vars = frozenset(log_measure.inputs) | ||
integrand_vars = frozenset(integrand.inputs) | ||
assert reduced_vars <= log_measure_vars | integrand_vars | ||
progress = False | ||
if not reduced_vars <= log_measure_vars: | ||
integrand = integrand.reduce(ops.add, reduced_vars - log_measure_vars) | ||
reduced_vars = reduced_vars & log_measure_vars | ||
progress = True | ||
if not reduced_vars <= integrand_vars: | ||
log_measure = log_measure.reduce(ops.logaddexp, reduced_vars - integrand_vars) | ||
reduced_vars = reduced_vars & integrand_vars | ||
progress = True | ||
if progress: | ||
return Integrate(log_measure, integrand, reduced_vars) | ||
|
||
return fn(log_measure, integrand, reduced_vars) | ||
|
||
|
||
def integrator(fn): | ||
""" | ||
Decorator for integration implementations. | ||
""" | ||
return functools.partial(_simplify_integrate, fn) | ||
|
||
|
||
@eager.register(Integrate, Funsor, Funsor, frozenset) | ||
@integrator | ||
def eager_integrate(log_measure, integrand, reduced_vars): | ||
return Contract(log_measure.exp(), integrand, reduced_vars) | ||
|
||
|
||
@eager.register(Integrate, Reduce, Funsor, frozenset) | ||
@integrator | ||
def eager_integrate(log_measure, integrand, reduced_vars): | ||
if log_measure.op is ops.logaddexp: | ||
if not log_measure.reduced_vars.isdisjoint(reduced_vars): | ||
raise NotImplementedError('TODO alpha convert') | ||
arg = Integrate(log_measure.arg, integrand, reduced_vars) | ||
return arg.reduce(ops.add, log_measure.reduced_vars) | ||
|
||
return None # defer to default implementation | ||
fritzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
__all__ = [ | ||
'Integrate', | ||
'integrator', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import os | ||
import re | ||
import types | ||
from collections import OrderedDict | ||
|
||
|
@@ -24,12 +25,18 @@ def interpret(cls, *args): | |
indent = ' ' * _STACK_SIZE | ||
typenames = [cls.__name__] + [type(arg).__name__ for arg in args] | ||
print(indent + ' '.join(typenames)) | ||
|
||
_STACK_SIZE += 1 | ||
try: | ||
result = _INTERPRETATION(cls, *args) | ||
finally: | ||
_STACK_SIZE -= 1 | ||
print(indent + '-> ' + type(result).__name__) | ||
|
||
if _DEBUG > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can specify |
||
result_str = re.sub('\n', '\n ' + indent, str(result)) | ||
else: | ||
result_str = type(result).__name__ | ||
print(indent + '-> ' + result_str) | ||
return result | ||
else: | ||
def interpret(cls, *args): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out some of our matrices are upper triangular, but that's ok since these functions work with both upper- and lower- triangular matrices.