From 57d1598260d3adcdb2d43df364a06c5777ae6653 Mon Sep 17 00:00:00 2001 From: Alexander Tikhonov Date: Sat, 15 Jun 2024 14:50:35 +0300 Subject: [PATCH] Fix compatibility with TypeVar default changes --- mashumaro/core/meta/helpers.py | 18 +++++++++++++++--- mashumaro/core/meta/types/pack.py | 7 +++++-- mashumaro/core/meta/types/unpack.py | 7 +++++-- tests/test_meta.py | 24 ++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/mashumaro/core/meta/helpers.py b/mashumaro/core/meta/helpers.py index ea0188c2..891c1e61 100644 --- a/mashumaro/core/meta/helpers.py +++ b/mashumaro/core/meta/helpers.py @@ -272,8 +272,9 @@ def type_name( ) return f"{_typing_name('Union', short)}[{args_str}]" else: - bound = getattr(typ, "__default__", None) - if bound is None: + if type_var_has_default(typ): + bound = get_type_var_default(typ) + else: bound = getattr(typ, "__bound__") return type_name(bound, short, resolved_type_params) elif is_new_type(typ) and not PY_310_MIN: @@ -423,7 +424,7 @@ def is_type_var_any(typ: Type) -> bool: return False elif typ.__bound__ not in (None, Any): return False - elif getattr(typ, "__default__", None) not in (None, NoneType): + elif type_var_has_default(typ): return False else: return True @@ -806,3 +807,14 @@ def is_type_alias_type(typ: Type) -> bool: return isinstance(typ, typing.TypeAliasType) # type: ignore else: return False + + +def type_var_has_default(typ: Any) -> bool: + try: + return typ.has_default() + except AttributeError: + return getattr(typ, "__default__", None) is not None + + +def get_type_var_default(typ: Any) -> Type: + return getattr(typ, "__default__") diff --git a/mashumaro/core/meta/types/pack.py b/mashumaro/core/meta/types/pack.py index e8b46371..db8a09cb 100644 --- a/mashumaro/core/meta/types/pack.py +++ b/mashumaro/core/meta/types/pack.py @@ -31,6 +31,7 @@ get_function_return_annotation, get_literal_values, get_type_origin, + get_type_var_default, is_final, is_generic, is_literal, @@ -52,6 +53,7 @@ resolve_type_params, substitute_type_params, type_name, + type_var_has_default, ) from mashumaro.core.meta.types.common import ( Expression, @@ -471,8 +473,9 @@ def pack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]: if constraints: return pack_union(spec, constraints, "type_var") else: - bound = getattr(spec.type, "__default__", None) - if bound is None: + if type_var_has_default(spec.type): + bound = get_type_var_default(spec.type) + else: bound = getattr(spec.type, "__bound__") # act as if it was Optional[bound] pv = PackerRegistry.get(spec.copy(type=bound)) diff --git a/mashumaro/core/meta/types/unpack.py b/mashumaro/core/meta/types/unpack.py index 78e14b6f..15bf1ac0 100644 --- a/mashumaro/core/meta/types/unpack.py +++ b/mashumaro/core/meta/types/unpack.py @@ -37,6 +37,7 @@ get_class_that_defines_method, get_function_arg_annotation, get_literal_values, + get_type_var_default, is_final, is_generic, is_literal, @@ -59,6 +60,7 @@ resolve_type_params, substitute_type_params, type_name, + type_var_has_default, ) from mashumaro.core.meta.types.common import ( AbstractMethodBuilder, @@ -749,8 +751,9 @@ def unpack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]: if constraints: return TypeVarUnpackerBuilder(constraints).build(spec) else: - bound = getattr(spec.type, "__default__", None) - if bound is None: + if type_var_has_default(spec.type): + bound = get_type_var_default(spec.type) + else: bound = getattr(spec.type, "__bound__") # act as if it was Optional[bound] uv = UnpackerRegistry.get(spec.copy(type=bound)) diff --git a/tests/test_meta.py b/tests/test_meta.py index e50cb08a..09be6ca3 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -29,6 +29,7 @@ get_literal_values, get_type_annotations, get_type_origin, + get_type_var_default, hash_type_args, is_annotated, is_dataclass_dict_mixin, @@ -49,6 +50,7 @@ resolve_type_params, substitute_type_params, type_name, + type_var_has_default, ) from mashumaro.core.meta.types.common import ( FieldContext, @@ -807,3 +809,25 @@ def test_is_hashable_type(): assert is_hashable_type(int) is True assert is_hashable_type(MyFrozenDataClass) is True assert is_hashable_type(MyDataClass) is False + + +def test_type_var_has_default(): + T_WithoutDefault = typing_extensions.TypeVar("T_WithoutDefault") + T_WithDefault = typing_extensions.TypeVar("T_WithDefault", default=int) + T_WithDefaultNone = typing_extensions.TypeVar( + "T_WithDefaultNone", default=None + ) + + assert not type_var_has_default(T_WithoutDefault) + assert type_var_has_default(T_WithDefault) + assert type_var_has_default(T_WithDefaultNone) + + +def test_get_type_var_default(): + T_WithDefault = typing_extensions.TypeVar("T_WithDefault", default=int) + T_WithDefaultNone = typing_extensions.TypeVar( + "T_WithDefaultNone", default=None + ) + + assert get_type_var_default(T_WithDefault) is int + assert get_type_var_default(T_WithDefaultNone) is None