Skip to content

Commit

Permalink
Add Tuple funsor (#430)
Browse files Browse the repository at this point in the history
* Add a funsor.Tuple term for tuples of heterogeneous type

* delete funsor.tensor.LazyTuple

* fix test
  • Loading branch information
eb8680 authored Jan 20, 2021
1 parent d16225f commit 1cb43a4
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 33 deletions.
44 changes: 40 additions & 4 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,37 @@ def bint(size):
return Bint[size]


class ProductDomain(Domain):

_type_cache = WeakValueDictionary()

def __getitem__(cls, arg_domains):
try:
return ProductDomain._type_cache[arg_domains]
except KeyError:
assert isinstance(arg_domains, tuple)
assert all(isinstance(arg_domain, Domain) for arg_domain in arg_domains)
subcls = type("Product_", (Product,), {"__args__": arg_domains})
ProductDomain._type_cache[arg_domains] = subcls
return subcls

def __repr__(cls):
return "Product[{}]".format(", ".join(map(repr, cls.__args__)))

@property
def __origin__(cls):
return Product

@property
def shape(cls):
return (len(cls.__args__),)


class Product(tuple, metaclass=ProductDomain):
"""like typing.Tuple, but works with issubclass"""
__args__ = NotImplemented


@quote.register(BintType)
@quote.register(RealsType)
def _(arg, indent, out):
Expand Down Expand Up @@ -215,10 +246,15 @@ def _find_domain_reshape(op, domain):


@find_domain.register(ops.GetitemOp)
def _find_domain_getitem(op, lhs, rhs):
dtype = lhs.dtype
shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:]
return Array[dtype, shape]
def _find_domain_getitem(op, lhs_domain, rhs_domain):
if isinstance(lhs_domain, ArrayType):
dtype = lhs_domain.dtype
shape = lhs_domain.shape[:op.offset] + lhs_domain.shape[1 + op.offset:]
return Array[dtype, shape]
elif isinstance(lhs_domain, ProductDomain):
# XXX should this return a Union?
raise NotImplementedError("Cannot statically infer domain from: "
f"{lhs_domain}[{rhs_domain}]")


@find_domain.register(ops.EqOp)
Expand Down
31 changes: 8 additions & 23 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import funsor
import funsor.ops as ops
from funsor.delta import Delta
from funsor.domains import Array, ArrayType, Bint, Real, Reals, find_domain
from funsor.domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain
from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp
from funsor.terms import (
Binary,
Expand All @@ -26,14 +26,15 @@
Lambda,
Number,
Slice,
Tuple,
Unary,
Variable,
eager,
substitute,
to_data,
to_funsor
)
from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, lazy_property, quote
from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, quote


def get_default_prototype():
Expand Down Expand Up @@ -752,22 +753,6 @@ def eager_cat_homogeneous(name, part_name, *parts):
return Tensor(tensor, inputs, dtype=output.dtype)


# TODO Promote this to a Funsor subclass.
class LazyTuple(tuple):
def __call__(self, *args, **kwargs):
return LazyTuple(x(*args, **kwargs) for x in self)

@lazy_property
def __annotations__(self):
result = {}
output = []
for part in self:
result.update(part.__annotations__)
output.append(result.pop("return"))
result["return"] = typing.Tuple[tuple(output)]
return result


