Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement monte carlo .sample() methods #75

Merged
merged 11 commits into from
Mar 15, 2019
7 changes: 5 additions & 2 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from funsor.domains import reals
from funsor.ops import Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor
from funsor.torch import Tensor


class DeltaMeta(FunsorMeta):
Expand Down Expand Up @@ -66,7 +65,7 @@ def eager_subs(self, subs):
if value is not None:
if isinstance(value, Variable):
name = value.name
elif isinstance(value, (Number, Tensor)) and isinstance(point, (Number, Tensor)):
elif not any(d.dtype == 'real' for side in (value, point) for d in side.inputs.values()):
return (value == point).all().log() + log_density
else:
# TODO Compute a jacobian, update log_prob, and emit another Delta.
Expand All @@ -83,6 +82,10 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def sample(self, sampled_vars):
assert all(k == self.name for k in sampled_vars if k in self.inputs)
return self


@eager.register(Binary, Op, Delta, (Funsor, Delta, Align))
def eager_binary(op, lhs, rhs):
Expand Down
6 changes: 5 additions & 1 deletion funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def find_domain(op, *domains):
assert callable(op), op
assert all(isinstance(arg, Domain) for arg in domains)
if len(domains) == 1:
return domains[0]
dtype = domains[0].dtype
shape = domains[0].shape
if op is ops.log or op is ops.exp:
dtype = 'real'
return Domain(shape, dtype)

lhs, rhs = domains
if op is ops.getitem:
Expand Down
60 changes: 55 additions & 5 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import torch
from pyro.distributions.util import broadcast_shape
from six import add_metaclass, integer_types
from six.moves import reduce

import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import reals
from funsor.ops import AddOp
from funsor.terms import Binary, Funsor, FunsorMeta, Number, eager
from funsor.torch import Tensor, align_tensor, align_tensors, materialize
from funsor.util import lazy_property


def _issubshape(subshape, supershape):
Expand Down Expand Up @@ -123,6 +126,13 @@ class Gaussian(Funsor):
"""
Funsor representing a batched joint Gaussian distribution as a log-density
function.

Note that :class:`Gaussian`s are not normalized, rather they are
canonicalized to evaluate to zero at their maximum value (at ``loc``). This
canonical form is useful because it allows :class:`Gaussian`s with
incomplete information, i.e. zero eigenvalues in the precision matrix.
These incomplete distributions arise when making low-dimensional
observations on higher dimensional hidden state.
"""
def __init__(self, loc, precision, inputs):
assert isinstance(loc, torch.Tensor)
Expand Down Expand Up @@ -206,6 +216,14 @@ def eager_subs(self, subs):

raise NotImplementedError('TODO implement partial substitution of real variables')

@lazy_property
def _log_normalizer(self):
dim = self.loc.size(-1)
log_det_term = _log_det_tril(torch.cholesky(self.precision))
data = log_det_term - 0.5 * math.log(2 * math.pi) * dim
inputs = OrderedDict((k, v) for k, v in self.inputs.items() if v.dtype != 'real')
return Tensor(data, inputs)

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp:
# Marginalize out real variables, but keep mixtures lazy.
Expand All @@ -217,13 +235,10 @@ def eager_reduce(self, op, reduced_vars):
return None # defer to default implementation

inputs = OrderedDict((k, d) for k, d in self.inputs.items() if k not in reduced_reals)
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
if reduced_reals == real_vars:
dim = self.loc.size(-1)
log_det_term = _log_det_tril(torch.cholesky(self.precision))
data = log_det_term - 0.5 * math.log(2 * math.pi) * dim
result = Tensor(data, int_inputs)
result = self._log_normalizer
else:
int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')
offsets, _ = _compute_offsets(self.inputs)
index = []
for key, domain in inputs.items():
Expand All @@ -250,6 +265,41 @@ def eager_reduce(self, op, reduced_vars):

return None # defer to default implementation

def sample(self, sampled_vars, sample_inputs=None):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
assert all(self.inputs[k].dtype == 'real' for k in sampled_vars)

# Partition inputs into sample_inputs + int_inputs + real_inputs.
if sample_inputs is None:
sample_inputs = OrderedDict()
assert frozenset(sample_inputs).isdisjoint(self.inputs)
sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
int_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype != 'real')
real_inputs = OrderedDict((k, d) for k, d in self.inputs.items() if d.dtype == 'real')
inputs = sample_inputs.copy()
inputs.update(int_inputs)

