From f976c3a5485c96fb89f5343e1ece3bb05aecb6af Mon Sep 17 00:00:00 2001 From: Adam G Date: Wed, 7 Feb 2024 09:15:05 -0500 Subject: [PATCH] Add function names to type resolver errors and properly report when the return value is the conflicting type Signed-off-by: Tim Paine --- csp/impl/types/instantiation_type_resolver.py | 89 ++++++++++++------- csp/tests/test_caching.py | 2 +- csp/tests/test_engine.py | 43 +++++++-- csp/tests/test_type_checking.py | 10 +-- 4 files changed, 102 insertions(+), 42 deletions(-) diff --git a/csp/impl/types/instantiation_type_resolver.py b/csp/impl/types/instantiation_type_resolver.py index 5a58bb8f5..22e5e66cd 100644 --- a/csp/impl/types/instantiation_type_resolver.py +++ b/csp/impl/types/instantiation_type_resolver.py @@ -85,15 +85,15 @@ def instance(cls): class ContainerTypeVarResolutionError(TypeError): - def __init__(self, tvar, tvar_value): - self._tvar, self._tvar_value = tvar, tvar_value + def __init__(self, func_name, tvar, tvar_value): + self._func_name, self._tvar, self._tvar_value = func_name, tvar, tvar_value super().__init__( - f"Unable to resolve container type for type variable {tvar} explicit value must have" + f"In function {func_name}: Unable to resolve container type for type variable {tvar} explicit value must have" + f" uniform values and be non empty, got: {tvar_value} " ) def __reduce__(self): - return (ContainerTypeVarResolutionError, (self._tvar, self._tvar_value)) + return (ContainerTypeVarResolutionError, (self._func_name, self._tvar, self._tvar_value)) class TypeMismatchError(TypeError): @@ -117,59 +117,77 @@ def get_tvar_info_str(cls, tvar_info): class ArgTypeMismatchError(TypeMismatchError): - def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None): - self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = ( + def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None): + self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = ( + func_name, expected_t, actual_arg, arg_name, tvar_info, ) super().__init__( - f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} " - + f"({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}" + f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for " + + ("return value, " if arg_name is None else f"argument '{arg_name}', ") + + f"got {actual_arg} ({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}" ) def __reduce__(self): - return (ArgTypeMismatchError, (self._expected_t, self._actual_arg, self._arg_name, self._tvar_info)) + return ( + ArgTypeMismatchError, + (self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info), + ) class ArgContainerMismatchError(TypeMismatchError): - def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None): - self._expected_t, self._actual_arg, self._arg_name = expected_t, actual_arg, arg_name + def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None): + self._func_name, self._expected_t, self._actual_arg, self._arg_name = ( + func_name, + expected_t, + actual_arg, + arg_name, + ) super().__init__( - f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} " + f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} " + "instead of generic container type specification" ) def __reduce__(self): - return (ArgContainerMismatchError, (self._expected_t, self._actual_arg, self._arg_name)) + return (ArgContainerMismatchError, (self._func_name, self._expected_t, self._actual_arg, self._arg_name)) class TSArgTypeMismatchError(TypeMismatchError): - def __init__(self, expected_t, actual_arg_type, arg_name, tvar_info=None): - self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = ( + def __init__(self, func_name, expected_t, actual_arg_type, arg_name, tvar_info=None): + self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = ( + func_name, expected_t, actual_arg_type, arg_name, tvar_info, ) actual_type_str = f"ts[{self.pretty_typename(actual_arg_type)}]" if actual_arg_type else "None" + super().__init__( - f"Expected ts[{self.pretty_typename(expected_t)}] for argument '{arg_name}'," - + f" got {actual_type_str}{self.get_tvar_info_str(tvar_info)}" + f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for " + + ("return value, " if arg_name is None else f"argument '{arg_name}', ") + + f"got {actual_type_str}{self.get_tvar_info_str(tvar_info)}" ) def __reduce__(self): - return (TSArgTypeMismatchError, (self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info)) + return ( + TSArgTypeMismatchError, + (self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info), + ) class TSDictBasketKeyMismatchError(TypeMismatchError): - def __init__(self, expected_t, arg_name): - self._expected_t, self._arg_name = expected_t, arg_name - super().__init__(f"Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys ") + def __init__(self, func_name, expected_t, arg_name): + self._func_name, self._expected_t, self._arg_name = func_name, expected_t, arg_name + super().__init__( + f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys " + ) def __reduce__(self): - return (TSDictBasketKeyMismatchError, (self._expected_t, self._arg_name)) + return (TSDictBasketKeyMismatchError, (self._func_name, self._expected_t, self._arg_name)) class NestedTsTypeError: @@ -316,13 +334,18 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None): if arg is not None: if not isinstance(arg, Edge): raise ArgTypeMismatchError( - expected_t=self._cur_def.typ, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info + func_name=self._function_name, + expected_t=self._cur_def.typ, + actual_arg=arg, + arg_name=self._cur_def.name, + tvar_info=tvar_info, ) if isTsDynamicBasket(arg.tstype): arg_type = arg.tstype else: arg_type = arg.tstype.typ raise TSArgTypeMismatchError( + func_name=self._function_name, expected_t=self._cur_def.typ.typ, actual_arg_type=arg_type, arg_name=self._cur_def.name, @@ -337,7 +360,11 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None): else: expected_type = typing.List[self._cur_def.typ] raise ArgTypeMismatchError( - expected_t=expected_type, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info + func_name=self._function_name, + expected_t=expected_type, + actual_arg=arg, + arg_name=self._cur_def.name, + tvar_info=tvar_info, ) def _add_scalar_value(self, arg, in_out_def): @@ -529,12 +556,14 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise # list if arg is container_typ: if raise_on_error: - raise ArgContainerMismatchError(expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name) + raise ArgContainerMismatchError( + func_name=self._function_name, expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name + ) else: return False if len(arg) == 0: if raise_on_error: - raise ContainerTypeVarResolutionError(tvar, arg) + raise ContainerTypeVarResolutionError(self._function_name, tvar, arg) else: return None res = None @@ -561,7 +590,7 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise if first_key_t and first_val_t: res = typing.Dict[first_key_t, first_val_t] if not res and raise_on_error: - raise ContainerTypeVarResolutionError(tvar, arg) + raise ContainerTypeVarResolutionError(self._function_name, tvar, arg) return res def _try_resolve_tvar_conflicts(self): @@ -701,7 +730,7 @@ def _add_list_basket_ts_value(self, args, in_out_def): if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None: if len(args) != in_out_def.shape: raise RuntimeError( - f"Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}" + f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}" ) expected_ts_type = in_out_def.typ.typ for value in args: @@ -716,13 +745,13 @@ def _add_dict_basket_ts_value(self, args, in_out_def): if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None: if len(args) != len(in_out_def.shape): raise RuntimeError( - f"Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}" + f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}" ) for k in in_out_def.shape: if k not in args: raise RuntimeError( - f"Expected key {k} for output {in_out_def.name} is missing from the actual returned value" + f"In function {self._function_name}: Expected key {k} for output {in_out_def.name} is missing from the actual returned value" ) expected_ts_type = in_out_def.typ.typ diff --git a/csp/tests/test_caching.py b/csp/tests/test_caching.py index 203f3d3b6..04c86de13 100644 --- a/csp/tests/test_caching.py +++ b/csp/tests/test_caching.py @@ -2251,7 +2251,7 @@ # return numpy.zeros(1, dtype=float) # with _GraphTempCacheFolderConfig() as config: -# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("Expected ts[csp.typing.Numpy1DArray[int]] for argument 'None', got ts[csp.typing.Numpy1DArray[float]]")): +# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("In function g1: Expected ts[csp.typing.Numpy1DArray[int]] for return value, got ts[csp.typing.Numpy1DArray[float]]")): # csp.run(g1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config) # with _GraphTempCacheFolderConfig() as config: diff --git a/csp/tests/test_engine.py b/csp/tests/test_engine.py index 1b63e9d7b..2b45c3197 100644 --- a/csp/tests/test_engine.py +++ b/csp/tests/test_engine.py @@ -488,7 +488,6 @@ def graph(): def test_bugreport_csp28(self): """bug where non-basket inputs after basket inputs were not being assigne dproperly in c++""" - @csp.node def buggy(basket: [ts[int]], x: ts[bool]) -> ts[bool]: if csp.ticked(x) and csp.valid(x): @@ -861,14 +860,14 @@ def graph(): fb = csp.feedback(int) with self.assertRaisesRegex( TypeError, - re.escape(r"""Expected csp.impl.types.tstype.TsType[""") + re.escape(r"""In function _bind: Expected csp.impl.types.tstype.TsType[""") + ".*" + re.escape(r"""('T')] for argument 'x', got 1 (int)"""), ): fb.bind(1) with self.assertRaisesRegex( - TypeError, re.escape(r"""Expected ts[T] for argument 'x', got ts[str](T=int)""") + TypeError, re.escape(r"""In function _bind: Expected ts[T] for argument 'x', got ts[str](T=int)""") ): fb.bind(csp.const("123")) @@ -968,7 +967,7 @@ def graph(): ) self.assertTrue(__file__ in traceback_list[-1]) self.assertLessEqual(len(traceback_list), 10) - self.assertEqual(str(e), "Expected ts[T] for argument 'my_arg', got None") + self.assertEqual(str(e), "In function aux: Expected ts[T] for argument 'my_arg', got None") def test_union_type_check(self): '''was a bug "Add support for typing.Union in type checking layer"''' @@ -981,7 +980,8 @@ def graph(x: typing.Union[int, float, str]): build_graph(graph, 1.1) build_graph(graph, "s") with self.assertRaisesRegex( - TypeError, "Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)" + TypeError, + "In function graph: Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)", ): build_graph(graph, [1.1]) @@ -994,7 +994,7 @@ def graph(x: ts[typing.Union[int, float, str]]): build_graph(graph, csp.const("s")) with self.assertRaisesRegex( TypeError, - "Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]", + "In function graph: Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]", ): build_graph(graph, csp.const([1.1])) @@ -1548,6 +1548,37 @@ def main(use_graph: bool, pass_null: bool) -> csp.Outputs(o=csp.ts[int]): endtime=timedelta(seconds=10), ) + def test_return_arg_mismatch(self): + @csp.graph + def my_graph(x: csp.ts[int]) -> csp.ts[str]: + return x + + with self.assertRaises(TSArgTypeMismatchError) as ctxt: + csp.run(my_graph, csp.const(1), starttime=datetime.utcnow()) + self.assertEqual(str(ctxt.exception), "In function my_graph: Expected ts[str] for return value, got ts[int]") + + @csp.graph + def dictbasket_graph(x: csp.ts[int]) -> {str: csp.ts[str]}: + return csp.output({"a": x}) + + with self.assertRaises(ArgTypeMismatchError) as ctxt: + csp.run(dictbasket_graph, csp.const(1), starttime=datetime.utcnow()) + self.assertRegex( + str(ctxt.exception), + "In function dictbasket_graph: Expected typing\.Dict\[str, .* for return value, got \{'a': .* \(dict\)", + ) + + @csp.graph + def listbasket_graph(x: csp.ts[int]) -> [csp.ts[str]]: + return csp.output([x]) + + with self.assertRaises(ArgTypeMismatchError) as ctxt: + csp.run(listbasket_graph, csp.const(1), starttime=datetime.utcnow()) + self.assertRegex( + str(ctxt.exception), + "In function listbasket_graph: Expected typing\.List\[.* for return value, got \[.* \(list\)", + ) + def test_global_context(self): try: diff --git a/csp/tests/test_type_checking.py b/csp/tests/test_type_checking.py index 4a5210fd2..88ef15588 100644 --- a/csp/tests/test_type_checking.py +++ b/csp/tests/test_type_checking.py @@ -601,11 +601,11 @@ def foo(arr: csp.ts[np.ndarray]) -> csp.ts[np.ndarray]: def test_pickle_type_resolver_errors(self): errors = [ - type_resolver.ContainerTypeVarResolutionError("T", "NotT"), - type_resolver.ArgTypeMismatchError("T", "NotT", "Var", {"field": 1}), - type_resolver.ArgContainerMismatchError("T", "NotT", "Var"), - type_resolver.TSArgTypeMismatchError("T", "NotT", "Var"), - type_resolver.TSDictBasketKeyMismatchError("T", "Var"), + type_resolver.ContainerTypeVarResolutionError("g", "T", "NotT"), + type_resolver.ArgTypeMismatchError("g", "T", "NotT", "Var", {"field": 1}), + type_resolver.ArgContainerMismatchError("g", "T", "NotT", "Var"), + type_resolver.TSArgTypeMismatchError("g", "T", "NotT", "Var"), + type_resolver.TSDictBasketKeyMismatchError("g", "T", "Var"), ] for err in errors: