diff --git a/funsor/factory.py b/funsor/factory.py index f75cc41d8..ac33e44ac 100644 --- a/funsor/factory.py +++ b/funsor/factory.py @@ -60,6 +60,26 @@ class Bound: pass +class ValueMeta(type): + def __getitem__(cls, value_type): + return Value(value_type) + + +class Value(metaclass=ValueMeta): + def __init__(self, value_type): + if issubclass(value_type, Funsor): + raise TypeError("Types cannot depend on Funsor values") + self.value_type = value_type + + +def _get_dependent_args(fields, hints, args): + return { + name: arg if isinstance(hint, Value) else arg.output + for name, arg, hint in zip(fields, args, hints) + if hint in (Funsor, Bound) or isinstance(hint, Value) + } + + def make_funsor(fn): """ Decorator to dynamically create a subclass of @@ -73,6 +93,8 @@ def make_funsor(fn): - Bound variable inputs (names) are typed :class:`Bound`. - Fresh variable inputs (names) are typed :class:`Fresh` together with lambda to compute the dependent domain. + - Ground value inputs (e.g. Python ints) are typed :class:`Value` together with + their actual data type, e.g. ``Value[int]``. - The return value is typed :class:`Fresh` together with a lambda to compute the dependent return domain. @@ -94,7 +116,7 @@ def Unflatten( """ input_types = typing.get_type_hints(fn) for name, hint in input_types.items(): - if not (hint in (Funsor, Bound) or isinstance(hint, Fresh)): + if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value))): raise TypeError(f"Invalid type hint {name}: {hint}") output_type = input_types.pop("return") hints = tuple(input_types.values()) @@ -117,13 +139,15 @@ def __call__(cls, *args): if not isinstance(arg, Variable): raise ValueError(f"Cannot infer domain of {name}={arg}") args[i] = arg + elif isinstance(hint, Value): + if not isinstance(arg, hint.value_type): + raise TypeError( + f"invalid dependent value type: {arg}: {hint.value_type}" + ) + args[i] = arg # Compute domains of fresh variables. - dependent_args = { - name: arg.output - for name, arg, hint in zip(cls._ast_fields, args, hints) - if hint in (Funsor, Bound) - } + dependent_args = _get_dependent_args(cls._ast_fields, hints, args) for i, (hint, arg) in enumerate(zip(hints, args)): if isinstance(hint, Fresh): domain = hint(**dependent_args) @@ -135,11 +159,7 @@ def __call__(cls, *args): ) def __init__(self, **kwargs): args = tuple(kwargs[k] for k in self._ast_fields) - dependent_args = { - name: arg.output - for name, arg, hint in zip(self._ast_fields, args, hints) - if hint in (Funsor, Bound) - } + dependent_args = _get_dependent_args(self._ast_fields, hints, args) output = output_type(**dependent_args) inputs = OrderedDict() fresh = set() diff --git a/test/test_factory.py b/test/test_factory.py index ea977b92a..a9fe8b78c 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -8,10 +8,12 @@ import funsor.ops as ops from funsor.domains import Array, Bint, Real, Reals -from funsor.factory import Bound, Fresh, make_funsor, to_funsor +from funsor.factory import Bound, Fresh, Value, make_funsor, to_funsor +from funsor.interpretations import reflect +from funsor.interpreter import reinterpret from funsor.tensor import Tensor from funsor.terms import Cat, Funsor, Lambda, Number, eager -from funsor.testing import check_funsor, random_tensor +from funsor.testing import assert_close, check_funsor, random_tensor def test_lambda_lambda(): @@ -175,3 +177,49 @@ def Scatter1( value = Number(4, 9) x = Scatter1(destin, "a", "a", value, source) check_funsor(x, {"a": Bint[9], "b": Bint[3]}, Real) + + +def test_value_dependence(): + @make_funsor + def Sum( + x: Funsor, + dim: Value[int], + ) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]: + return None + + @eager.register(Sum, Tensor, int) + def eager_sum(x, dim): + data = x.data.sum(len(x.data.shape) - len(x.shape) + dim) + return Tensor(data, x.inputs, x.dtype) + + x = random_tensor(OrderedDict(a=Bint[3]), Reals[2, 4, 5]) + + with reflect: + y0 = Sum(x, 0) + check_funsor(y0, x.inputs, Reals[4, 5]) + y1 = Sum(x, 1) + check_funsor(y1, x.inputs, Reals[2, 5]) + y2 = Sum(x, 2) + check_funsor(y2, x.inputs, Reals[2, 4]) + + z0 = reinterpret(y0) + check_funsor(z0, x.inputs, Reals[4, 5]) + assert_close(z0.data, x.data.sum(1 + 0)) + z1 = reinterpret(y1) + check_funsor(z1, x.inputs, Reals[2, 5]) + assert_close(z1.data, x.data.sum(1 + 1)) + z2 = reinterpret(y2) + check_funsor(z2, x.inputs, Reals[2, 4]) + assert_close(z2.data, x.data.sum(1 + 2)) + + with pytest.raises(TypeError): + with reflect: + Sum(x, 1.5) + + with pytest.raises(TypeError): + + @make_funsor + def Sum( + x: Funsor, dim: Value[Number] + ) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]: + return None