Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Monte Carlo interpretation of Integrate #54

Merged
merged 63 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
6d3c7fd
Sketch Monte Carlo interpretation of logaddexp reduction
fritzo Mar 7, 2019
dbe80dc
Merge branch 'master' into montecarlo
fritzo Mar 9, 2019
95d030f
Use AssociativeOp in patterns
fritzo Mar 9, 2019
ca3fdcc
Merge branch 'master' into montecarlo
fritzo Mar 13, 2019
a4a0ea7
Fix op pattern matcher
fritzo Mar 13, 2019
1df4425
Try eager before monte_carlo
fritzo Mar 13, 2019
c68e45b
Drop ops.sample, ops.marginal
fritzo Mar 13, 2019
4060457
Sketch VAE example using monte carlo interpretation
fritzo Mar 13, 2019
0bb8eba
Refactor, focusing on .sample() and .monte_carlo_logsumexp() methods
fritzo Mar 13, 2019
a799bc7
Fix vae example
fritzo Mar 13, 2019
a34bd82
Sketch Tensor.sample() (untested)
fritzo Mar 14, 2019
4dbe8b6
Fix cyclic import
fritzo Mar 14, 2019
3792e87
Sketch Gaussian.sample() (untested)
fritzo Mar 14, 2019
80eda9d
Implement Delta.sample()
fritzo Mar 14, 2019
3fa6956
Sketch Expectation class
fritzo Mar 14, 2019
4e75094
Sketch sampler implementations
fritzo Mar 14, 2019
5121478
Delete Expectation in favor of Integrate in a separate PR
fritzo Mar 15, 2019
bcbe76e
Revert .sample() sketch
fritzo Mar 15, 2019
eef211f
Update VAE example to use multi-output Functions
fritzo Mar 20, 2019
361fff8
Fix reductions in VAE
fritzo Mar 20, 2019
91fd5b2
Merge branch 'master' into montecarlo
fritzo Mar 20, 2019
ea587e3
Sketch support for multiple args in __getitem__
fritzo Mar 20, 2019
c41c6cb
Fix bugs in getitem_tensor_tensor
fritzo Mar 21, 2019
5920641
Add stronger tests for tensor getitem
fritzo Mar 21, 2019
2339f27
Add support for string indexing
fritzo Mar 21, 2019
7a692b1
Simplify vae example using multi-getitem
fritzo Mar 21, 2019
bbdca68
Add stub for Integrate
fritzo Mar 21, 2019
6193291
Merge branch 'master' into montecarlo
fritzo Mar 21, 2019
ab8107f
Fix typo
fritzo Mar 21, 2019
77f7f04
Sketch monte_carlo registration of Gaussian-Gaussian things
fritzo Mar 21, 2019
01ba6ad
Add stubs for Joint integration
fritzo Mar 21, 2019
db97b92
Fix typos
fritzo Mar 21, 2019
231967c
Sketch support for multiple samples
fritzo Mar 21, 2019
8e06fe4
Merge branch 'master' into montecarlo
fritzo Mar 21, 2019
339aa17
Fix test usage of registry
fritzo Mar 21, 2019
f5d189b
Fix bugs in gaussian integral
fritzo Mar 21, 2019
cb85e1f
Merge branch 'master' into montecarlo
fritzo Mar 21, 2019
ddcc70d
Merge branch 'master' into multi-getitem
fritzo Mar 21, 2019
f84c473
Handle scale factors in Funsor.sample()
fritzo Mar 22, 2019
801d0ff
Use Integrate in test_samplers.py
fritzo Mar 22, 2019
5d5f53e
Fix bug in Integrate; be less clever
fritzo Mar 22, 2019
d498604
Merge branch 'multi-getitem' into montecarlo
fritzo Mar 22, 2019
657779c
Add implementations of gaussian-linear integrals
fritzo Mar 22, 2019
7de197c
Add interpretation logging controlled by FUNSOR_DEBUG
fritzo Mar 22, 2019
853cd2b
Simplify debug printing
fritzo Mar 22, 2019
1588796
Merge branch 'debug-interp' into montecarlo
fritzo Mar 22, 2019
6688931
Fix lazy reduction for Joint.reduce()
fritzo Mar 22, 2019
58d693d
Merge branch 'master' into montecarlo
fritzo Mar 22, 2019
91dda7f
Merge branch 'fix-joint-reduce-lazy' into montecarlo
fritzo Mar 22, 2019
868fb66
Fix recursion bug
fritzo Mar 22, 2019
5fa8a6f
Merge branch 'fix-joint-reduce-lazy' into montecarlo
fritzo Mar 22, 2019
95f7ebc
Get univariate Gaussian sampling to mostly work
fritzo Mar 23, 2019
5ed4836
Fix bug in Tensor.eager_reduce with nontrivial output
fritzo Mar 23, 2019
78ad7b6
Fix output shape broadcasting in Tensor
fritzo Mar 23, 2019
c476a8e
Fix assert_close in test_samplers.py
fritzo Mar 23, 2019
1758620
Merge branch 'master' into montecarlo
fritzo Mar 23, 2019
1576534
Fix cholesky bugs
fritzo Mar 24, 2019
f10e09b
Fix bug in _trace_mm()
fritzo Mar 24, 2019
39586ad
Fixes for examples/vae.py
fritzo Mar 24, 2019
bee75b9
Remove examples/vae.py
fritzo Mar 24, 2019
7784fbb
Add docstrings
fritzo Mar 24, 2019
6f2d453
Updates per review
fritzo Mar 25, 2019
56ca93f
Revert accidental change
fritzo Mar 25, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ test: lint FORCE
python examples/discrete_hmm.py -n 2 -t 50 --lazy
python examples/kalman_filter.py --xfail-if-not-implemented
python examples/kalman_filter.py -n 2 -t 50 --lazy
@#python examples/vae.py
@#python examples/ss_vae_delayed.py --xfail-if-not-implemented
@#python examples/minipyro.py --xfail-if-not-implemented
@echo PASS
Expand Down
98 changes: 98 additions & 0 deletions examples/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import absolute_import, division, print_function

