From 1cb43a4dd317cd0d3f7faee4662c2b6bfa918d2a Mon Sep 17 00:00:00 2001 From: eb8680 Date: Tue, 19 Jan 2021 20:36:17 -0500 Subject: [PATCH] Add Tuple funsor (#430) * Add a funsor.Tuple term for tuples of heterogeneous type * delete funsor.tensor.LazyTuple * fix test --- funsor/domains.py | 44 ++++++++++++++++++++++++++++++++++++++++---- funsor/tensor.py | 31 ++++++++----------------------- funsor/terms.py | 27 ++++++++++++++++++++++++++- test/test_tensor.py | 8 ++++---- test/test_terms.py | 24 +++++++++++++++++++++++- 5 files changed, 101 insertions(+), 33 deletions(-) diff --git a/funsor/domains.py b/funsor/domains.py index da951a1ba..87f2ce644 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -175,6 +175,37 @@ def bint(size): return Bint[size] +class ProductDomain(Domain): + + _type_cache = WeakValueDictionary() + + def __getitem__(cls, arg_domains): + try: + return ProductDomain._type_cache[arg_domains] + except KeyError: + assert isinstance(arg_domains, tuple) + assert all(isinstance(arg_domain, Domain) for arg_domain in arg_domains) + subcls = type("Product_", (Product,), {"__args__": arg_domains}) + ProductDomain._type_cache[arg_domains] = subcls + return subcls + + def __repr__(cls): + return "Product[{}]".format(", ".join(map(repr, cls.__args__))) + + @property + def __origin__(cls): + return Product + + @property + def shape(cls): + return (len(cls.__args__),) + + +class Product(tuple, metaclass=ProductDomain): + """like typing.Tuple, but works with issubclass""" + __args__ = NotImplemented + + @quote.register(BintType) @quote.register(RealsType) def _(arg, indent, out): @@ -215,10 +246,15 @@ def _find_domain_reshape(op, domain): @find_domain.register(ops.GetitemOp) -def _find_domain_getitem(op, lhs, rhs): - dtype = lhs.dtype - shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:] - return Array[dtype, shape] +def _find_domain_getitem(op, lhs_domain, rhs_domain): + if isinstance(lhs_domain, ArrayType): + dtype = lhs_domain.dtype + shape = lhs_domain.shape[:op.offset] + lhs_domain.shape[1 + op.offset:] + return Array[dtype, shape] + elif isinstance(lhs_domain, ProductDomain): + # XXX should this return a Union? + raise NotImplementedError("Cannot statically infer domain from: " + f"{lhs_domain}[{rhs_domain}]") @find_domain.register(ops.EqOp) diff --git a/funsor/tensor.py b/funsor/tensor.py index 0dc1a9542..e4ec652fa 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -17,7 +17,7 @@ import funsor import funsor.ops as ops from funsor.delta import Delta -from funsor.domains import Array, ArrayType, Bint, Real, Reals, find_domain +from funsor.domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain from funsor.ops import GetitemOp, MatmulOp, Op, ReshapeOp from funsor.terms import ( Binary, @@ -26,6 +26,7 @@ Lambda, Number, Slice, + Tuple, Unary, Variable, eager, @@ -33,7 +34,7 @@ to_data, to_funsor ) -from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, lazy_property, quote +from funsor.util import get_backend, get_tracing_state, getargspec, is_nn_module, quote def get_default_prototype(): @@ -752,22 +753,6 @@ def eager_cat_homogeneous(name, part_name, *parts): return Tensor(tensor, inputs, dtype=output.dtype) -# TODO Promote this to a Funsor subclass. -class LazyTuple(tuple): - def __call__(self, *args, **kwargs): - return LazyTuple(x(*args, **kwargs) for x in self) - - @lazy_property - def __annotations__(self): - result = {} - output = [] - for part in self: - result.update(part.__annotations__) - output.append(result.pop("return")) - result["return"] = typing.Tuple[tuple(output)] - return result - - # TODO Move this to terms.py; it is no longer Tensor-specific. class Function(Funsor): r""" @@ -834,13 +819,13 @@ def _select(fn, i, *args): def _nested_function(fn, args, output): if isinstance(output, ArrayType): return Function(fn, output, args) - elif output.__origin__ in (tuple, typing.Tuple): + elif output.__origin__ in (tuple, Product, typing.Tuple): result = [] for i, output_i in enumerate(output.__args__): fn_i = functools.partial(_select, fn, i) fn_i.__name__ = "{}_{}".format(_nameof(fn), i) result.append(_nested_function(fn_i, args, output_i)) - return LazyTuple(result) + return Tuple(tuple(result)) raise ValueError("Invalid output: {}".format(output)) @@ -880,7 +865,7 @@ def _function(inputs, output, fn): for (name, domain) in zip(names, inputs)) assert len(args) == len(inputs) if not isinstance(output, ArrayType): - assert output.__origin__ in (tuple, typing.Tuple) + assert output.__origin__ in (tuple, Product, typing.Tuple) # Memoize multiple-output functions so that invocations can be shared among # all outputs. This is not foolproof, but does work in simple situations. fn = _Memoized(fn) @@ -938,14 +923,14 @@ def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: output = inputs.pop("return") assert all(isinstance(d, ArrayType) for d in inputs.values()) assert (isinstance(output, (ArrayType, tuple)) or - output.__origin__ in (tuple, typing.Tuple)) + output.__origin__ in (tuple, Product, typing.Tuple)) return _function(inputs, output, fn) # Usage @function(input1, ..., inputN, output) inputs, output = signature[:-1], signature[-1] output = _tuple_to_Tuple(output) assert all(isinstance(d, ArrayType) for d in inputs) assert (isinstance(output, (ArrayType, tuple)) or - output.__origin__ in (tuple, typing.Tuple)) + output.__origin__ in (tuple, Product, typing.Tuple)) return functools.partial(_function, inputs, output) diff --git a/funsor/terms.py b/funsor/terms.py index 5b05f4c28..7cf6d1f51 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,7 +17,7 @@ import funsor.interpreter as interpreter import funsor.ops as ops -from funsor.domains import Array, Bint, Domain, Real, find_domain +from funsor.domains import Array, Bint, Domain, Product, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.util import getargspec, get_backend, lazy_property, pretty, quote @@ -1592,6 +1592,31 @@ def eager_independent_trivial(fn, reals_var, bint_var, diag_var): return None +class Tuple(Funsor): + """ + Funsor term representing tuples of other terms of possibly heterogeneous type. + """ + def __init__(self, args): + assert isinstance(args, tuple) + assert all(isinstance(arg, Funsor) for arg in args) + inputs = OrderedDict() + for arg in args: + inputs.update(arg.inputs) + output = Product[tuple(arg.output for arg in args)] + super().__init__(inputs, output) + self.args = args + + def __iter__(self): + for i in range(len(self.args)): + yield self[i] + + +@lazy.register(Binary, GetitemOp, Tuple, Number) +@eager.register(Binary, GetitemOp, Tuple, Number) +def eager_getitem_tuple(op, lhs, rhs): + return op(lhs.args, rhs.data) + + def _symbolic(inputs, output, fn): args, vargs, kwargs, defaults = getargspec(fn) assert not vargs diff --git a/test/test_tensor.py b/test/test_tensor.py index 12b64efce..607ed9b7c 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -6,14 +6,14 @@ import itertools import pickle from collections import OrderedDict -from typing import Tuple, get_type_hints +from typing import get_type_hints import numpy as np import pytest import funsor import funsor.ops as ops -from funsor.domains import Array, Bint, Real, Reals, find_domain +from funsor.domains import Array, Bint, Real, Product, Reals, find_domain from funsor.interpreter import interpretation from funsor.tensor import REDUCE_OP_TO_NUMERIC, Einsum, Tensor, align_tensors, numeric_array, stack, tensordot from funsor.terms import Cat, Lambda, Number, Slice, Stack, Variable, lazy @@ -723,10 +723,10 @@ def _numeric_max_and_argmax(x): def test_function_nested_eager_hint(): @funsor.function - def max_and_argmax(x: Reals[8]) -> Tuple[Real, Bint[8]]: + def max_and_argmax(x: Reals[8]) -> Product[Real, Bint[8]]: return tuple(_numeric_max_and_argmax(x)) - expected = {"x": Reals[8], "return": Tuple[Real, Bint[8]]} + expected = {"x": Reals[8], "return": Product[Real, Bint[8]]} assert get_type_hints(max_and_argmax) == expected inputs = OrderedDict([('i', Bint[2]), ('j', Bint[3])]) diff --git a/test/test_terms.py b/test/test_terms.py index d0c278a34..f36a71e92 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -15,7 +15,7 @@ import funsor import funsor.ops as ops from funsor.cnf import Contraction -from funsor.domains import Array, Bint, Real, Reals +from funsor.domains import Array, Bint, Product, Real, Reals from funsor.interpreter import interpretation, reinterpret from funsor.tensor import REDUCE_OP_TO_NUMERIC from funsor.terms import ( @@ -29,6 +29,7 @@ Slice, Stack, Subs, + Tuple, Variable, eager, eager_or_die, @@ -611,3 +612,24 @@ def test_stack_lambda(dtype): assert z[0] is x1 assert z[1] is x2 + + +def test_funsor_tuple(): + x = Number(1, 3) + y = Number(2.5, 'real') + z = random_tensor(OrderedDict([('i', Bint[2])])) + + xyz = Tuple((x, y, z)) + + check_funsor(xyz, {'i': Bint[2]}, Product[x.output, y.output, z.output]) + + assert eval(repr(xyz.output)) is xyz.output + + assert xyz[0] is x + assert xyz[1] is y + assert xyz[2] is z + + x1, y1, z1 = xyz + assert x1 is x + assert y1 is y + assert z1 is z