if sampled_vars == frozenset(real_inputs):
scale_tril = torch.inverse(torch.cholesky(self.precision))
assert self.loc.shape == scale_tril.shape[:-1]
shape = sample_shape + self.loc.shape
white_noise = torch.randn(shape)
sample = self.loc + _mv(scale_tril, white_noise)
offsets, _ = _compute_offsets(real_inputs)
results = []
for key, domain in real_inputs.items():
data = sample[..., offsets[key]: offsets[key] + domain.num_elements]
data = data.reshape(shape[:-1] + domain.shape)
point = Tensor(data, inputs)
assert point.output == domain
results.append(Delta(key, point))
results.append(self._log_normalizer)
return reduce(ops.add, results)

raise NotImplementedError('TODO implement partial sampling of real variables')


@eager.register(Binary, AddOp, Gaussian, Gaussian)
def eager_add_gaussian_gaussian(op, lhs, rhs):
Expand Down
71 changes: 58 additions & 13 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict

from six import add_metaclass
from six.moves import reduce

import funsor.ops as ops
from funsor.delta import Delta
Expand Down Expand Up @@ -75,29 +76,65 @@ def eager_subs(self, subs):

def eager_reduce(self, op, reduced_vars):
if op is ops.logaddexp:
# Integrate out delayed discrete variables.
discrete_vars = reduced_vars.intersection(self.discrete.inputs)
mixture_params = frozenset(self.gaussian.inputs).union(*(x.point.inputs for x in self.deltas))
lazy_vars = discrete_vars & mixture_params # Mixtures must remain lazy.
discrete_vars -= mixture_params
# Integrate out degenerate variables, i.e. drop selected delta.
deltas = []
remaining_vars = set(reduced_vars)
for d in self.deltas:
if d.name in reduced_vars:
remaining_vars.remove(d.name)
else:
deltas.append(d)
deltas = tuple(deltas)
reduced_vars = frozenset(remaining_vars)

# Integrate out delayed discrete variables, but keep mixtures lazy.
lazy_vars = reduced_vars.difference(self.gaussian.inputs, *(x.inputs for x in deltas))
discrete_vars = reduced_vars.intersection(self.discrete.inputs).difference(lazy_vars)
discrete = self.discrete.reduce(op, discrete_vars)
reduced_vars = reduced_vars.difference(discrete_vars, lazy_vars)

# Integrate out delayed gaussian variables.
gaussian_vars = reduced_vars.intersection(self.gaussian.inputs)
gaussian = self.gaussian.reduce(ops.logaddexp, gaussian_vars)
assert (reduced_vars - gaussian_vars).issubset(d.name for d in self.deltas)

# Integrate out delayed degenerate variables, i.e. drop them.
deltas = tuple(d for d in self.deltas if d.name not in reduced_vars)

assert not lazy_vars
return (Joint(deltas, discrete) + gaussian).reduce(ops.logaddexp, lazy_vars)
reduced_vars = reduced_vars.difference(gaussian_vars)

# Account for remaining reduced vars that were inputs to dropped deltas.
eager_result = Joint(deltas, discrete) + gaussian
reduced_vars |= lazy_vars.difference(eager_result.inputs)
lazy_vars = lazy_vars.intersection(eager_result.inputs)
if reduced_vars:
eager_result += ops.log(reduce(ops.mul, [self.inputs[v].dtype for v in reduced_vars]))

# Return a value only if progress has been made.
if eager_result is self:
return None # defer to default implementation
else:
return eager_result.reduce(ops.logaddexp, lazy_vars)

if op is ops.add:
raise NotImplementedError('TODO product-reduce along a plate dimension')

return None # defer to default implementation

def sample(self, sampled_vars, sample_inputs=None):
if sample_inputs is None:
sample_inputs = OrderedDict()
assert frozenset(sample_inputs).isdisjoint(self.inputs)
discrete_vars = sampled_vars.intersection(self.discrete.inputs)
gaussian_vars = frozenset(k for k in sampled_vars
if k in self.gaussian.inputs
if self.gaussian.inputs[k].dtype == 'real')
result = self
if discrete_vars:
discrete = result.discrete.sample(discrete_vars, sample_inputs)
result = Joint(result.deltas, gaussian=result.gaussian) + discrete
if gaussian_vars:
sample_inputs = OrderedDict((k, v) for k, v in sample_inputs.items()
if k not in result.gaussian.inputs)
gaussian = result.gaussian.sample(gaussian_vars, sample_inputs)
result = Joint(result.deltas, result.discrete) + gaussian
return result


@eager.register(Joint, tuple, Funsor, Funsor)
def eager_joint(deltas, discrete, gaussian):
Expand Down Expand Up @@ -151,10 +188,18 @@ def eager_add(op, joint, other):
return Joint(joint.deltas, joint.discrete + other, joint.gaussian)


@eager.register(Binary, Op, Joint, (Number, Tensor))
def eager_add(op, joint, other):
if op is ops.sub:
return joint + -other

return None # defer to default implementation


@eager.register(Binary, AddOp, Joint, Gaussian)
def eager_add(op, joint, other):
# Update with a delayed gaussian random variable.
subs = tuple((d.name, d.point) for d in joint.deltas if d in other.inputs)
subs = tuple((d.name, d.point) for d in joint.deltas if d.name in other.inputs)
if subs:
other = other.eager_subs(subs)
if joint.gaussian is not Number(0):
Expand Down
1 change: 0 additions & 1 deletion funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def elbo(model, guide, *args, **kwargs):
# FIXME do not marginalize; instead sample.
q = guide_joint.log_prob.reduce(ops.logaddexp)
tr = guide_joint.samples
tr.update(funsor.backward(ops.sample, q)) # force deferred samples?

# replay model against guide
with log_joint() as model_joint, replay(guide_trace=tr):
Expand Down
14 changes: 0 additions & 14 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,6 @@ def safediv(x, y):
return truediv(x, y)


# just a placeholder
@Op
def marginal(x, y):
raise ValueError


# just a placeholder
@Op
def sample(x, y):
raise ValueError


@Op
def reciprocal(x):
if isinstance(x, Number):
Expand Down Expand Up @@ -176,7 +164,6 @@ def reciprocal(x):
'log',
'log1p',
'lt',
'marginal',
'max',
'min',
'mul',
Expand All @@ -186,7 +173,6 @@ def reciprocal(x):
'pow',
'safediv',
'safesub',
'sample',
'sqrt',
'sub',
'truediv',
Expand Down
39 changes: 37 additions & 2 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import itertools
import numbers
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, Hashable
from collections import Hashable, OrderedDict
from weakref import WeakValueDictionary

from six import add_metaclass, integer_types
from six.moves import reduce

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.domains import Domain, bint, find_domain
from funsor.domains import Domain, bint, find_domain, reals
from funsor.interpreter import interpret
from funsor.ops import AssociativeOp, Op
from funsor.registry import KeyedRegistry
Expand Down Expand Up @@ -213,6 +213,41 @@ def reduce(self, op, reduced_vars=None):
assert reduced_vars.issubset(self.inputs)
return Reduce(op, self, reduced_vars)

def sample(self, sampled_vars, sample_inputs=None):
"""
Create a Monte Carlo approximation to this funsor by replacing
functions of ``sampled_vars`` with :class:`~funsor.delta.Delta`s.

If ``sample_inputs`` is not provided, the result is a :class:`Funsor`
with the same ``.inputs`` and ``.output`` as the original funsor, so
that self can be replaced by the sample in expectation computations::

y = x.sample(sampled_vars)
assert y.inputs == x.inputs
assert y.output == x.output
exact = (x.exp() * integrand).reduce(ops.add)
approx = (y.exp() * integrand).reduce(ops.add)

If ``sample_inputs`` is provided, this creates a batch of samples
that are intended to be averaged, however this reduction is not
performed by the :meth:`sample` method::

y = x.sample(sampled_vars, sample_inputs)
total = reduce(ops.mul, d.num_elements) for d in sample_inputs.values())
exact = (x.exp() * integrand).reduce(ops.add)
approx = (y.exp() * integrand).reduce(ops.add) / total

:param frozenset sampled_vars: A set of input variables to sample.
:param OrderedDict sample_inputs: An optional mapping from variable
name to :class:`~funsor.domains.Domain` over which samples will
be batched.
"""
assert self.output == reals()
assert isinstance(sampled_vars, frozenset)
if sampled_vars.isdisjoint(self.inputs):
return self
raise NotImplementedError

def align(self, names):
"""
Align this funsor to match given ``names``.
Expand Down
Loading