import argparse

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

import funsor
import funsor.ops as ops
import funsor.distributions as dist
from funsor.domains import reals


class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)

def forward(self, image):
h1 = F.relu(self.fc1(image))
return self.fc21(h1), self.fc22(h1)


class Decoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)

def forward(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))

def old_forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar


def main(args):
encoder = Encoder()
decoder = Decoder()

@funsor.function(reals(28, 28), reals(2, 20))
def encode(image):
loc, scale = encoder(image)
return torch.stack([loc, scale], dim=-2)

@funsor.function(reals(28, 20), reals(20))
def decode(z):
return decoder(z)

@funsor.interpreter.interpretation(funsor.interpreter.monte_carlo)
def loss_function(data):
loc, scale = encode(data)
z = funsor.Variable('z', reals(20))
q = dist.Normal(loc, scale, value=z)

probs = decode(z)
p = dist.Bernoulli(probs, value=data)

elbo = (q.exp() * (p - q)).reduce(ops.add)
loss = -elbo
return loss.data

train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True)

encoder.train()
decoder.train()
optimizer = optim.Adam(encoder.parameters() +
decoder.parameters(), lr=1e-3)
for epoch in range(args.num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
optimizer.zero_grad()
loss = loss_function(data)
loss.backward()
train_loss += loss.item()
optimizer.step()
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--epochs', type=int, default=10)
args = parser.parse_args()
main(args)
7 changes: 5 additions & 2 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from funsor.domains import reals
from funsor.ops import Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.torch import Tensor


class DeltaMeta(FunsorMeta):
Expand Down Expand Up @@ -66,7 +65,7 @@ def eager_subs(self, subs):
if value is not None:
if isinstance(value, Variable):
name = value.name
elif isinstance(value, (Number, Tensor)) and isinstance(point, (Number, Tensor)):
elif not any(d.dtype == 'real' for side in (value, point) for d in side.inputs.values()):
return (value == point).all().log() + log_density
else:
# TODO Compute a jacobian, update log_prob, and emit another Delta.
Expand All @@ -83,6 +82,10 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def sample(self, sampled_vars):
assert all(k == self.name for k in sampled_vars if k in self.inputs)
return self


@eager.register(Binary, Op, Delta, (Funsor, Delta, Align))
def eager_binary(op, lhs, rhs):
Expand Down
97 changes: 97 additions & 0 deletions funsor/expectation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import absolute_import, division, print_function

from collections import OrderedDict, defaultdict

import funsor.ops as ops
from funsor.domains import reals
from funsor.interpreter import interpretation, reinterpret
from funsor.optimizer import associate, Finitary
from funsor.terms import Funsor, Unary, eager, monte_carlo, to_funsor, Reduce


class Expectation(Funsor):
"""
Expectation of an ``integrand`` with resepct to a nonnegative ``measure``.
"""
def __init__(self, measure, integrand, reduced_vars):
assert isinstance(measure, Funsor)
assert isinstance(integrand, Funsor)
assert isinstance(reduced_vars, Funsor)
assert measure.output == reals()
inputs = OrderedDict((k, d) for part in (measure, integrand)
for k, d in part.inputs.items()
if k not in reduced_vars)
output = integrand.output
super(Expectation, self).__init__(inputs, output)
self.measure = measure
self.integrand = integrand
self.reduced_vars = reduced_vars

def eager_subs(self, subs):
raise NotImplementedError('TODO')


@eager.register(Expectation, Funsor, Funsor, frozenset)
def eager_expectation(measure, integrand, reduced_vars):
return (measure * integrand).reduce(ops.add, reduced_vars)


@monte_carlo.register(Expectation, Funsor, Funsor, frozenset)
fritzo marked this conversation as resolved.
Show resolved Hide resolved
def monte_carlo_expectation(measure, integrand, reduced_vars):
if not reduced_vars:
return measure * integrand

# Split measure into a finitary product of factors.
with interpretation(associate):
integrand = reinterpret(integrand)
factors = []
if isinstance(measure, Finitary) and measure.op is ops.mul:
factors.extend(measure.operands)
elif (isinstance(measure, Unary) and measure.op is ops.exp and
isinstance(measure.arg, Finitary) and measure.arg.op is ops.add):
for operand in measure.arg.operands:
factors.append(operand.exp())
else:
factors.append(measure)
if len(factors) != len(set(factors)):
raise NotImplementedError('TODO combine duplicates e.g. x*x -> x**2')

# Split integrand into a finitary sum of terms.
with interpretation(associate):
integrand = reinterpret(integrand)
terms = []
if isinstance(integrand, Finitary) and integrand.op is ops.add:
terms.extend(integrand.operands)
else:
terms.append(integrand)
if len(terms) != len(set(terms)):
raise NotImplementedError('TODO combine duplicates e.g. x+x -> x*2')

vars_to_factors = defaultdict(set)
for factor in factors:
for var in reduced_vars.intersection(factor.inputs):
vars_to_factors[var].add(factor)

# Compute each term separately.
# TODO share work across terms.
results = []
for term in terms:
term_reduced_vars = reduced_vars.intersection(term.inputs)
upstream_factors = set().union(*(vars_to_factors[v] for v in term_reduced_vars))
remaining = reduce(ops.add, set(factors) - upstream_factors, to_funsor(0))
upstream = reduce(ops.add, upstream_factors, to_funsor(0))
# Try analytic integration.
local = eager_expectation(upstream, term, term_reduced_vars)
if not isinstance(local, Reduce):
# Fall back to monte carlo integration.
upstream = upstream.sample(term_reduced_vars)
local = eager_expectation(upstream, term, term_reduced_vars)
results.append(eager_expectation(
remaining, local, reduced_vars - term_reduced_vars))

return reduce(ops.add, results)


__all__ = [
'Expectation',
]
25 changes: 25 additions & 0 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import torch
from pyro.distributions.util import broadcast_shape
from six import add_metaclass, integer_types
from six.moves import reduce

import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import reals
from funsor.ops import AddOp
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager
Expand Down Expand Up @@ -250,6 +252,29 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def sample(self, sampled_vars, sample_inputs=None):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
assert all(self.inputs[k].dtype == 'real' for k in sampled_vars)

int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
if sampled_vars == frozenset(real_inputs):
scale_tril = torch.inverse(torch.cholesky(self.precision))
assert self.loc.shape == scale_tril.shape[:-1]
white_noise = torch.randn(self.loc.shape)
sample = self.loc + _mv(scale_tril, white_noise)
offsets, _ = _compute_offsets(real_inputs)
results = []
for key, domain in real_inputs.items():
data = sample[..., offsets[key]: offsets[key] + domain.num_elements]
point = Tensor(data, int_inputs, domain)
results.append(Delta(key, point))
return reduce(ops.add, results)

raise NotImplementedError('TODO implement partial sampling of real variables')


@eager.register(Binary, AddOp, Gaussian, Gaussian)
def eager_add_gaussian_gaussian(op, lhs, rhs):
Expand Down
3 changes: 3 additions & 0 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def monte_carlo_logsumexp(self, reduced_vars):
raise NotImplementedError('TODO sample variables')


@eager.register(Joint, tuple, Funsor, Funsor)
def eager_joint(deltas, discrete, gaussian):
Expand Down
1 change: 0 additions & 1 deletion funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def elbo(model, guide, *args, **kwargs):
# FIXME do not marginalize; instead sample.
q = guide_joint.log_prob.reduce(ops.logaddexp)
tr = guide_joint.samples
tr.update(funsor.backward(ops.sample, q)) # force deferred samples?

# replay model against guide
with log_joint() as model_joint, replay(guide_trace=tr):
Expand Down
14 changes: 0 additions & 14 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,6 @@ def safediv(x, y):
return truediv(x, y)


# just a placeholder
@Op
def marginal(x, y):
raise ValueError


# just a placeholder
@Op
def sample(x, y):
raise ValueError


@Op
def reciprocal(x):
if isinstance(x, Number):
Expand Down Expand Up @@ -176,7 +164,6 @@ def reciprocal(x):
'log',
'log1p',
'lt',
'marginal',
'max',
'min',
'mul',
Expand All @@ -186,7 +173,6 @@ def reciprocal(x):
'pow',
'safediv',
'safesub',
'sample',
'sqrt',
'sub',
'truediv',
Expand Down
Loading