Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function names to type resolver errors and properly report when the return value is the conflicting type #41

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def instance(cls):


class ContainerTypeVarResolutionError(TypeError):
def __init__(self, tvar, tvar_value):
self._tvar, self._tvar_value = tvar, tvar_value
def __init__(self, func_name, tvar, tvar_value):
self._func_name, self._tvar, self._tvar_value = func_name, tvar, tvar_value
super().__init__(
f"Unable to resolve container type for type variable {tvar} explicit value must have"
f"In function {func_name}: Unable to resolve container type for type variable {tvar} explicit value must have"
+ f" uniform values and be non empty, got: {tvar_value} "
)

def __reduce__(self):
return (ContainerTypeVarResolutionError, (self._tvar, self._tvar_value))
return (ContainerTypeVarResolutionError, (self._func_name, self._tvar, self._tvar_value))


class TypeMismatchError(TypeError):
Expand All @@ -117,59 +117,77 @@ def get_tvar_info_str(cls, tvar_info):


class ArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg,
arg_name,
tvar_info,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ f"({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_arg} ({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (ArgTypeMismatchError, (self._expected_t, self._actual_arg, self._arg_name, self._tvar_info))
return (
ArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info),
)


class ArgContainerMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name = expected_t, actual_arg, arg_name
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name = (
func_name,
expected_t,
actual_arg,
arg_name,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ "instead of generic container type specification"
)

def __reduce__(self):
return (ArgContainerMismatchError, (self._expected_t, self._actual_arg, self._arg_name))
return (ArgContainerMismatchError, (self._func_name, self._expected_t, self._actual_arg, self._arg_name))


class TSArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg_type,
arg_name,
tvar_info,
)
actual_type_str = f"ts[{self.pretty_typename(actual_arg_type)}]" if actual_arg_type else "None"

super().__init__(
f"Expected ts[{self.pretty_typename(expected_t)}] for argument '{arg_name}',"
+ f" got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (TSArgTypeMismatchError, (self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info))
return (
TSArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info),
)


class TSDictBasketKeyMismatchError(TypeMismatchError):
def __init__(self, expected_t, arg_name):
self._expected_t, self._arg_name = expected_t, arg_name
super().__init__(f"Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys ")
def __init__(self, func_name, expected_t, arg_name):
self._func_name, self._expected_t, self._arg_name = func_name, expected_t, arg_name
super().__init__(
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys "
)

def __reduce__(self):
return (TSDictBasketKeyMismatchError, (self._expected_t, self._arg_name))
return (TSDictBasketKeyMismatchError, (self._func_name, self._expected_t, self._arg_name))


class NestedTsTypeError:
Expand Down Expand Up @@ -316,13 +334,18 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
if arg is not None:
if not isinstance(arg, Edge):
raise ArgTypeMismatchError(
expected_t=self._cur_def.typ, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=self._cur_def.typ,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)
if isTsDynamicBasket(arg.tstype):
arg_type = arg.tstype
else:
arg_type = arg.tstype.typ
raise TSArgTypeMismatchError(
func_name=self._function_name,
expected_t=self._cur_def.typ.typ,
actual_arg_type=arg_type,
arg_name=self._cur_def.name,
Expand All @@ -337,7 +360,11 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
else:
expected_type = typing.List[self._cur_def.typ]
raise ArgTypeMismatchError(
expected_t=expected_type, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=expected_type,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)

def _add_scalar_value(self, arg, in_out_def):
Expand Down Expand Up @@ -529,12 +556,14 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
# list
if arg is container_typ:
if raise_on_error:
raise ArgContainerMismatchError(expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name)
raise ArgContainerMismatchError(
func_name=self._function_name, expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name
)
else:
return False
if len(arg) == 0:
if raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
else:
return None
res = None
Expand All @@ -561,7 +590,7 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
if first_key_t and first_val_t:
res = typing.Dict[first_key_t, first_val_t]
if not res and raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
return res

def _try_resolve_tvar_conflicts(self):
Expand Down Expand Up @@ -701,7 +730,7 @@ def _add_list_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != in_out_def.shape:
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
)
expected_ts_type = in_out_def.typ.typ
for value in args:
Expand All @@ -716,13 +745,13 @@ def _add_dict_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != len(in_out_def.shape):
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
)

for k in in_out_def.shape:
if k not in args:
raise RuntimeError(
f"Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
f"In function {self._function_name}: Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
)

expected_ts_type = in_out_def.typ.typ
Expand Down
2 changes: 1 addition & 1 deletion csp/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@
# return numpy.zeros(1, dtype=float)

# with _GraphTempCacheFolderConfig() as config:
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("Expected ts[csp.typing.Numpy1DArray[int]] for argument 'None', got ts[csp.typing.Numpy1DArray[float]]")):
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("In function g1: Expected ts[csp.typing.Numpy1DArray[int]] for return value, got ts[csp.typing.Numpy1DArray[float]]")):
# csp.run(g1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config)

