From 1ae816f430a668ee527d1fd240640cff291b84cb Mon Sep 17 00:00:00 2001 From: Tom Kuson Date: Sat, 29 Jun 2024 23:23:23 +0100 Subject: [PATCH] Add rule for dict comprehensions with constant values (#553) Adds `C420` that checks for dictionary comprehensions that use constant values. Changed the fixture for `test_C418_pass` as it was failing (`dict({x: 1 for x in range(1)}, a=1)` triggers the new rule). Closes #552. --------- Co-authored-by: Adam Johnson --- CHANGELOG.rst | 4 ++ README.rst | 10 ++++ src/flake8_comprehensions/__init__.py | 66 +++++++++++++++++---------- tests/test_flake8_comprehensions.py | 57 ++++++++++++++++++++++- 4 files changed, 112 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 25b7c11..da54167 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,10 @@ Changelog ========= +* Add rule C420 to check for dict comprehensions with constant values, encouraging replacement with ``dict.fromkeys()``. + + Thanks to Tom Kuson in `PR #553 `__. + 3.14.0 (2023-07-10) ------------------- diff --git a/README.rst b/README.rst index 3b6cbbd..ba9d844 100644 --- a/README.rst +++ b/README.rst @@ -227,3 +227,13 @@ For example: * Rewrite ``all([condition(x) for x in iterable])`` as ``all(condition(x) for x in iterable)`` * Rewrite ``any([condition(x) for x in iterable])`` as ``any(condition(x) for x in iterable)`` + +C420: Unnecessary dict comprehension - rewrite using dict.fromkeys(). +---------------------------------------------------------------------- + +It's unnecessary to use a dict comprehension to build a dict with all values set to the same constant. +Use ``dict.fromkeys()`` instead, which is faster. +For example: + +* Rewrite ``{x: 1 for x in iterable}`` as ``dict.fromkeys(iterable, 1)`` +* Rewrite ``{x: None for x in iterable}`` as ``dict.fromkeys(iterable)`` diff --git a/src/flake8_comprehensions/__init__.py b/src/flake8_comprehensions/__init__.py index 3a48a38..dbe97a1 100644 --- a/src/flake8_comprehensions/__init__.py +++ b/src/flake8_comprehensions/__init__.py @@ -46,6 +46,9 @@ def __init__(self, tree: ast.AST) -> None: "C419 Unnecessary list comprehension passed to {func}() prevents " + "short-circuiting - rewrite as a generator." ), + "C420": ( + "C420 Unnecessary {type} comprehension - rewrite using dict.fromkeys()." + ), } def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]: @@ -335,32 +338,47 @@ def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]: len(node.generators) == 1 and not node.generators[0].ifs and not node.generators[0].is_async - and ( - ( - isinstance(node, (ast.ListComp, ast.SetComp)) - and isinstance(node.elt, ast.Name) - and isinstance(node.generators[0].target, ast.Name) - and node.elt.id == node.generators[0].target.id + ): + if ( + isinstance(node, (ast.ListComp, ast.SetComp)) + and isinstance(node.elt, ast.Name) + and isinstance(node.generators[0].target, ast.Name) + and node.elt.id == node.generators[0].target.id + ) or ( + isinstance(node, ast.DictComp) + and isinstance(node.key, ast.Name) + and isinstance(node.value, ast.Name) + and isinstance(node.generators[0].target, ast.Tuple) + and len(node.generators[0].target.elts) == 2 + and isinstance(node.generators[0].target.elts[0], ast.Name) + and node.generators[0].target.elts[0].id == node.key.id + and isinstance(node.generators[0].target.elts[1], ast.Name) + and node.generators[0].target.elts[1].id == node.value.id + ): + yield ( + node.lineno, + node.col_offset, + self.messages["C416"].format( + type=comp_type[node.__class__] + ), + type(self), ) - or ( - isinstance(node, ast.DictComp) - and isinstance(node.key, ast.Name) - and isinstance(node.value, ast.Name) - and isinstance(node.generators[0].target, ast.Tuple) - and len(node.generators[0].target.elts) == 2 - and isinstance(node.generators[0].target.elts[0], ast.Name) - and node.generators[0].target.elts[0].id == node.key.id - and isinstance(node.generators[0].target.elts[1], ast.Name) - and node.generators[0].target.elts[1].id == node.value.id + + elif ( + isinstance(node, ast.DictComp) + and isinstance(node.key, ast.Name) + and isinstance(node.value, ast.Constant) + and isinstance(node.generators[0].target, ast.Name) + and node.key.id == node.generators[0].target.id + ): + yield ( + node.lineno, + node.col_offset, + self.messages["C420"].format( + type=comp_type[node.__class__] + ), + type(self), ) - ) - ): - yield ( - node.lineno, - node.col_offset, - self.messages["C416"].format(type=comp_type[node.__class__]), - type(self), - ) def has_star_args(call_node: ast.Call) -> bool: diff --git a/tests/test_flake8_comprehensions.py b/tests/test_flake8_comprehensions.py index 454cc31..3843b6e 100644 --- a/tests/test_flake8_comprehensions.py +++ b/tests/test_flake8_comprehensions.py @@ -886,7 +886,7 @@ def test_C417_fail(code, failures, flake8_path): "code", [ "dict({}, a=1)", - "dict({x: 1 for x in range(1)}, a=1)", + "dict({x: [] for x in range(1)}, a=1)", ], ) def test_C418_pass(code, flake8_path): @@ -963,3 +963,58 @@ def test_C419_fail(code, failures, flake8_path): (flake8_path / "example.py").write_text(dedent(code)) result = flake8_path.run_flake8() assert result.out_lines == failures + + +@pytest.mark.parametrize( + "code", + [ + "{elt: elt * 2 for elt in range(5)}", + "{elt: [] for elt in foo}", + "{elt: {1, 2, 3} for elt in ['a', 'b', 'c']}", + "{elt: some_func() for elt in ['a', 'b', 'c']}", + "{elt: SomeClass() for elt in ['a', 'b', 'c']}", + ], +) +def test_C420_pass(code, flake8_path): + (flake8_path / "example.py").write_text(dedent(code)) + result = flake8_path.run_flake8() + assert result.out_lines == [] + + +@pytest.mark.parametrize( + "code,failures", + [ + ( + "{elt: None for elt in range(5)}", + [ + "./example.py:1:1: C420 Unnecessary dict comprehension - " + + "rewrite using dict.fromkeys()." + ], + ), + ( + "{elt: 1 for elt in foo}", + [ + "./example.py:1:1: C420 Unnecessary dict comprehension - " + + "rewrite using dict.fromkeys()." + ], + ), + ( + "{elt: 'value' for elt in ['a', 'b', 'c']}", + [ + "./example.py:1:1: C420 Unnecessary dict comprehension - " + + "rewrite using dict.fromkeys()." + ], + ), + ( + "{elt: True for elt in some_func()}", + [ + "./example.py:1:1: C420 Unnecessary dict comprehension - " + + "rewrite using dict.fromkeys()." + ], + ), + ], +) +def test_C420_fail(code, failures, flake8_path): + (flake8_path / "example.py").write_text(dedent(code)) + result = flake8_path.run_flake8() + assert result.out_lines == failures