Skip to content

Commit

Permalink
Merge pull request #41 from Point72/tkp/graphreturntype
Browse files Browse the repository at this point in the history
Add function names to type resolver errors and properly report when the return value is the conflicting type
  • Loading branch information
AdamGlustein committed Feb 7, 2024
2 parents f2b04f4 + f976c3a commit 6d3a8bf
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 42 deletions.
89 changes: 59 additions & 30 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def instance(cls):


class ContainerTypeVarResolutionError(TypeError):
def __init__(self, tvar, tvar_value):
self._tvar, self._tvar_value = tvar, tvar_value
def __init__(self, func_name, tvar, tvar_value):
self._func_name, self._tvar, self._tvar_value = func_name, tvar, tvar_value
super().__init__(
f"Unable to resolve container type for type variable {tvar} explicit value must have"
f"In function {func_name}: Unable to resolve container type for type variable {tvar} explicit value must have"
+ f" uniform values and be non empty, got: {tvar_value} "
)

def __reduce__(self):
return (ContainerTypeVarResolutionError, (self._tvar, self._tvar_value))
return (ContainerTypeVarResolutionError, (self._func_name, self._tvar, self._tvar_value))


class TypeMismatchError(TypeError):
Expand All @@ -117,59 +117,77 @@ def get_tvar_info_str(cls, tvar_info):


class ArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg,
arg_name,
tvar_info,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ f"({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_arg} ({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (ArgTypeMismatchError, (self._expected_t, self._actual_arg, self._arg_name, self._tvar_info))
return (
ArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info),
)


class ArgContainerMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name = expected_t, actual_arg, arg_name
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name = (
func_name,
expected_t,
actual_arg,
arg_name,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ "instead of generic container type specification"
)

def __reduce__(self):
return (ArgContainerMismatchError, (self._expected_t, self._actual_arg, self._arg_name))
return (ArgContainerMismatchError, (self._func_name, self._expected_t, self._actual_arg, self._arg_name))


class TSArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg_type,
arg_name,
tvar_info,
)
actual_type_str = f"ts[{self.pretty_typename(actual_arg_type)}]" if actual_arg_type else "None"

super().__init__(
f"Expected ts[{self.pretty_typename(expected_t)}] for argument '{arg_name}',"
+ f" got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (TSArgTypeMismatchError, (self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info))
return (
TSArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info),
)


class TSDictBasketKeyMismatchError(TypeMismatchError):
def __init__(self, expected_t, arg_name):
self._expected_t, self._arg_name = expected_t, arg_name
super().__init__(f"Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys ")
def __init__(self, func_name, expected_t, arg_name):
self._func_name, self._expected_t, self._arg_name = func_name, expected_t, arg_name
super().__init__(
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys "
)

def __reduce__(self):
return (TSDictBasketKeyMismatchError, (self._expected_t, self._arg_name))
return (TSDictBasketKeyMismatchError, (self._func_name, self._expected_t, self._arg_name))


class NestedTsTypeError:
Expand Down Expand Up @@ -316,13 +334,18 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
if arg is not None:
if not isinstance(arg, Edge):
raise ArgTypeMismatchError(
expected_t=self._cur_def.typ, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=self._cur_def.typ,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)
if isTsDynamicBasket(arg.tstype):
arg_type = arg.tstype
else:
arg_type = arg.tstype.typ
raise TSArgTypeMismatchError(
func_name=self._function_name,
expected_t=self._cur_def.typ.typ,
actual_arg_type=arg_type,
arg_name=self._cur_def.name,
Expand All @@ -337,7 +360,11 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
else:
expected_type = typing.List[self._cur_def.typ]
raise ArgTypeMismatchError(
expected_t=expected_type, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=expected_type,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)

def _add_scalar_value(self, arg, in_out_def):
Expand Down Expand Up @@ -529,12 +556,14 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
# list
if arg is container_typ:
if raise_on_error:
raise ArgContainerMismatchError(expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name)
raise ArgContainerMismatchError(
func_name=self._function_name, expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name
)
else:
return False
if len(arg) == 0:
if raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
else:
return None
res = None
Expand All @@ -561,7 +590,7 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
if first_key_t and first_val_t:
res = typing.Dict[first_key_t, first_val_t]
if not res and raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
return res

def _try_resolve_tvar_conflicts(self):
Expand Down Expand Up @@ -701,7 +730,7 @@ def _add_list_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != in_out_def.shape:
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
)
expected_ts_type = in_out_def.typ.typ
for value in args:
Expand All @@ -716,13 +745,13 @@ def _add_dict_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != len(in_out_def.shape):
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
)

for k in in_out_def.shape:
if k not in args:
raise RuntimeError(
f"Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
f"In function {self._function_name}: Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
)

expected_ts_type = in_out_def.typ.typ
Expand Down
2 changes: 1 addition & 1 deletion csp/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@
# return numpy.zeros(1, dtype=float)

# with _GraphTempCacheFolderConfig() as config:
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("Expected ts[csp.typing.Numpy1DArray[int]] for argument 'None', got ts[csp.typing.Numpy1DArray[float]]")):
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("In function g1: Expected ts[csp.typing.Numpy1DArray[int]] for return value, got ts[csp.typing.Numpy1DArray[float]]")):
# csp.run(g1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config)

