Skip to content

Commit

Permalink
Add function names to type resolver errors and properly report when t…
Browse files Browse the repository at this point in the history
…he return value is the conflicting type

Signed-off-by: Tim Paine <[email protected]>
  • Loading branch information
AdamGlustein authored and timkpaine committed Feb 7, 2024
1 parent f4008ce commit f976c3a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 42 deletions.
89 changes: 59 additions & 30 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def instance(cls):


class ContainerTypeVarResolutionError(TypeError):
def __init__(self, tvar, tvar_value):
self._tvar, self._tvar_value = tvar, tvar_value
def __init__(self, func_name, tvar, tvar_value):
self._func_name, self._tvar, self._tvar_value = func_name, tvar, tvar_value
super().__init__(
f"Unable to resolve container type for type variable {tvar} explicit value must have"
f"In function {func_name}: Unable to resolve container type for type variable {tvar} explicit value must have"
+ f" uniform values and be non empty, got: {tvar_value} "
)

def __reduce__(self):
return (ContainerTypeVarResolutionError, (self._tvar, self._tvar_value))
return (ContainerTypeVarResolutionError, (self._func_name, self._tvar, self._tvar_value))


class TypeMismatchError(TypeError):
Expand All @@ -117,59 +117,77 @@ def get_tvar_info_str(cls, tvar_info):


class ArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg,
arg_name,
tvar_info,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ f"({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_arg} ({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (ArgTypeMismatchError, (self._expected_t, self._actual_arg, self._arg_name, self._tvar_info))
return (
ArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info),
)


class ArgContainerMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name = expected_t, actual_arg, arg_name
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name = (
func_name,
expected_t,
actual_arg,
arg_name,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ "instead of generic container type specification"
)

def __reduce__(self):
return (ArgContainerMismatchError, (self._expected_t, self._actual_arg, self._arg_name))
return (ArgContainerMismatchError, (self._func_name, self._expected_t, self._actual_arg, self._arg_name))


class TSArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg_type,
arg_name,
tvar_info,
)
actual_type_str = f"ts[{self.pretty_typename(actual_arg_type)}]" if actual_arg_type else "None"

super().__init__(
f"Expected ts[{self.pretty_typename(expected_t)}] for argument '{arg_name}',"
+ f" got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (TSArgTypeMismatchError, (self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info))
return (
TSArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info),
)


class TSDictBasketKeyMismatchError(TypeMismatchError):
def __init__(self, expected_t, arg_name):
self._expected_t, self._arg_name = expected_t, arg_name
super().__init__(f"Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys ")
def __init__(self, func_name, expected_t, arg_name):
self._func_name, self._expected_t, self._arg_name = func_name, expected_t, arg_name
super().__init__(
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys "
)

def __reduce__(self):
return (TSDictBasketKeyMismatchError, (self._expected_t, self._arg_name))
return (TSDictBasketKeyMismatchError, (self._func_name, self._expected_t, self._arg_name))


class NestedTsTypeError:
Expand Down Expand Up @@ -316,13 +334,18 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
if arg is not None:
if not isinstance(arg, Edge):
raise ArgTypeMismatchError(
expected_t=self._cur_def.typ, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=self._cur_def.typ,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)
if isTsDynamicBasket(arg.tstype):
arg_type = arg.tstype
else:
arg_type = arg.tstype.typ
raise TSArgTypeMismatchError(
func_name=self._function_name,
expected_t=self._cur_def.typ.typ,
actual_arg_type=arg_type,
arg_name=self._cur_def.name,
Expand All @@ -337,7 +360,11 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
else:
expected_type = typing.List[self._cur_def.typ]
raise ArgTypeMismatchError(
expected_t=expected_type, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=expected_type,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)

def _add_scalar_value(self, arg, in_out_def):
Expand Down Expand Up @@ -529,12 +556,14 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
# list
if arg is container_typ:
if raise_on_error:
raise ArgContainerMismatchError(expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name)
raise ArgContainerMismatchError(
func_name=self._function_name, expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name
)
else:
return False
if len(arg) == 0:
if raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
else:
return None
res = None
Expand All @@ -561,7 +590,7 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
if first_key_t and first_val_t:
res = typing.Dict[first_key_t, first_val_t]
if not res and raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
return res

def _try_resolve_tvar_conflicts(self):
Expand Down Expand Up @@ -701,7 +730,7 @@ def _add_list_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != in_out_def.shape:
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
)
expected_ts_type = in_out_def.typ.typ
for value in args:
Expand All @@ -716,13 +745,13 @@ def _add_dict_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != len(in_out_def.shape):
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
)

for k in in_out_def.shape:
if k not in args:
raise RuntimeError(
f"Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
f"In function {self._function_name}: Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
)

expected_ts_type = in_out_def.typ.typ
Expand Down
2 changes: 1 addition & 1 deletion csp/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@
# return numpy.zeros(1, dtype=float)

# with _GraphTempCacheFolderConfig() as config:
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("Expected ts[csp.typing.Numpy1DArray[int]] for argument 'None', got ts[csp.typing.Numpy1DArray[float]]")):
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("In function g1: Expected ts[csp.typing.Numpy1DArray[int]] for return value, got ts[csp.typing.Numpy1DArray[float]]")):
# csp.run(g1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config)

