Skip to content

Commit

Permalink
Add rule for dict comprehensions with constant values (#553)
Browse files Browse the repository at this point in the history
Adds `C420` that checks for dictionary comprehensions that use constant
values.

Changed the fixture for `test_C418_pass` as it was failing (`dict({x: 1
for x in range(1)}, a=1)` triggers the new rule).

Closes #552.

---------

Co-authored-by: Adam Johnson <[email protected]>
  • Loading branch information
tjkuson and adamchainz authored Jun 29, 2024
1 parent 226a7e0 commit 1ae816f
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 25 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Changelog
=========

* Add rule C420 to check for dict comprehensions with constant values, encouraging replacement with ``dict.fromkeys()``.

Thanks to Tom Kuson in `PR #553 <https://github.com/adamchainz/flake8-comprehensions/pull/553>`__.

3.14.0 (2023-07-10)
-------------------

Expand Down
10 changes: 10 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,13 @@ For example:

* Rewrite ``all([condition(x) for x in iterable])`` as ``all(condition(x) for x in iterable)``
* Rewrite ``any([condition(x) for x in iterable])`` as ``any(condition(x) for x in iterable)``

C420: Unnecessary dict comprehension - rewrite using dict.fromkeys().
----------------------------------------------------------------------

It's unnecessary to use a dict comprehension to build a dict with all values set to the same constant.
Use ``dict.fromkeys()`` instead, which is faster.
For example:

* Rewrite ``{x: 1 for x in iterable}`` as ``dict.fromkeys(iterable, 1)``
* Rewrite ``{x: None for x in iterable}`` as ``dict.fromkeys(iterable)``
66 changes: 42 additions & 24 deletions src/flake8_comprehensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(self, tree: ast.AST) -> None:
"C419 Unnecessary list comprehension passed to {func}() prevents "
+ "short-circuiting - rewrite as a generator."
),
"C420": (
"C420 Unnecessary {type} comprehension - rewrite using dict.fromkeys()."
),
}

def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]:
Expand Down Expand Up @@ -335,32 +338,47 @@ def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]:
len(node.generators) == 1
and not node.generators[0].ifs
and not node.generators[0].is_async
and (
(
isinstance(node, (ast.ListComp, ast.SetComp))
and isinstance(node.elt, ast.Name)
and isinstance(node.generators[0].target, ast.Name)
and node.elt.id == node.generators[0].target.id
):
if (
isinstance(node, (ast.ListComp, ast.SetComp))
and isinstance(node.elt, ast.Name)
and isinstance(node.generators[0].target, ast.Name)
and node.elt.id == node.generators[0].target.id
) or (
isinstance(node, ast.DictComp)
and isinstance(node.key, ast.Name)
and isinstance(node.value, ast.Name)
and isinstance(node.generators[0].target, ast.Tuple)
and len(node.generators[0].target.elts) == 2
and isinstance(node.generators[0].target.elts[0], ast.Name)
and node.generators[0].target.elts[0].id == node.key.id
and isinstance(node.generators[0].target.elts[1], ast.Name)
and node.generators[0].target.elts[1].id == node.value.id
):
yield (
node.lineno,
node.col_offset,
self.messages["C416"].format(
type=comp_type[node.__class__]
),
type(self),
)
or (
isinstance(node, ast.DictComp)
and isinstance(node.key, ast.Name)
and isinstance(node.value, ast.Name)
and isinstance(node.generators[0].target, ast.Tuple)
and len(node.generators[0].target.elts) == 2
and isinstance(node.generators[0].target.elts[0], ast.Name)
and node.generators[0].target.elts[0].id == node.key.id
and isinstance(node.generators[0].target.elts[1], ast.Name)
and node.generators[0].target.elts[1].id == node.value.id

elif (
isinstance(node, ast.DictComp)
and isinstance(node.key, ast.Name)
and isinstance(node.value, ast.Constant)
and isinstance(node.generators[0].target, ast.Name)
and node.key.id == node.generators[0].target.id
):
yield (
node.lineno,
node.col_offset,
self.messages["C420"].format(
type=comp_type[node.__class__]
),
type(self),
)
)
):
yield (
node.lineno,
node.col_offset,
self.messages["C416"].format(type=comp_type[node.__class__]),
type(self),
)


def has_star_args(call_node: ast.Call) -> bool:
Expand Down
57 changes: 56 additions & 1 deletion tests/test_flake8_comprehensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def test_C417_fail(code, failures, flake8_path):
"code",
[
"dict({}, a=1)",
"dict({x: 1 for x in range(1)}, a=1)",
"dict({x: [] for x in range(1)}, a=1)",
],
)
def test_C418_pass(code, flake8_path):
Expand Down Expand Up @@ -963,3 +963,58 @@ def test_C419_fail(code, failures, flake8_path):
(flake8_path / "example.py").write_text(dedent(code))
result = flake8_path.run_flake8()
assert result.out_lines == failures


@pytest.mark.parametrize(
"code",
[
"{elt: elt * 2 for elt in range(5)}",
"{elt: [] for elt in foo}",
"{elt: {1, 2, 3} for elt in ['a', 'b', 'c']}",
"{elt: some_func() for elt in ['a', 'b', 'c']}",
"{elt: SomeClass() for elt in ['a', 'b', 'c']}",
],
)
def test_C420_pass(code, flake8_path):
(flake8_path / "example.py").write_text(dedent(code))
result = flake8_path.run_flake8()
assert result.out_lines == []


@pytest.mark.parametrize(
"code,failures",
[
(
"{elt: None for elt in range(5)}",
[
"./example.py:1:1: C420 Unnecessary dict comprehension - "
+ "rewrite using dict.fromkeys()."
],
),
(
"{elt: 1 for elt in foo}",
[
"./example.py:1:1: C420 Unnecessary dict comprehension - "
+ "rewrite using dict.fromkeys()."
],
),
(
"{elt: 'value' for elt in ['a', 'b', 'c']}",
[
"./example.py:1:1: C420 Unnecessary dict comprehension - "
+ "rewrite using dict.fromkeys()."
],
),
(
"{elt: True for elt in some_func()}",
[
"./example.py:1:1: C420 Unnecessary dict comprehension - "
+ "rewrite using dict.fromkeys()."
],
),
],
)
def test_C420_fail(code, failures, flake8_path):
(flake8_path / "example.py").write_text(dedent(code))
result = flake8_path.run_flake8()
assert result.out_lines == failures

0 comments on commit 1ae816f

Please sign in to comment.