Skip to content

Commit

Permalink
Add a Value hint to funsor.factory (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Feb 20, 2021
1 parent 5b3d4a1 commit 78996fd
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
42 changes: 31 additions & 11 deletions funsor/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ class Bound:
pass


class ValueMeta(type):
def __getitem__(cls, value_type):
return Value(value_type)


class Value(metaclass=ValueMeta):
def __init__(self, value_type):
if issubclass(value_type, Funsor):
raise TypeError("Types cannot depend on Funsor values")
self.value_type = value_type


def _get_dependent_args(fields, hints, args):
return {
name: arg if isinstance(hint, Value) else arg.output
for name, arg, hint in zip(fields, args, hints)
if hint in (Funsor, Bound) or isinstance(hint, Value)
}


def make_funsor(fn):
"""
Decorator to dynamically create a subclass of
Expand All @@ -73,6 +93,8 @@ def make_funsor(fn):
- Bound variable inputs (names) are typed :class:`Bound`.
- Fresh variable inputs (names) are typed :class:`Fresh` together with
lambda to compute the dependent domain.
- Ground value inputs (e.g. Python ints) are typed :class:`Value` together with
their actual data type, e.g. ``Value[int]``.
- The return value is typed :class:`Fresh` together with a lambda to
compute the dependent return domain.
Expand All @@ -94,7 +116,7 @@ def Unflatten(
"""
input_types = typing.get_type_hints(fn)
for name, hint in input_types.items():
if not (hint in (Funsor, Bound) or isinstance(hint, Fresh)):
if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value))):
raise TypeError(f"Invalid type hint {name}: {hint}")
output_type = input_types.pop("return")
hints = tuple(input_types.values())
Expand All @@ -117,13 +139,15 @@ def __call__(cls, *args):
if not isinstance(arg, Variable):
raise ValueError(f"Cannot infer domain of {name}={arg}")
args[i] = arg
elif isinstance(hint, Value):
if not isinstance(arg, hint.value_type):
raise TypeError(
f"invalid dependent value type: {arg}: {hint.value_type}"
)
args[i] = arg

# Compute domains of fresh variables.
dependent_args = {
name: arg.output
for name, arg, hint in zip(cls._ast_fields, args, hints)
if hint in (Funsor, Bound)
}
dependent_args = _get_dependent_args(cls._ast_fields, hints, args)
for i, (hint, arg) in enumerate(zip(hints, args)):
if isinstance(hint, Fresh):
domain = hint(**dependent_args)
Expand All @@ -135,11 +159,7 @@ def __call__(cls, *args):
)
def __init__(self, **kwargs):
args = tuple(kwargs[k] for k in self._ast_fields)
dependent_args = {
name: arg.output
for name, arg, hint in zip(self._ast_fields, args, hints)
if hint in (Funsor, Bound)
}
dependent_args = _get_dependent_args(self._ast_fields, hints, args)
output = output_type(**dependent_args)
inputs = OrderedDict()
fresh = set()
Expand Down
52 changes: 50 additions & 2 deletions test/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import funsor.ops as ops
from funsor.domains import Array, Bint, Real, Reals
from funsor.factory import Bound, Fresh, make_funsor, to_funsor
from funsor.factory import Bound, Fresh, Value, make_funsor, to_funsor
from funsor.interpretations import reflect
from funsor.interpreter import reinterpret
from funsor.tensor import Tensor
from funsor.terms import Cat, Funsor, Lambda, Number, eager
from funsor.testing import check_funsor, random_tensor
from funsor.testing import assert_close, check_funsor, random_tensor


def test_lambda_lambda():
Expand Down Expand Up @@ -175,3 +177,49 @@ def Scatter1(
value = Number(4, 9)
x = Scatter1(destin, "a", "a", value, source)
check_funsor(x, {"a": Bint[9], "b": Bint[3]}, Real)


def test_value_dependence():
@make_funsor
def Sum(
x: Funsor,
dim: Value[int],
) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]:
return None

@eager.register(Sum, Tensor, int)
def eager_sum(x, dim):
data = x.data.sum(len(x.data.shape) - len(x.shape) + dim)
return Tensor(data, x.inputs, x.dtype)

x = random_tensor(OrderedDict(a=Bint[3]), Reals[2, 4, 5])

with reflect:
y0 = Sum(x, 0)
check_funsor(y0, x.inputs, Reals[4, 5])
y1 = Sum(x, 1)
check_funsor(y1, x.inputs, Reals[2, 5])
y2 = Sum(x, 2)
check_funsor(y2, x.inputs, Reals[2, 4])

z0 = reinterpret(y0)
check_funsor(z0, x.inputs, Reals[4, 5])
assert_close(z0.data, x.data.sum(1 + 0))
z1 = reinterpret(y1)
check_funsor(z1, x.inputs, Reals[2, 5])
assert_close(z1.data, x.data.sum(1 + 1))
z2 = reinterpret(y2)
check_funsor(z2, x.inputs, Reals[2, 4])
assert_close(z2.data, x.data.sum(1 + 2))

with pytest.raises(TypeError):
with reflect:
Sum(x, 1.5)

with pytest.raises(TypeError):

@make_funsor
def Sum(
x: Funsor, dim: Value[Number]
) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]:
return None

0 comments on commit 78996fd

Please sign in to comment.