Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Value hint to funsor.factory #462

Merged
merged 1 commit into from
Feb 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions funsor/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ class Bound:
pass


class ValueMeta(type):
def __getitem__(cls, value_type):
return Value(value_type)


class Value(metaclass=ValueMeta):
def __init__(self, value_type):
if issubclass(value_type, Funsor):
raise TypeError("Types cannot depend on Funsor values")
self.value_type = value_type


def _get_dependent_args(fields, hints, args):
return {
name: arg if isinstance(hint, Value) else arg.output
for name, arg, hint in zip(fields, args, hints)
if hint in (Funsor, Bound) or isinstance(hint, Value)
}


def make_funsor(fn):
"""
Decorator to dynamically create a subclass of
Expand All @@ -73,6 +93,8 @@ def make_funsor(fn):
- Bound variable inputs (names) are typed :class:`Bound`.
- Fresh variable inputs (names) are typed :class:`Fresh` together with
lambda to compute the dependent domain.
- Ground value inputs (e.g. Python ints) are typed :class:`Value` together with
their actual data type, e.g. ``Value[int]``.
- The return value is typed :class:`Fresh` together with a lambda to
compute the dependent return domain.

Expand All @@ -94,7 +116,7 @@ def Unflatten(
"""
input_types = typing.get_type_hints(fn)
for name, hint in input_types.items():
if not (hint in (Funsor, Bound) or isinstance(hint, Fresh)):
if not (hint in (Funsor, Bound) or isinstance(hint, (Fresh, Value))):
raise TypeError(f"Invalid type hint {name}: {hint}")
output_type = input_types.pop("return")
hints = tuple(input_types.values())
Expand All @@ -117,13 +139,15 @@ def __call__(cls, *args):
if not isinstance(arg, Variable):
raise ValueError(f"Cannot infer domain of {name}={arg}")
args[i] = arg
elif isinstance(hint, Value):
if not isinstance(arg, hint.value_type):
raise TypeError(
f"invalid dependent value type: {arg}: {hint.value_type}"
)
args[i] = arg

# Compute domains of fresh variables.
dependent_args = {
name: arg.output
for name, arg, hint in zip(cls._ast_fields, args, hints)
if hint in (Funsor, Bound)
}
dependent_args = _get_dependent_args(cls._ast_fields, hints, args)
for i, (hint, arg) in enumerate(zip(hints, args)):
if isinstance(hint, Fresh):
domain = hint(**dependent_args)
Expand All @@ -135,11 +159,7 @@ def __call__(cls, *args):
)
def __init__(self, **kwargs):
args = tuple(kwargs[k] for k in self._ast_fields)
dependent_args = {
name: arg.output
for name, arg, hint in zip(self._ast_fields, args, hints)
if hint in (Funsor, Bound)
}
dependent_args = _get_dependent_args(self._ast_fields, hints, args)
output = output_type(**dependent_args)
inputs = OrderedDict()
fresh = set()
Expand Down
52 changes: 50 additions & 2 deletions test/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import funsor.ops as ops
from funsor.domains import Array, Bint, Real, Reals
from funsor.factory import Bound, Fresh, make_funsor, to_funsor
from funsor.factory import Bound, Fresh, Value, make_funsor, to_funsor
from funsor.interpretations import reflect
from funsor.interpreter import reinterpret
from funsor.tensor import Tensor
from funsor.terms import Cat, Funsor, Lambda, Number, eager
from funsor.testing import check_funsor, random_tensor
from funsor.testing import assert_close, check_funsor, random_tensor


def test_lambda_lambda():
Expand Down Expand Up @@ -175,3 +177,49 @@ def Scatter1(
value = Number(4, 9)
x = Scatter1(destin, "a", "a", value, source)
check_funsor(x, {"a": Bint[9], "b": Bint[3]}, Real)


def test_value_dependence():
@make_funsor
def Sum(
x: Funsor,
dim: Value[int],
) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]:
return None

@eager.register(Sum, Tensor, int)
def eager_sum(x, dim):
data = x.data.sum(len(x.data.shape) - len(x.shape) + dim)
return Tensor(data, x.inputs, x.dtype)

x = random_tensor(OrderedDict(a=Bint[3]), Reals[2, 4, 5])

with reflect:
y0 = Sum(x, 0)
check_funsor(y0, x.inputs, Reals[4, 5])
y1 = Sum(x, 1)
check_funsor(y1, x.inputs, Reals[2, 5])
y2 = Sum(x, 2)
check_funsor(y2, x.inputs, Reals[2, 4])

z0 = reinterpret(y0)
check_funsor(z0, x.inputs, Reals[4, 5])
assert_close(z0.data, x.data.sum(1 + 0))
z1 = reinterpret(y1)
check_funsor(z1, x.inputs, Reals[2, 5])
assert_close(z1.data, x.data.sum(1 + 1))
z2 = reinterpret(y2)
check_funsor(z2, x.inputs, Reals[2, 4])
assert_close(z2.data, x.data.sum(1 + 2))

with pytest.raises(TypeError):
with reflect:
Sum(x, 1.5)

with pytest.raises(TypeError):

@make_funsor
def Sum(
x: Funsor, dim: Value[Number]
) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]:
return None