Skip to content

Commit

Permalink
Fix compatibility with TypeVar default changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Jun 15, 2024
1 parent 2053922 commit 57d1598
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
18 changes: 15 additions & 3 deletions mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ def type_name(
)
return f"{_typing_name('Union', short)}[{args_str}]"
else:
bound = getattr(typ, "__default__", None)
if bound is None:
if type_var_has_default(typ):
bound = get_type_var_default(typ)
else:
bound = getattr(typ, "__bound__")
return type_name(bound, short, resolved_type_params)
elif is_new_type(typ) and not PY_310_MIN:
Expand Down Expand Up @@ -423,7 +424,7 @@ def is_type_var_any(typ: Type) -> bool:
return False
elif typ.__bound__ not in (None, Any):
return False
elif getattr(typ, "__default__", None) not in (None, NoneType):
elif type_var_has_default(typ):
return False
else:
return True
Expand Down Expand Up @@ -806,3 +807,14 @@ def is_type_alias_type(typ: Type) -> bool:
return isinstance(typ, typing.TypeAliasType) # type: ignore
else:
return False


def type_var_has_default(typ: Any) -> bool:
try:
return typ.has_default()
except AttributeError:
return getattr(typ, "__default__", None) is not None


def get_type_var_default(typ: Any) -> Type:
return getattr(typ, "__default__")
7 changes: 5 additions & 2 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_function_return_annotation,
get_literal_values,
get_type_origin,
get_type_var_default,
is_final,
is_generic,
is_literal,
Expand All @@ -52,6 +53,7 @@
resolve_type_params,
substitute_type_params,
type_name,
type_var_has_default,
)
from mashumaro.core.meta.types.common import (
Expression,
Expand Down Expand Up @@ -471,8 +473,9 @@ def pack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:
if constraints:
return pack_union(spec, constraints, "type_var")
else:
bound = getattr(spec.type, "__default__", None)
if bound is None:
if type_var_has_default(spec.type):
bound = get_type_var_default(spec.type)
else:
bound = getattr(spec.type, "__bound__")
# act as if it was Optional[bound]
pv = PackerRegistry.get(spec.copy(type=bound))
Expand Down
7 changes: 5 additions & 2 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_class_that_defines_method,
get_function_arg_annotation,
get_literal_values,
get_type_var_default,
is_final,
is_generic,
is_literal,
Expand All @@ -59,6 +60,7 @@
resolve_type_params,
substitute_type_params,
type_name,
type_var_has_default,
)
from mashumaro.core.meta.types.common import (
AbstractMethodBuilder,
Expand Down Expand Up @@ -749,8 +751,9 @@ def unpack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:
if constraints:
return TypeVarUnpackerBuilder(constraints).build(spec)
else:
bound = getattr(spec.type, "__default__", None)
if bound is None:
if type_var_has_default(spec.type):
bound = get_type_var_default(spec.type)
else:
bound = getattr(spec.type, "__bound__")
# act as if it was Optional[bound]
uv = UnpackerRegistry.get(spec.copy(type=bound))
Expand Down
24 changes: 24 additions & 0 deletions tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_literal_values,
get_type_annotations,
get_type_origin,
get_type_var_default,
hash_type_args,
is_annotated,
is_dataclass_dict_mixin,
Expand All @@ -49,6 +50,7 @@
resolve_type_params,
substitute_type_params,
type_name,
type_var_has_default,
)
from mashumaro.core.meta.types.common import (
FieldContext,
Expand Down Expand Up @@ -807,3 +809,25 @@ def test_is_hashable_type():
assert is_hashable_type(int) is True
assert is_hashable_type(MyFrozenDataClass) is True
assert is_hashable_type(MyDataClass) is False


def test_type_var_has_default():
T_WithoutDefault = typing_extensions.TypeVar("T_WithoutDefault")
T_WithDefault = typing_extensions.TypeVar("T_WithDefault", default=int)
T_WithDefaultNone = typing_extensions.TypeVar(
"T_WithDefaultNone", default=None
)

assert not type_var_has_default(T_WithoutDefault)
assert type_var_has_default(T_WithDefault)
assert type_var_has_default(T_WithDefaultNone)


def test_get_type_var_default():
T_WithDefault = typing_extensions.TypeVar("T_WithDefault", default=int)
T_WithDefaultNone = typing_extensions.TypeVar(
"T_WithDefaultNone", default=None
)

assert get_type_var_default(T_WithDefault) is int
assert get_type_var_default(T_WithDefaultNone) is None

0 comments on commit 57d1598

Please sign in to comment.