# with _GraphTempCacheFolderConfig() as config:
Expand Down
43 changes: 37 additions & 6 deletions csp/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ def graph():

def test_bugreport_csp28(self):
"""bug where non-basket inputs after basket inputs were not being assigne dproperly in c++"""

@csp.node
def buggy(basket: [ts[int]], x: ts[bool]) -> ts[bool]:
if csp.ticked(x) and csp.valid(x):
Expand Down Expand Up @@ -861,14 +860,14 @@ def graph():
fb = csp.feedback(int)
with self.assertRaisesRegex(
TypeError,
re.escape(r"""Expected csp.impl.types.tstype.TsType[""")
re.escape(r"""In function _bind: Expected csp.impl.types.tstype.TsType[""")
+ ".*"
+ re.escape(r"""('T')] for argument 'x', got 1 (int)"""),
):
fb.bind(1)

with self.assertRaisesRegex(
TypeError, re.escape(r"""Expected ts[T] for argument 'x', got ts[str](T=int)""")
TypeError, re.escape(r"""In function _bind: Expected ts[T] for argument 'x', got ts[str](T=int)""")
):
fb.bind(csp.const("123"))

Expand Down Expand Up @@ -968,7 +967,7 @@ def graph():
)
self.assertTrue(__file__ in traceback_list[-1])
self.assertLessEqual(len(traceback_list), 10)
self.assertEqual(str(e), "Expected ts[T] for argument 'my_arg', got None")
self.assertEqual(str(e), "In function aux: Expected ts[T] for argument 'my_arg', got None")

def test_union_type_check(self):
'''was a bug "Add support for typing.Union in type checking layer"'''
Expand All @@ -981,7 +980,8 @@ def graph(x: typing.Union[int, float, str]):
build_graph(graph, 1.1)
build_graph(graph, "s")
with self.assertRaisesRegex(
TypeError, "Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)"
TypeError,
"In function graph: Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)",
):
build_graph(graph, [1.1])

Expand All @@ -994,7 +994,7 @@ def graph(x: ts[typing.Union[int, float, str]]):
build_graph(graph, csp.const("s"))
with self.assertRaisesRegex(
TypeError,
"Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]",
"In function graph: Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]",
):
build_graph(graph, csp.const([1.1]))

Expand Down Expand Up @@ -1548,6 +1548,37 @@ def main(use_graph: bool, pass_null: bool) -> csp.Outputs(o=csp.ts[int]):
endtime=timedelta(seconds=10),
)

def test_return_arg_mismatch(self):
@csp.graph
def my_graph(x: csp.ts[int]) -> csp.ts[str]:
return x

with self.assertRaises(TSArgTypeMismatchError) as ctxt:
csp.run(my_graph, csp.const(1), starttime=datetime.utcnow())
self.assertEqual(str(ctxt.exception), "In function my_graph: Expected ts[str] for return value, got ts[int]")

@csp.graph
def dictbasket_graph(x: csp.ts[int]) -> {str: csp.ts[str]}:
return csp.output({"a": x})

with self.assertRaises(ArgTypeMismatchError) as ctxt:
csp.run(dictbasket_graph, csp.const(1), starttime=datetime.utcnow())
self.assertRegex(
str(ctxt.exception),
"In function dictbasket_graph: Expected typing\.Dict\[str, .* for return value, got \{'a': .* \(dict\)",
)

@csp.graph
def listbasket_graph(x: csp.ts[int]) -> [csp.ts[str]]:
return csp.output([x])

with self.assertRaises(ArgTypeMismatchError) as ctxt:
csp.run(listbasket_graph, csp.const(1), starttime=datetime.utcnow())
self.assertRegex(
str(ctxt.exception),
"In function listbasket_graph: Expected typing\.List\[.* for return value, got \[.* \(list\)",
)

def test_global_context(self):
try:

Expand Down
10 changes: 5 additions & 5 deletions csp/tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ def foo(arr: csp.ts[np.ndarray]) -> csp.ts[np.ndarray]:

def test_pickle_type_resolver_errors(self):
errors = [
type_resolver.ContainerTypeVarResolutionError("T", "NotT"),
type_resolver.ArgTypeMismatchError("T", "NotT", "Var", {"field": 1}),
type_resolver.ArgContainerMismatchError("T", "NotT", "Var"),
type_resolver.TSArgTypeMismatchError("T", "NotT", "Var"),
type_resolver.TSDictBasketKeyMismatchError("T", "Var"),
type_resolver.ContainerTypeVarResolutionError("g", "T", "NotT"),
type_resolver.ArgTypeMismatchError("g", "T", "NotT", "Var", {"field": 1}),
type_resolver.ArgContainerMismatchError("g", "T", "NotT", "Var"),
type_resolver.TSArgTypeMismatchError("g", "T", "NotT", "Var"),
type_resolver.TSDictBasketKeyMismatchError("g", "T", "Var"),
]

for err in errors:
Expand Down