# TODO Move this to terms.py; it is no longer Tensor-specific.
class Function(Funsor):
r"""
Expand Down Expand Up @@ -834,13 +819,13 @@ def _select(fn, i, *args):
def _nested_function(fn, args, output):
if isinstance(output, ArrayType):
return Function(fn, output, args)
elif output.__origin__ in (tuple, typing.Tuple):
elif output.__origin__ in (tuple, Product, typing.Tuple):
result = []
for i, output_i in enumerate(output.__args__):
fn_i = functools.partial(_select, fn, i)
fn_i.__name__ = "{}_{}".format(_nameof(fn), i)
result.append(_nested_function(fn_i, args, output_i))
return LazyTuple(result)
return Tuple(tuple(result))
raise ValueError("Invalid output: {}".format(output))


Expand Down Expand Up @@ -880,7 +865,7 @@ def _function(inputs, output, fn):
for (name, domain) in zip(names, inputs))
assert len(args) == len(inputs)
if not isinstance(output, ArrayType):
assert output.__origin__ in (tuple, typing.Tuple)
assert output.__origin__ in (tuple, Product, typing.Tuple)
# Memoize multiple-output functions so that invocations can be shared among
# all outputs. This is not foolproof, but does work in simple situations.
fn = _Memoized(fn)
Expand Down Expand Up @@ -938,14 +923,14 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]:
output = inputs.pop("return")
assert all(isinstance(d, ArrayType) for d in inputs.values())
assert (isinstance(output, (ArrayType, tuple)) or
output.__origin__ in (tuple, typing.Tuple))
output.__origin__ in (tuple, Product, typing.Tuple))
return _function(inputs, output, fn)
# Usage @function(input1, ..., inputN, output)
inputs, output = signature[:-1], signature[-1]
output = _tuple_to_Tuple(output)
assert all(isinstance(d, ArrayType) for d in inputs)
assert (isinstance(output, (ArrayType, tuple)) or
output.__origin__ in (tuple, typing.Tuple))
output.__origin__ in (tuple, Product, typing.Tuple))
return functools.partial(_function, inputs, output)


Expand Down
27 changes: 26 additions & 1 deletion funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.domains import Array, Bint, Domain, Real, find_domain
from funsor.domains import Array, Bint, Domain, Product, Real, find_domain
from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.util import getargspec, get_backend, lazy_property, pretty, quote
Expand Down Expand Up @@ -1592,6 +1592,31 @@ def eager_independent_trivial(fn, reals_var, bint_var, diag_var):
return None


class Tuple(Funsor):
"""
Funsor term representing tuples of other terms of possibly heterogeneous type.
"""
def __init__(self, args):
assert isinstance(args, tuple)
assert all(isinstance(arg, Funsor) for arg in args)
inputs = OrderedDict()
for arg in args:
inputs.update(arg.inputs)
output = Product[tuple(arg.output for arg in args)]
super().__init__(inputs, output)
self.args = args

def __iter__(self):
for i in range(len(self.args)):
yield self[i]


@lazy.register(Binary, GetitemOp, Tuple, Number)
@eager.register(Binary, GetitemOp, Tuple, Number)
def eager_getitem_tuple(op, lhs, rhs):
return op(lhs.args, rhs.data)


def _symbolic(inputs, output, fn):
args, vargs, kwargs, defaults = getargspec(fn)
assert not vargs
Expand Down
8 changes: 4 additions & 4 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import itertools
import pickle
from collections import OrderedDict
from typing import Tuple, get_type_hints
from typing import get_type_hints

import numpy as np
import pytest

import funsor
import funsor.ops as ops
from funsor.domains import Array, Bint, Real, Reals, find_domain
from funsor.domains import Array, Bint, Real, Product, Reals, find_domain
from funsor.interpreter import interpretation
from funsor.tensor import REDUCE_OP_TO_NUMERIC, Einsum, Tensor, align_tensors, numeric_array, stack, tensordot
from funsor.terms import Cat, Lambda, Number, Slice, Stack, Variable, lazy
Expand Down Expand Up @@ -723,10 +723,10 @@ def _numeric_max_and_argmax(x):
def test_function_nested_eager_hint():

@funsor.function
def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]:
def max_and_argmax(x: Reals[8]) -> Product[Real, Bint[8]]:
return tuple(_numeric_max_and_argmax(x))

expected = {"x": Reals[8], "return": Tuple[Real, Bint[8]]}
expected = {"x": Reals[8], "return": Product[Real, Bint[8]]}
assert get_type_hints(max_and_argmax) == expected

inputs = OrderedDict([('i', Bint[2]), ('j', Bint[3])])
Expand Down
24 changes: 23 additions & 1 deletion test/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import funsor
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.domains import Array, Bint, Real, Reals
from funsor.domains import Array, Bint, Product, Real, Reals
from funsor.interpreter import interpretation, reinterpret
from funsor.tensor import REDUCE_OP_TO_NUMERIC
from funsor.terms import (
Expand All @@ -29,6 +29,7 @@
Slice,
Stack,
Subs,
Tuple,
Variable,
eager,
eager_or_die,
Expand Down Expand Up @@ -611,3 +612,24 @@ def test_stack_lambda(dtype):

assert z[0] is x1
assert z[1] is x2


def test_funsor_tuple():
x = Number(1, 3)
y = Number(2.5, 'real')
z = random_tensor(OrderedDict([('i', Bint[2])]))

xyz = Tuple((x, y, z))

check_funsor(xyz, {'i': Bint[2]}, Product[x.output, y.output, z.output])

assert eval(repr(xyz.output)) is xyz.output

assert xyz[0] is x
assert xyz[1] is y
assert xyz[2] is z

x1, y1, z1 = xyz
assert x1 is x
assert y1 is y
assert z1 is z

0 comments on commit 1cb43a4

Please sign in to comment.