# with _GraphTempCacheFolderConfig() as config:
Expand Down
43 changes: 37 additions & 6 deletions csp/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ def graph():

def test_bugreport_csp28(self):
"""bug where non-basket inputs after basket inputs were not being assigne dproperly in c++"""

@csp.node
def buggy(basket: [ts[int]], x: ts[bool]) -> ts[bool]:
if csp.ticked(x) and csp.valid(x):
Expand Down Expand Up @@ -861,14 +860,14 @@ def graph():
fb = csp.feedback(int)
with self.assertRaisesRegex(
TypeError,
re.escape(r"""Expected csp.impl.types.tstype.TsType[""")
re.escape(r"""In function _bind: Expected csp.impl.types.tstype.TsType[""")
+ ".*"
+ re.escape(r"""('T')] for argument 'x', got 1 (int)"""),
):
fb.bind(1)

with self.assertRaisesRegex(
TypeError, re.escape(r"""Expected ts[T] for argument 'x', got ts[str](T=int)""")
TypeError, re.escape(r"""In function _bind: Expected ts[T] for argument 'x', got ts[str](T=int)""")
):
fb.bind(csp.const("123"))

Expand Down Expand Up @@ -968,7 +967,7 @@ def graph():
)
self.assertTrue(__file__ in traceback_list[-1])
self.assertLessEqual(len(traceback_list), 10)
self.assertEqual(str(e), "Expected ts[T] for argument 'my_arg', got None")
self.assertEqual(str(e), "In function aux: Expected ts[T] for argument 'my_arg', got None")

def test_union_type_check(self):
'''was a bug "Add support for typing.Union in type checking layer"'''
Expand All @@ -981,7 +980,8 @@ def graph(x: typing.Union[int, float, str]):
build_graph(graph, 1.1)
build_graph(graph, "s")
with self.assertRaisesRegex(
TypeError, "Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)"
TypeError,
"In function graph: Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)",
):
build_graph(graph, [1.1])

Expand All @@ -994,7 +994,7 @@ def graph(x: ts[typing.Union[int, float, str]]):
build_graph(graph, csp.const("s"))
with self.assertRaisesRegex(
TypeError,
"Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]",
"In function graph: Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]",
):
build_graph(graph, csp.const([1.1]))

Expand Down Expand Up @@ -1548,6 +1548,37 @@ def main(use_graph: bool, pass_null: bool) -> csp.Outputs(o=csp.ts[int]):
endtime=timedelta(seconds=10),
)

def test_return_arg_mismatch(self):
@csp.graph
def my_graph(x: csp.ts[int]) -> csp.ts[str]:
return x

with self.assertRaises(TSArgTypeMismatchError) as ctxt:
csp.run(my_graph, csp.const(1), starttime=datetime.utcnow())
self.assertEqual(str(ctxt.exception), "In function my_graph: Expected ts[str] for return value, got ts[int]")

@csp.graph
def dictbasket_graph(x: csp.ts[int]) -> {str: csp.ts[str]}:
return csp.output({"a": x})

with self.assertRaises(ArgTypeMismatchError) as ctxt:
csp.run(dictbasket_graph, csp.const(1), starttime=datetime.utcnow())
self.assertRegex(
str(ctxt.exception),
"In function dictbasket_graph: Expected typing\.Dict\[str, .* for return value, got \{'a': .* \(dict\)",
)

@csp.graph
def listbasket_graph(x: csp.ts[int]) -> [csp.ts[str]]:
return csp.output([x])

with self.assertRaises(ArgTypeMismatchError) as ctxt:
csp.run(listbasket_graph, csp.const(1), starttime=datetime.utcnow())
self.assertRegex(
str(ctxt.exception),
"In function listbasket_graph: Expected typing\.List\[.* for return value, got \[.* \(list\)",
)

def test_global_context(self):
try:

Expand Down
10 changes: 5 additions & 5 deletions csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ def foo(arr: csp.ts[np.ndarray]) -> csp.ts[np.ndarray]:

def test_pickle_type_resolver_errors(self):
errors = [
type_resolver.ContainerTypeVarResolutionError("T", "NotT"),
type_resolver.ArgTypeMismatchError("T", "NotT", "Var", {"field": 1}),
type_resolver.ArgContainerMismatchError("T", "NotT", "Var"),
type_resolver.TSArgTypeMismatchError("T", "NotT", "Var"),
type_resolver.TSDictBasketKeyMismatchError("T", "Var"),
type_resolver.ContainerTypeVarResolutionError("g", "T", "NotT"),
type_resolver.ArgTypeMismatchError("g", "T", "NotT", "Var", {"field": 1}),
type_resolver.ArgContainerMismatchError("g", "T", "NotT", "Var"),
type_resolver.TSArgTypeMismatchError("g", "T", "NotT", "Var"),
type_resolver.TSDictBasketKeyMismatchError("g", "T", "Var"),
]

for err in errors:
Expand Down

0 comments on commit f976c3a

Please sign in to comment.