Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes mypy crash on dataclasses.field(**unpack) #11137

Merged
merged 3 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
if self._is_kw_only_type(node_type):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue)
has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)

is_in_init_param = field_args.get('init')
if is_in_init_param is None:
Expand Down Expand Up @@ -447,7 +447,8 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
transformer.transform()


def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
def _collect_field_args(expr: Expression,
ctx: ClassDefContext) -> Tuple[bool, Dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to dataclass.field and the second is a
dictionary of the keyword arguments that field() was called with.
Expand All @@ -460,7 +461,15 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
# field() only takes keyword arguments.
args = {}
for name, arg in zip(expr.arg_names, expr.args):
assert name is not None
if name is None:
# This means that `field` is used with `**` unpacking,
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
sobolevn marked this conversation as resolved.
Show resolved Hide resolved
ctx.api.fail(
'Unpacking **kwargs in "field()" is not supported',
expr,
)
return True, {}
args[name] = arg
return True, args
return False, {}
36 changes: 36 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -1300,3 +1300,39 @@ a.x = x
a.x = x2 # E: Incompatible types in assignment (expression has type "Callable[[str], str]", variable has type "Callable[[int], int]")

[builtins fixtures/dataclasses.pyi]


[case testDataclassFieldDoesNotFailOnKwargsUnpacking]
# flags: --python-version 3.7
# https://github.com/python/mypy/issues/10879
from dataclasses import dataclass, field

@dataclass
class Foo:
bar: float = field(**{"repr": False})
[out]
main:7: error: Unpacking **kwargs in "field()" is not supported
main:7: error: No overload variant of "field" matches argument type "Dict[str, bool]"
main:7: note: Possible overload variants:
main:7: note: def [_T] field(*, default: _T, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T
main:7: note: def [_T] field(*, default_factory: Callable[[], _T], init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T
main:7: note: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> Any
[builtins fixtures/dataclasses.pyi]


[case testDataclassFieldWithTypedDictUnpacking]
# flags: --python-version 3.7
from dataclasses import dataclass, field
from typing_extensions import TypedDict

class FieldKwargs(TypedDict):
repr: bool

field_kwargs: FieldKwargs = {"repr": False}

@dataclass
class Foo:
bar: float = field(**field_kwargs) # E: Unpacking **kwargs in "field()" is not supported

reveal_type(Foo(bar=1.5)) # N: Revealed type is "__main__.Foo"
[builtins fixtures/dataclasses.pyi]
25 changes: 23 additions & 2 deletions test-data/unit/fixtures/dataclasses.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Generic, Sequence, TypeVar
from typing import (
Generic, Iterator, Iterable, Mapping, Optional, Sequence, Tuple,
TypeVar, Union, overload,
)

_T = TypeVar('_T')
_U = TypeVar('_U')
KT = TypeVar('KT')
VT = TypeVar('VT')

class object:
def __init__(self) -> None: pass
Expand All @@ -15,7 +20,23 @@ class int: pass
class float: pass
class str: pass
class bool(int): pass
class dict(Generic[_T, _U]): pass

class dict(Mapping[KT, VT]):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this for **dict unpack.

@overload
def __init__(self, **kwargs: VT) -> None: pass
@overload
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
def __getitem__(self, key: KT) -> VT: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def __contains__(self, item: object) -> int: pass
def update(self, a: Mapping[KT, VT]) -> None: pass
@overload
def get(self, k: KT) -> Optional[VT]: pass
@overload
def get(self, k: KT, default: Union[KT, _T]) -> Union[VT, _T]: pass
def __len__(self) -> int: ...

class list(Generic[_T], Sequence[_T]): pass
class function: pass
class classmethod: pass
Expand Down