Skip to content

Commit

Permalink
Implement general Gaussian funsor (#37)
Browse files Browse the repository at this point in the history
* WIP sketch Gaussian funsor

* Partially implement binary_gaussian_gaussian

* Implement marginalization along a dimension; add smoke test

* Add more comments

* Implement binary_gaussian_gaussian

* Sketch to_affine() and Affine funsor

* Add xfailing test for to_affine()

* Sketch more of eager_subs

* Remove affine stuff

* WIP fix align_gaussian() using align_tensor()

* Refactor and simplify align_tensor()

* Fix eager_subs

* Get smoke tests working

* Switch from scale_tril to precision representation

* Implement basic Normal -> Gaussian transform

* Rename normal conversions

* Fix filling in of defaults for distribution classes

* Add test for binary_gaussian_number

* Add test for binary_gaussian_tensor

* Add xfailing test for gaussian + gaussian

* Add more tests

* Add test of Normal vs Gaussian

* Fix math error in Gaussian .logsumexp()

* Fix bugs in Gaussian+Gaussian, align_gaussian()

* Fix kalman_filter.py, add to make test

* Add more distribution tests
  • Loading branch information
fritzo authored and eb8680 committed Mar 5, 2019
1 parent d333418 commit 448c1fa
Show file tree
Hide file tree
Showing 12 changed files with 815 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ lint: FORCE
test: lint FORCE
pytest -v test
python examples/discrete_hmm.py -n 2
@#python examples/kalman_filter.py --xfail-if-not-implemented
python examples/kalman_filter.py --xfail-if-not-implemented
@#python examples/ss_vae_delayed.py --xfail-if-not-implemented
@#python examples/minipyro.py --xfail-if-not-implemented
@echo PASS
Expand Down
30 changes: 11 additions & 19 deletions examples/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,36 @@ def main(args):

# A Gaussian HMM model.
def model(data):
prob = 1.
log_prob = funsor.to_funsor(0.)

x_curr = 0.
x_curr = funsor.Tensor(torch.tensor(0.))
for t, y in enumerate(data):
x_prev = x_curr

# A delayed sample statement.
x_curr = funsor.Variable('x_{}'.format(t), funsor.reals())
prob *= dist.Normal(loc=x_prev, scale=trans_noise, value=x_curr)
log_prob += dist.Normal(x_prev, trans_noise, value=x_curr)

# If we want, we can immediately marginalize out previous sample sites.
prob = prob.sum('x_{}'.format(t - 1))
# TODO prob = Clever(funsor.eval)(prob)
if isinstance(x_prev, funsor.Variable):
log_prob = log_prob.logsumexp(x_prev.name)

# An observe statement.
prob *= dist.Normal(loc=x_curr, scale=emit_noise, value=y)
log_prob += dist.Normal(x_curr, emit_noise, value=y)

return prob
log_prob = log_prob.logsumexp()
return log_prob

# Train model parameters.
print('---- training ----')
data = torch.randn(args.time_steps)
optim = torch.optim.Adam(params, lr=args.learning_rate)
for step in range(args.train_steps):
optim.zero_grad()
prob = model(data)
# TODO prob = Clever(funsor.eval)(prob)
loss = -prob.sum().log() # Integrates out delayed variables.
log_prob = model(data)
assert not log_prob.inputs, 'free variables remain'
loss = -log_prob.data
loss.backward()
optim.step()

# Serve by drawing a posterior sample.
print('---- serving ----')
prob = model(data)
prob = funsor.eval(prob.sum()) # Forward filter.
samples = prob.backward(prob.log()) # Bakward sample.
print(samples)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Kalman filter example")
Expand Down
5 changes: 3 additions & 2 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from funsor.terms import Funsor, Number, Variable, of_shape, to_funsor
from funsor.torch import Function, Tensor, arange, einsum, function

from . import distributions, domains, handlers, interpreter, minipyro, ops, terms, torch
from . import distributions, domains, gaussian, handlers, interpreter, minipyro, ops, terms, torch

__all__ = [
'Domain',
Expand All @@ -16,14 +16,15 @@
'Variable',
'arange',
'backward',
'bint',
'distributions',
'domains',
'einsum',
'find_domain',
'function',
'gaussian',
'handlers',
'interpreter',
'bint',
'minipyro',
'of_shape',
'ops',
Expand Down
65 changes: 65 additions & 0 deletions funsor/distributions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
from __future__ import absolute_import, division, print_function

import math
from collections import OrderedDict

import pyro.distributions as dist
import torch
from six import add_metaclass

import funsor.ops as ops
from funsor.domains import bint, reals
from funsor.gaussian import Gaussian
from funsor.terms import Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.torch import Tensor, align_tensors, materialize


def numbers_to_tensors(*args):
"""
Convert :class:`~funsor.terms.Number`s to :class:`funsor.torch.Tensor`s,
using any provided tensor as a prototype, if available.
"""
if any(isinstance(x, Number) for x in args):
new_tensor = torch.tensor
for x in args:
if isinstance(x, Tensor):
new_tensor = x.data.new_tensor
break
args = tuple(Tensor(new_tensor(x.data), dtype=x.dtype) if isinstance(x, Number) else x
for x in args)
return args


class Distribution(Funsor):
"""
Funsor backed by a PyTorch distribution object.
Expand Down Expand Up @@ -70,6 +89,7 @@ def __call__(cls, probs, value=None):
value = Variable('value', bint(size))
else:
value = to_funsor(value)
probs, value = numbers_to_tensors(probs, value)
return super(CategoricalMeta, cls).__call__(probs, value)


Expand Down Expand Up @@ -108,6 +128,7 @@ def __call__(cls, loc, scale, value=None):
value = Variable('value', reals())
else:
value = to_funsor(value)
loc, scale, value = numbers_to_tensors(loc, scale, value)
return super(NormalMeta, cls).__call__(loc, scale, value)


Expand All @@ -124,6 +145,50 @@ def eager_normal(loc, scale, value):
return Normal.eager_log_prob(loc=loc, scale=scale, value=value)


# Create a Gaussian from a ground observation.
@eager.register(Normal, Variable, (Number, Tensor), (Number, Tensor))
def eager_normal(loc, scale, value):
assert loc.output == reals()
inputs, (scale, value) = align_tensors(scale, value)
inputs.update(loc.inputs)

log_density = -0.5 * math.log(2 * math.pi) - scale.log()
loc = value.unsqueeze(-1)
precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1)
return Gaussian(log_density, loc, precision, inputs)


# Create a Gaussian from a ground observation.
@eager.register(Normal, (Number, Tensor), (Number, Tensor), Variable)
def eager_normal(loc, scale, value):
assert value.output == reals()
inputs, (loc, scale) = align_tensors(loc, scale)
inputs.update(value.inputs)

log_density = -0.5 * math.log(2 * math.pi) - scale.log()
loc = loc.unsqueeze(-1)
precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1)
return Gaussian(log_density, loc, precision, inputs)


# Create a Gaussian from a noisy identity transform.
# This is extrememly limited but suffices for examples/kalman_filter.py
@eager.register(Normal, Variable, (Number, Tensor), Variable)
def eager_normal(loc, scale, value):
assert loc.output == reals()
assert value.output == reals()
assert loc.name != value.name
inputs = loc.inputs.copy()
inputs.update(scale.inputs)
inputs.update(value.inputs)

log_density = -0.5 * math.log(2 * math.pi) - scale.data.log()
loc = scale.data.new_zeros(scale.data.shape + (2,))
p = scale.data.pow(-2)
precision = torch.stack([p, -p, -p, p], -1).reshape(p.shape + (2, 2))
return Gaussian(log_density, loc, precision, inputs)


__all__ = [
'Categorical',
'Distribution',
Expand Down
1 change: 1 addition & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Domain(namedtuple('Domain', ['shape', 'dtype'])):
"""
def __new__(cls, shape, dtype):
assert isinstance(shape, tuple)
assert all(isinstance(size, integer_types) for size in shape)
if isinstance(dtype, integer_types):
assert not shape
elif isinstance(dtype, str):
Expand Down
Loading

0 comments on commit 448c1fa

Please sign in to comment.