Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lambda and Independent funsors #97

Merged
merged 13 commits into from
Apr 1, 2019
3 changes: 2 additions & 1 deletion funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from funsor.domains import Domain, bint, find_domain, reals
from funsor.interpreter import reinterpret
from funsor.terms import Funsor, Number, Variable, of_shape, to_data, to_funsor
from funsor.terms import Funsor, Lambda, Number, Variable, of_shape, to_data, to_funsor
from funsor.torch import Tensor, arange, torch_einsum

from . import (
Expand All @@ -26,6 +26,7 @@
__all__ = [
'Domain',
'Funsor',
'Lambda',
'Number',
'Tensor',
'Variable',
Expand Down
11 changes: 5 additions & 6 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import funsor.ops as ops
from funsor.domains import reals
from funsor.ops import Op
from funsor.ops import AddOp, Op
from funsor.terms import Align, Binary, Funsor, FunsorMeta, Number, Variable, eager, to_funsor


Expand Down Expand Up @@ -97,12 +97,11 @@ def eager_binary(op, lhs, rhs):
return None # defer to default implementation


@eager.register(Binary, Op, (Funsor, Align), Delta)
@eager.register(Binary, AddOp, (Funsor, Align), Delta)
def eager_binary(op, lhs, rhs):
if op is ops.add:
if rhs.name in lhs.inputs:
lhs = lhs(**{rhs.name: rhs.point})
return op(lhs, rhs)
if rhs.name in lhs.inputs:
lhs = lhs(**{rhs.name: rhs.point})
return op(lhs, rhs)

return None # defer to default implementation

Expand Down
40 changes: 40 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,45 @@ def eager_reduce(self, op, reduced_vars):
return Stack(components, self.name)


class Lambda(Funsor):
"""
Lazy inverse to ``ops.getitem``.

This is useful to simulate higher-order functions of integers
by representing those functions as arrays.
"""
def __init__(self, var, expr):
assert isinstance(var, Variable)
assert isinstance(var.dtype, integer_types)
assert isinstance(expr, Funsor)
inputs = expr.inputs.copy()
inputs.pop(var.name, None)
shape = (var.dtype,) + expr.output.shape
output = Domain(shape, expr.dtype)
super(Lambda, self).__init__(inputs, output)
self.var = var
self.expr = expr

def eager_subs(self, subs):
subs = tuple((k, v) for k, v in subs if k != self.var.name)
if not any(k in self.inputs for k, v in subs):
return self
if any(self.var.name in v.inputs for k, v in subs):
raise NotImplementedError('TODO alpha-convert to avoid conflict')
expr = self.expr.eager_subs(subs)
return Lambda(self.var, expr)


@eager.register(Binary, GetitemOp, Lambda, (Funsor, Align))
def eager_getitem_lambda(op, lhs, rhs):
if op.offset == 0:
return lhs.expr.eager_subs(((lhs.var.name, rhs),))
if lhs.var.name in rhs.inputs:
raise NotImplementedError('TODO alpha-convert to avoid conflict')
expr = GetitemOp(op.offset - 1)(lhs.expr, rhs)
return Lambda(lhs.var, expr)


def _of_shape(fn, shape):
args, vargs, kwargs, defaults = getargspec(fn)
assert not vargs
Expand All @@ -901,6 +940,7 @@ def of_shape(*shape):
__all__ = [
'Binary',
'Funsor',
'Lambda',
'Number',
'Reduce',
'Stack',
Expand Down
19 changes: 18 additions & 1 deletion funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from funsor.domains import Domain, bint, find_domain, reals
from funsor.ops import GetitemOp, Op
from funsor.six import getargspec
from funsor.terms import Binary, Funsor, FunsorMeta, Number, Variable, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, FunsorMeta, Lambda, Number, Variable, eager, to_data, to_funsor


def align_tensor(new_inputs, x):
Expand Down Expand Up @@ -400,6 +400,23 @@ def eager_getitem_tensor_tensor(op, lhs, rhs):
return Tensor(data, inputs, lhs.dtype)


@eager.register(Lambda, Variable, Tensor)
def eager_lambda(var, expr):
inputs = expr.inputs.copy()
if var.name in inputs:
inputs.pop(var.name)
inputs[var.name] = var.output
data = align_tensor(inputs, expr)
inputs.pop(var.name)
else:
data = expr.data
shape = data.shape
dim = len(shape) - len(expr.output.shape)
data = data.reshape(shape[:dim] + (1,) + shape[dim:])
data = data.expand(shape[:dim] + (var.dtype,) + shape[dim:])
return Tensor(data, inputs, expr.dtype)


@eager.register(Contract, Tensor, Tensor, frozenset)
@contractor
def eager_contract(lhs, rhs, reduced_vars):
Expand Down
23 changes: 22 additions & 1 deletion test/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import funsor.ops as ops
from funsor.domains import Domain, bint, reals
from funsor.interpreter import interpretation
from funsor.terms import Binary, Number, Stack, Variable, sequential, to_data, to_funsor
from funsor.terms import Binary, Lambda, Number, Stack, Variable, sequential, to_data, to_funsor
from funsor.testing import check_funsor
from funsor.torch import REDUCE_OP_TO_TORCH

Expand Down Expand Up @@ -212,6 +212,27 @@ def test_reduce_subset(op, reduced_vars):
assert actual is f


@pytest.mark.parametrize('base_shape', [(), (4,), (3, 2)], ids=str)
def test_lambda(base_shape):
z = Variable('z', reals(*base_shape))
i = Variable('i', bint(5))
j = Variable('j', bint(7))

zi = Lambda(i, z)
assert zi.output.shape == (5,) + base_shape
assert zi[i] is z

zj = Lambda(j, z)
assert zj.output.shape == (7,) + base_shape
assert zj[j] is z

zij = Lambda(j, zi)
assert zij.output.shape == (7, 5) + base_shape
assert zij[j] is zi
assert zij[:, i] is zj
assert zij[j, i] is z


def test_stack_simple():
x = Number(0.)
y = Number(1.)
Expand Down
22 changes: 21 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import funsor
import funsor.ops as ops
from funsor.domains import Domain, bint, reals
from funsor.terms import Number, Variable
from funsor.terms import Lambda, Number, Variable
from funsor.testing import assert_close, assert_equiv, check_funsor, random_tensor
from funsor.torch import REDUCE_OP_TO_TORCH, Tensor, align_tensors, torch_einsum

Expand Down Expand Up @@ -379,6 +379,26 @@ def test_getitem_tensor():
assert_close(x[i, j, k](k=y), x[i, j, y])


def test_lambda_getitem():
data = torch.randn(2)
x = Tensor(data)
y = Tensor(data, OrderedDict(i=bint(2)))
i = Variable('i', bint(2))
assert x[i] is y
assert Lambda(i, y) is x


def test_lambda_subs():
x = Tensor(torch.randn(2))
y = Tensor(torch.randn(2))
z = Variable('z', reals())
i = Variable('i', bint(2))

actual = (x + Lambda(i, z))(z=y[i]) # FIXME this doesn't work
fritzo marked this conversation as resolved.
Show resolved Hide resolved
expected = x + y
assert_close(actual, expected)


REDUCE_OPS = [ops.add, ops.mul, ops.and_, ops.or_, ops.logaddexp, ops.min, ops.max]


Expand Down