Skip to content

Commit

Permalink
Add function names to type resolver errors and properly report when t…
Browse files Browse the repository at this point in the history
…he return value is the conflicting type

Signed-off-by: Tim Paine <[email protected]>
  • Loading branch information
AdamGlustein authored and timkpaine committed Feb 7, 2024
1 parent f4008ce commit b3a4dea
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 63 deletions.
89 changes: 59 additions & 30 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def instance(cls):


class ContainerTypeVarResolutionError(TypeError):
def __init__(self, tvar, tvar_value):
self._tvar, self._tvar_value = tvar, tvar_value
def __init__(self, func_name, tvar, tvar_value):
self._func_name, self._tvar, self._tvar_value = func_name, tvar, tvar_value
super().__init__(
f"Unable to resolve container type for type variable {tvar} explicit value must have"
f"In function {func_name}: Unable to resolve container type for type variable {tvar} explicit value must have"
+ f" uniform values and be non empty, got: {tvar_value} "
)

def __reduce__(self):
return (ContainerTypeVarResolutionError, (self._tvar, self._tvar_value))
return (ContainerTypeVarResolutionError, (self._func_name, self._tvar, self._tvar_value))


class TypeMismatchError(TypeError):
Expand All @@ -117,59 +117,77 @@ def get_tvar_info_str(cls, tvar_info):


class ArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg,
arg_name,
tvar_info,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ f"({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_arg} ({self.pretty_typename(type(actual_arg))}){self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (ArgTypeMismatchError, (self._expected_t, self._actual_arg, self._arg_name, self._tvar_info))
return (
ArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg, self._arg_name, self._tvar_info),
)


class ArgContainerMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg, arg_name, tvar_info=None):
self._expected_t, self._actual_arg, self._arg_name = expected_t, actual_arg, arg_name
def __init__(self, func_name, expected_t, actual_arg, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg, self._arg_name = (
func_name,
expected_t,
actual_arg,
arg_name,
)
super().__init__(
f"Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
f"In function {func_name}: Expected {self.pretty_typename(expected_t)} for argument '{arg_name}', got {actual_arg} "
+ "instead of generic container type specification"
)

def __reduce__(self):
return (ArgContainerMismatchError, (self._expected_t, self._actual_arg, self._arg_name))
return (ArgContainerMismatchError, (self._func_name, self._expected_t, self._actual_arg, self._arg_name))


class TSArgTypeMismatchError(TypeMismatchError):
def __init__(self, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
def __init__(self, func_name, expected_t, actual_arg_type, arg_name, tvar_info=None):
self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info = (
func_name,
expected_t,
actual_arg_type,
arg_name,
tvar_info,
)
actual_type_str = f"ts[{self.pretty_typename(actual_arg_type)}]" if actual_arg_type else "None"

super().__init__(
f"Expected ts[{self.pretty_typename(expected_t)}] for argument '{arg_name}',"
+ f" got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for "
+ ("return value, " if arg_name is None else f"argument '{arg_name}', ")
+ f"got {actual_type_str}{self.get_tvar_info_str(tvar_info)}"
)

def __reduce__(self):
return (TSArgTypeMismatchError, (self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info))
return (
TSArgTypeMismatchError,
(self._func_name, self._expected_t, self._actual_arg_type, self._arg_name, self._tvar_info),
)


class TSDictBasketKeyMismatchError(TypeMismatchError):
def __init__(self, expected_t, arg_name):
self._expected_t, self._arg_name = expected_t, arg_name
super().__init__(f"Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys ")
def __init__(self, func_name, expected_t, arg_name):
self._func_name, self._expected_t, self._arg_name = func_name, expected_t, arg_name
super().__init__(
f"In function {func_name}: Expected ts[{self.pretty_typename(expected_t)}] for argument {arg_name} must have str keys "
)

def __reduce__(self):
return (TSDictBasketKeyMismatchError, (self._expected_t, self._arg_name))
return (TSDictBasketKeyMismatchError, (self._func_name, self._expected_t, self._arg_name))


class NestedTsTypeError:
Expand Down Expand Up @@ -316,13 +334,18 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
if arg is not None:
if not isinstance(arg, Edge):
raise ArgTypeMismatchError(
expected_t=self._cur_def.typ, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=self._cur_def.typ,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)
if isTsDynamicBasket(arg.tstype):
arg_type = arg.tstype
else:
arg_type = arg.tstype.typ
raise TSArgTypeMismatchError(
func_name=self._function_name,
expected_t=self._cur_def.typ.typ,
actual_arg_type=arg_type,
arg_name=self._cur_def.name,
Expand All @@ -337,7 +360,11 @@ def _raise_arg_mismatch_error(self, arg=None, tvar_info=None):
else:
expected_type = typing.List[self._cur_def.typ]
raise ArgTypeMismatchError(
expected_t=expected_type, actual_arg=arg, arg_name=self._cur_def.name, tvar_info=tvar_info
func_name=self._function_name,
expected_t=expected_type,
actual_arg=arg,
arg_name=self._cur_def.name,
tvar_info=tvar_info,
)

def _add_scalar_value(self, arg, in_out_def):
Expand Down Expand Up @@ -529,12 +556,14 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
# list
if arg is container_typ:
if raise_on_error:
raise ArgContainerMismatchError(expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name)
raise ArgContainerMismatchError(
func_name=self._function_name, expected_t=tvar, actual_arg=arg, arg_name=self._cur_def.name
)
else:
return False
if len(arg) == 0:
if raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
else:
return None
res = None
Expand All @@ -561,7 +590,7 @@ def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise
if first_key_t and first_val_t:
res = typing.Dict[first_key_t, first_val_t]
if not res and raise_on_error:
raise ContainerTypeVarResolutionError(tvar, arg)
raise ContainerTypeVarResolutionError(self._function_name, tvar, arg)
return res

def _try_resolve_tvar_conflicts(self):
Expand Down Expand Up @@ -701,7 +730,7 @@ def _add_list_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != in_out_def.shape:
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {in_out_def.shape}, actual length is {len(args)}"
)
expected_ts_type = in_out_def.typ.typ
for value in args:
Expand All @@ -716,13 +745,13 @@ def _add_dict_basket_ts_value(self, args, in_out_def):
if in_out_def.kind == ArgKind.BASKET_TS and in_out_def.shape is not None:
if len(args) != len(in_out_def.shape):
raise RuntimeError(
f"Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
f"In function {self._function_name}: Expected output shape for output {in_out_def.name} is of length {len(in_out_def.shape)}, actual length is {len(args)}"
)

for k in in_out_def.shape:
if k not in args:
raise RuntimeError(
f"Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
f"In function {self._function_name}: Expected key {k} for output {in_out_def.name} is missing from the actual returned value"
)

expected_ts_type = in_out_def.typ.typ
Expand Down
12 changes: 6 additions & 6 deletions csp/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# from csp.impl.types.instantiation_type_resolver import TSArgTypeMismatchError
# from csp.utils.object_factory_registry import Injected, register_injected_object, set_new_registry_thread_instance
# from datetime import date, datetime, timedelta
# from csp.tests.utils.typed_curve_generator import TypedCurveGenerator
# from tests.csp.utils.typed_curve_generator import TypedCurveGenerator


# class _DummyStructWithTimestamp(csp.Struct):
Expand Down Expand Up @@ -338,8 +338,8 @@
# all_files_and_folders = sorted(glob.glob(f'{config.cache_config.data_folder}/**', recursive=True))
# files_in_cache = [v.replace(f'{config.cache_config.data_folder}/', '') for v in all_files_and_folders if os.path.isfile(v)]
# # When we right from command line, the tests import paths differ. So let's support it as well
# files_in_cache = [f.replace('csp.tests.test_caching', 'test_caching') for f in files_in_cache]
# files_in_cache = [f.replace('/csp.tests.', '/') for f in files_in_cache]
# files_in_cache = [f.replace('python.tests.csp.test_caching', 'test_caching') for f in files_in_cache]
# files_in_cache = [f.replace('/tests.csp.', '/') for f in files_in_cache]
# return files_in_cache

# def test_no_cache(self):
Expand Down Expand Up @@ -638,7 +638,7 @@
# def test_enum_field_serialization(self):
# for split_columns_to_files in (True, False):
# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files)
# from csp.tests.impl.test_enum import MyEnum
# from tests.csp.impl.test_enum import MyEnum

# class MyStruct(csp.Struct):
# e: MyEnum
Expand Down Expand Up @@ -667,7 +667,7 @@
# def test_nested_struct_caching(self):
# for split_columns_to_files in (True, False):
# graph_kwargs = self._get_default_graph_caching_kwargs(split_columns_to_files)
# from csp.tests.impl.test_enum import MyEnum
# from tests.csp.impl.test_enum import MyEnum
# class MyStruct1(csp.Struct):
# v_int: int
# v_str: str
Expand Down Expand Up @@ -2251,7 +2251,7 @@
# return numpy.zeros(1, dtype=float)

# with _GraphTempCacheFolderConfig() as config:
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("Expected ts[csp.typing.Numpy1DArray[int]] for argument 'None', got ts[csp.typing.Numpy1DArray[float]]")):
# with self.assertRaisesRegex(TSArgTypeMismatchError, re.escape("In function g1: Expected ts[csp.typing.Numpy1DArray[int]] for return value, got ts[csp.typing.Numpy1DArray[float]]")):
# csp.run(g1, starttime=datetime(2020, 1, 1), endtime=timedelta(minutes=20), config=config)

# with _GraphTempCacheFolderConfig() as config:
Expand Down
Loading

0 comments on commit b3a4dea

Please sign in to comment.