Skip to content

Commit

Permalink
Add Lambda and Independent funsors (#97)
Browse files Browse the repository at this point in the history
* Add a Lambda funsor, inverse to getitem

* Sketch Uncurry funsor

* Sketch Uncurry-Delta-Lambda pattern

* Sketch Joint-Uncurry-Delta rule

* Sketch uncurry-distribution test

* Change to_funsor second arg from dtype to Domain

* Add Funsor.__contains__

* Fix test_normal_uncurry

* Rename Uncurry to Independent
  • Loading branch information
fritzo authored and eb8680 committed Apr 1, 2019
1 parent 89313ef commit 306aca6
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 9 deletions.
4 changes: 3 additions & 1 deletion funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from funsor.domains import Domain, bint, find_domain, reals
from funsor.integrate import Integrate
from funsor.interpreter import reinterpret
from funsor.terms import Funsor, Number, Variable, of_shape, to_data, to_funsor
from funsor.terms import Funsor, Independent, Lambda, Number, Variable, of_shape, to_data, to_funsor
from funsor.torch import Tensor, arange, torch_einsum

from . import (
Expand All @@ -29,7 +29,9 @@
__all__ = [
'Domain',
'Funsor',
'Independent',
'Integrate',
'Lambda',
'Number',
'Tensor',
'Variable',
Expand Down
29 changes: 28 additions & 1 deletion funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,20 @@
from funsor.interpreter import debug_logged
from funsor.ops import AddOp, SubOp, TransformOp
from funsor.registry import KeyedRegistry
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, to_funsor
from funsor.terms import (
Align,
Binary,
Funsor,
FunsorMeta,
Independent,
Lambda,
Number,
Subs,
Unary,
Variable,
eager,
to_funsor
)


class DeltaMeta(FunsorMeta):
Expand Down Expand Up @@ -123,6 +136,20 @@ def eager_add(op, lhs, rhs):
return None # defer to default implementation


@eager.register(Independent, Delta, str, str)
def eager_independent(delta, reals_var, bint_var):
if delta.name == reals_var:
i = Variable(bint_var, delta.inputs[bint_var])
point = Lambda(i, delta.point)
if bint_var in delta.log_density.inputs:
log_density = delta.log_density.reduce(ops.add, bint_var)
else:
log_density = delta.log_density * delta.inputs[bint_var].dtype
return Delta(delta.name, point, log_density)

return None # defer to default implementation


@eager.register(Integrate, Delta, Funsor, frozenset)
@integrator
def eager_integrate(delta, integrand, reduced_vars):
Expand Down
19 changes: 18 additions & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.integrate import Integrate, integrator
from funsor.montecarlo import monte_carlo
from funsor.ops import AddOp, NegOp, SubOp
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Subs, Unary, Variable, eager, to_funsor
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Independent, Number, Subs, Unary, Variable, eager, to_funsor
from funsor.torch import Tensor, arange


Expand Down Expand Up @@ -162,6 +162,23 @@ def eager_joint(deltas, discrete, gaussian):
return None # defer to default implementation


@eager.register(Independent, Joint, str, str)
def eager_independent(joint, reals_var, bint_var):
for i, delta in enumerate(joint.deltas):
if delta.name == reals_var:
delta = Independent(delta, reals_var, bint_var)
deltas = joint.deltas[:i] + (delta,) + joint.deltas[1+i:]
discrete = joint.discrete
if bint_var in discrete.inputs:
discrete = discrete.reduce(ops.add, bint_var)
gaussian = joint.gaussian
if bint_var in gaussian.inputs:
gaussian = gaussian.reduce(ops.add, bint_var)
return Joint(deltas, discrete, gaussian)

return None # defer to default implementation


################################################################################
# Patterns to update a Joint with other funsors
################################################################################
Expand Down
102 changes: 101 additions & 1 deletion funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def pretty(self):
self._pretty(lines)
return '\n'.join('| ' * indent + text for indent, text in lines)

def __contains__(self, item):
raise TypeError

def __call__(self, *args, **kwargs):
"""
Partially evaluates this funsor by substituting dimensions.
Expand Down Expand Up @@ -924,7 +927,7 @@ def eager_subs(self, subs):

if pos is None:
# Eagerly recurse into components.
assert not any(self.name in v.inputs for k, v in subs)
assert not any(self.name in v.inputs and self.name != k for k, v in subs)
components = tuple(Subs(x, subs) for x in self.components)
return Stack(components, self.name)

Expand Down Expand Up @@ -963,6 +966,101 @@ def eager_reduce(self, op, reduced_vars):
return Stack(components, self.name)


class Lambda(Funsor):
"""
Lazy inverse to ``ops.getitem``.
This is useful to simulate higher-order functions of integers
by representing those functions as arrays.
"""
def __init__(self, var, expr):
assert isinstance(var, Variable)
assert isinstance(var.dtype, integer_types)
assert isinstance(expr, Funsor)
inputs = expr.inputs.copy()
inputs.pop(var.name, None)
shape = (var.dtype,) + expr.output.shape
output = Domain(shape, expr.dtype)
super(Lambda, self).__init__(inputs, output)
self.var = var
self.expr = expr

def eager_subs(self, subs):
subs = tuple((k, v) for k, v in subs if k != self.var.name)
if not any(k in self.inputs for k, v in subs):
return self
if any(self.var.name in v.inputs for k, v in subs):
raise NotImplementedError('TODO alpha-convert to avoid conflict')
expr = self.expr.eager_subs(subs)
return Lambda(self.var, expr)


@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align))
def eager_getitem_lambda(op, lhs, rhs):
if op.offset == 0:
return lhs.expr.eager_subs(((lhs.var.name, rhs),))
if lhs.var.name in rhs.inputs:
raise NotImplementedError('TODO alpha-convert to avoid conflict')
expr = GetitemOp(op.offset - 1)(lhs.expr, rhs)
return Lambda(lhs.var, expr)


class Independent(Funsor):
"""
Creates an independent diagonal distribution.
This is equivalent to substitution followed by reduction::
f = ...
assert f.inputs['x'] == reals(4, 5)
assert f.inputs['i'] == bint(3)
g = Independent(f, 'x', 'i')
assert g.inputs['x'] == reals(3, 4, 5)
assert 'i' not in g.inputs
x = Variable('x', reals(3, 4, 5))
g == f(x=x['i']).reduce(ops.logaddexp, 'i')
"""
def __init__(self, fn, reals_var, bint_var):
assert isinstance(fn, Funsor)
assert isinstance(reals_var, str)
assert reals_var in fn.inputs
assert fn.inputs[reals_var].dtype == 'real'
assert isinstance(bint_var, str)
assert bint_var in fn.inputs
assert isinstance(fn.inputs[bint_var].dtype, int)
inputs = fn.inputs.copy()
shape = (inputs.pop(bint_var).dtype,) + inputs[reals_var].shape
inputs[reals_var] = reals(*shape)
super(Independent, self).__init__(inputs, fn.output)
self.fn = fn
self.reals_var = reals_var
self.bint_var = bint_var

def eager_subs(self, subs):
fn_subs = []
reals_value = None
for k, v in subs:
if self.bint_var in v.inputs:
raise NotImplementedError('TODO alpha-convert')
if k == self.reals_var:
reals_value = v
else:
fn_subs.append((k, v))
fn = Subs(self.fn, tuple(fn_subs))
if reals_value is None:
return Independent(fn, self.reals_var, self.bint_var)
factors = fn(**{self.reals_var: reals_value[self.bint_var]})
return factors.reduce(ops.add, self.bint_var)

def unscaled_sample(self, sampled_vars, sample_inputs):
if self.bint_var in sampled_vars or self.bint_var in sample_inputs:
raise NotImplementedError('TODO alpha-convert')
fn = self.fn.unscaled_sample(sampled_vars, sample_inputs)
return Independent(fn, self.reals_var, self.bint_var)


def _of_shape(fn, shape):
args, vargs, kwargs, defaults = getargspec(fn)
assert not vargs
Expand Down Expand Up @@ -1012,6 +1110,8 @@ def _log1p(x):
__all__ = [
'Binary',
'Funsor',
'Independent',
'Lambda',
'Number',
'Reduce',
'Stack',
Expand Down
19 changes: 18 additions & 1 deletion funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from funsor.montecarlo import monte_carlo
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.six import getargspec
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Subs, Variable, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, FunsorMeta, Lambda, Number, Subs, Variable, eager, to_data, to_funsor


def align_tensor(new_inputs, x):
Expand Down Expand Up @@ -420,6 +420,23 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Lambda, Variable, Tensor)
def eager_lambda(var, expr):
inputs = expr.inputs.copy()
if var.name in inputs:
inputs.pop(var.name)
inputs[var.name] = var.output
data = align_tensor(inputs, expr)
inputs.pop(var.name)
else:
data = expr.data
shape = data.shape
dim = len(shape) - len(expr.output.shape)
data = data.reshape(shape[:dim] + (1,) + shape[dim:])
data = data.expand(shape[:dim] + (var.dtype,) + shape[dim:])
return Tensor(data, inputs, expr.dtype)


@eager.register(Contract, AssociativeOp, AssociativeOp, Tensor, Tensor, frozenset)
@contractor
def eager_contract(sum_op, prod_op, lhs, rhs, reduced_vars):
Expand Down
14 changes: 13 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from funsor.delta import Delta
from funsor.domains import bint, reals
from funsor.joint import Joint
from funsor.terms import Variable
from funsor.terms import Independent, Variable
from funsor.testing import assert_close, check_funsor, random_tensor
from funsor.torch import Tensor

Expand Down Expand Up @@ -204,6 +204,18 @@ def test_normal_gaussian_3(batch_shape):
assert_close(actual, expected, atol=1e-4)


def test_normal_independent():
loc = random_tensor(OrderedDict(), reals(2))
scale = random_tensor(OrderedDict(), reals(2)).exp()
fn = dist.Normal(loc['i'], scale['i'], value='z')
assert fn.inputs['z'] == reals()
d = Independent(fn, 'z', 'i')
assert d.inputs['z'] == reals(2)
sample = d.sample(frozenset(['z']))
assert isinstance(sample, Joint)
assert sample.inputs['z'] == reals(2)


def test_mvn_defaults():
loc = Variable('loc', reals(3))
scale_tril = Variable('scale', reals(3, 3))
Expand Down
44 changes: 42 additions & 2 deletions test/test_terms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import itertools
from collections import OrderedDict

import numpy as np
import pytest
Expand All @@ -10,8 +11,8 @@
import funsor.ops as ops
from funsor.domains import Domain, bint, reals
from funsor.interpreter import interpretation
from funsor.terms import Binary, Number, Stack, Variable, sequential, to_data, to_funsor
from funsor.testing import check_funsor
from funsor.terms import Binary, Independent, Lambda, Number, Stack, Variable, sequential, to_data, to_funsor
from funsor.testing import assert_close, check_funsor, random_tensor
from funsor.torch import REDUCE_OP_TO_TORCH

np.seterr(all='ignore')
Expand Down Expand Up @@ -212,6 +213,45 @@ def test_reduce_subset(op, reduced_vars):
assert actual is f


@pytest.mark.parametrize('base_shape', [(), (4,), (3, 2)], ids=str)
def test_lambda(base_shape):
z = Variable('z', reals(*base_shape))
i = Variable('i', bint(5))
j = Variable('j', bint(7))

zi = Lambda(i, z)
assert zi.output.shape == (5,) + base_shape
assert zi[i] is z

zj = Lambda(j, z)
assert zj.output.shape == (7,) + base_shape
assert zj[j] is z

zij = Lambda(j, zi)
assert zij.output.shape == (7, 5) + base_shape
assert zij[j] is zi
assert zij[:, i] is zj
assert zij[j, i] is z


def test_independent():
f = Variable('x', reals(4, 5)) + random_tensor(OrderedDict(i=bint(3)))
assert f.inputs['x'] == reals(4, 5)
assert f.inputs['i'] == bint(3)

actual = Independent(f, 'x', 'i')
assert actual.inputs['x'] == reals(3, 4, 5)
assert 'i' not in actual.inputs

x = Variable('x', reals(3, 4, 5))
expected = f(x=x['i']).reduce(ops.add, 'i')
assert actual.inputs == expected.inputs
assert actual.output == expected.output

data = random_tensor(OrderedDict(), x.output)
assert_close(actual(data), expected(data))


def test_stack_simple():
x = Number(0.)
y = Number(1.)
Expand Down
11 changes: 10 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import funsor
import funsor.ops as ops
from funsor.domains import Domain, bint, find_domain, reals
from funsor.terms import Number, Variable
from funsor.terms import Lambda, Number, Variable
from funsor.testing import assert_close, assert_equiv, check_funsor, random_tensor
from funsor.torch import REDUCE_OP_TO_TORCH, Tensor, align_tensors, torch_einsum

Expand Down Expand Up @@ -402,6 +402,15 @@ def test_getitem_tensor():
assert_close(x[i, j, k](k=y), x[i, j, y])


def test_lambda_getitem():
data = torch.randn(2)
x = Tensor(data)
y = Tensor(data, OrderedDict(i=bint(2)))
i = Variable('i', bint(2))
assert x[i] is y
assert Lambda(i, y) is x


REDUCE_OPS = [ops.add, ops.mul, ops.and_, ops.or_, ops.logaddexp, ops.min, ops.max]


Expand Down

0 comments on commit 306aca6

Please sign in to comment.