# with _GraphTempCacheFolderConfig() as config:
Expand Down
43 changes: 37 additions & 6 deletions csp/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ def graph():

def test_bugreport_csp28(self):
"""bug where non-basket inputs after basket inputs were not being assigne dproperly in c++"""

@csp.node
def buggy(basket: [ts[int]], x: ts[bool]) -> ts[bool]:
if csp.ticked(x) and csp.valid(x):
Expand Down Expand Up @@ -861,14 +860,14 @@ def graph():
fb = csp.feedback(int)
with self.assertRaisesRegex(
TypeError,
re.escape(r"""Expected csp.impl.types.tstype.TsType[""")
re.escape(r"""In function _bind: Expected csp.impl.types.tstype.TsType[""")
+ ".*"
+ re.escape(r"""('T')] for argument 'x', got 1 (int)"""),
):
fb.bind(1)

with self.assertRaisesRegex(
TypeError, re.escape(r"""Expected ts[T] for argument 'x', got ts[str](T=int)""")
TypeError, re.escape(r"""In function _bind: Expected ts[T] for argument 'x', got ts[str](T=int)""")
):
fb.bind(csp.const("123"))

Expand Down Expand Up @@ -968,7 +967,7 @@ def graph():
)
self.assertTrue(__file__ in traceback_list[-1])
self.assertLessEqual(len(traceback_list), 10)
self.assertEqual(str(e), "Expected ts[T] for argument 'my_arg', got None")
self.assertEqual(str(e), "In function aux: Expected ts[T] for argument 'my_arg', got None")

def test_union_type_check(self):
'''was a bug "Add support for typing.Union in type checking layer"'''
Expand All @@ -981,7 +980,8 @@ def graph(x: typing.Union[int, float, str]):
build_graph(graph, 1.1)
build_graph(graph, "s")
with self.assertRaisesRegex(
TypeError, "Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)"
TypeError,
"In function graph: Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)",
):
build_graph(graph, [1.1])

Expand All @@ -994,7 +994,7 @@ def graph(x: ts[typing.Union[int, float, str]]):
build_graph(graph, csp.const("s"))
with self.assertRaisesRegex(
TypeError,
"Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]",
"In function graph: Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]",
):
build_graph(graph, csp.const([1.1]))

Expand Down Expand Up @@ -1548,6 +1548,37 @@ def main(use_graph: bool, pass_null: bool) -> csp.Outputs(o=csp.ts[int]):
endtime=timedelta(seconds=10),
)

def test_return_arg_mismatch(self):
@csp.graph
def my_graph(x: csp.ts[int]) -> csp.ts[str]:
return x

with self.assertRaises(TSArgTypeMismatchError) as ctxt:
csp.run(my_graph, csp.const(1), starttime=datetime.utcnow())
self.assertEqual(str(ctxt.exception), "In function my_graph: Expected ts[str] for return value, got ts[int]")

@csp.graph
def dictbasket_graph(x: csp.ts[int]) -> {str: csp.ts[str]}:
return csp.output({"a": x})

with self.assertRaises(ArgTypeMismatchError) as ctxt:
csp.run(dictbasket_graph, csp.const(1), starttime=datetime.utcnow())
self.assertRegex(
str(ctxt.exception),
"In function dictbasket_graph: Expected typing\.Dict\[str, .* for return value, got \{'a': .* \(dict\)",
)

@csp.graph
def listbasket_graph(x: csp.ts[int]) -> [csp.ts[str]]:
return csp.output([x])

with self.assertRaises(ArgTypeMismatchError) as ctxt:
csp.run(listbasket_graph, csp.const(1), starttime=datetime.utcnow())
self.assertRegex(
str(ctxt.exception),
"In function listbasket_graph: Expected typing\.List\[.* for return value, got \[.* \(list\)",
)

def test_global_context(self):
try:

Expand Down
10 changes: 5 additions & 5 deletions csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ def foo(arr: csp.ts[np.ndarray]) -> csp.ts[np.ndarray]:

def test_pickle_type_resolver_errors(self):
errors = [
type_resolver.ContainerTypeVarResolutionError("T", "NotT"),
type_resolver.ArgTypeMismatchError("T", "NotT", "Var", {"field": 1}),
type_resolver.ArgContainerMismatchError("T", "NotT", "Var"),
type_resolver.TSArgTypeMismatchError("T", "NotT", "Var"),
type_resolver.TSDictBasketKeyMismatchError("T", "Var"),
type_resolver.ContainerTypeVarResolutionError("g", "T", "NotT"),
type_resolver.ArgTypeMismatchError("g", "T", "NotT", "Var", {"field": 1}),
type_resolver.ArgContainerMismatchError("g", "T", "NotT", "Var"),
type_resolver.TSArgTypeMismatchError("g", "T", "NotT", "Var"),
type_resolver.TSDictBasketKeyMismatchError("g", "T", "Var"),
]

for err in errors:
Expand Down

0 comments on commit 6d3a8bf

Please sign in to comment.