Skip to content

Commit

Permalink
feat: allow configuring to exclude lines or files (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming authored Dec 8, 2022
1 parent 1a1c8e6 commit b02b4ff
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 21 deletions.
25 changes: 21 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ def foo() -> tuple[dict[str, int], str | None]:

Unused import names will be removed, and if `from __future__ import annotations` is not found in the script, it will be automatically added if the new syntax is being used.

## Use as a command line tool

```bash
python3 -m pip install -U fix-future-annotations

fix-future-annotations my_script.py
```

## Use as pre-commit hook

Add the following to your `.pre-commit-config.yaml`:
Expand All @@ -149,12 +157,21 @@ repos:
- id: fix-future-annotations
```
## Use as command line tool
## Configurations
```bash
python3 -m pip install -U fix-future-annotations
`fix-future-annotations` can be configured via `pyproject.toml`. Here is an example:

fix-future-annotations my_script.py
```toml
[tool.fix_future_annotations]
exclude_files = [ # regex patterns to exclude files
'tests/.*',
'docs/.*',
]
exclude_lines = [ # regex patterns to exclude lines
'# ffa: ignore', # if a line ends with this comment, the whole *block* will be excluded
'class .+\(BaseModel\):' # classes that inherit from `BaseModel` will be excluded
]
```

## License
Expand Down
4 changes: 4 additions & 0 deletions fix_future_annotations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from fix_future_annotations._main import fix_file


__all__ = ["fix_file"]
38 changes: 38 additions & 0 deletions fix_future_annotations/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
import re
import sys

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib


@dataclass
class Config:
"""Configuration for fix_future_annotations."""

# The line patterns(regex) to exclude from the fix.
exclude_lines: list[str] = field(default_factory=list)
# The file patterns(regex) to exclude from the fix.
exclude_files: list[str] = field(default_factory=list)

@classmethod
def from_file(cls, path: str | Path = "pyproject.toml") -> Config:
"""Load the configuration from a file."""
try:
with open(path, "rb") as f:
data = tomllib.load(f)
except OSError:
return cls()
else:
return cls(**data.get("tool", {}).get("fix_future_annotations", {}))

def is_file_excluded(self, file_path: str) -> bool:
return any(re.search(pattern, file_path) for pattern in self.exclude_files)

def is_line_excluded(self, line: str) -> bool:
return any(re.search(pattern, line) for pattern in self.exclude_lines)
26 changes: 19 additions & 7 deletions fix_future_annotations/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,27 @@

from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src

from fix_future_annotations._config import Config
from fix_future_annotations._visitor import AnnotationVisitor


def _escaped(line: str) -> bool:
return (len(line) - len(line.rstrip("\\"))) % 2 == 1


def _iter_files(*paths: str) -> Iterator[str]:
def _iter_files(*paths: str, config: Config) -> Iterator[str]:
def files_under_dir(path: str) -> Iterator[str]:
for root, _, files in os.walk(path):
for filename in files:
if filename.endswith(".py"):
yield os.path.join(root, filename)
fn = os.path.join(root, filename).replace("\\", "/")
if not config.is_file_excluded(fn):
yield fn

for path in paths:
if os.path.isdir(path):
yield from files_under_dir(path)
elif path.endswith(".py"):
elif path.endswith(".py") and not config.is_file_excluded(path):
yield path


Expand Down Expand Up @@ -82,14 +85,20 @@ def _add_future_annotations(content: str) -> str:


def fix_file(
file_path: str | Path, write: bool = False, show_diff: bool = False
file_path: str | Path,
*,
write: bool = False,
show_diff: bool = False,
config: Config | None = None,
) -> bool:
"""Fix the file at file_path to use PEP 585, 604 and 563 syntax."""
if config is None:
config = Config.from_file()
file_path = Path(file_path)
file_content = file_path.read_text("utf-8")
tokens = src_to_tokens(file_content)
tree = ast.parse(file_content)
visitor = AnnotationVisitor()
visitor = AnnotationVisitor(file_content.splitlines(), config=config)
token_funcs = visitor.get_token_functions(tree)
for i, token in reversed_enumerate(tokens):
if not token.src:
Expand Down Expand Up @@ -137,9 +146,12 @@ def main(argv: list[str] | None = None) -> None:
args = parser.parse_args(argv)
diff_count = 0
checked = 0
for filename in _iter_files(*args.path):
config = Config.from_file()
for filename in _iter_files(*args.path, config=config):
checked += 1
result = fix_file(filename, args.write, show_diff=args.verbose)
result = fix_file(
filename, write=args.write, show_diff=args.verbose, config=config
)
diff_count += int(result)
if diff_count:
if args.write:
Expand Down
31 changes: 25 additions & 6 deletions fix_future_annotations/_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tokenize_rt import NON_CODING_TOKENS, Offset, Token

from fix_future_annotations._config import Config
from fix_future_annotations._utils import (
ast_to_offset,
find_closing_bracket,
Expand Down Expand Up @@ -134,11 +135,17 @@ def _fix_union(i: int, tokens: list[Token], *, arg_count: int) -> None:
class State(NamedTuple):
in_annotation: bool
in_literal: bool
omit: bool

def update_annotation(self) -> bool:
return self.in_annotation and not self.omit


class AnnotationVisitor(ast.NodeVisitor):
def __init__(self) -> None:
def __init__(self, lines: list[str], *, config: Config) -> None:
super().__init__()
self.lines = lines
self.config = config
self.token_funcs: dict[Offset, list[TokenFunc]] = {}

self._typing_import_name: str | None = None
Expand All @@ -162,8 +169,12 @@ def add_conditional_token_func(
(condition, partial(self.add_token_func, offset, func))
)

def _is_excluded(self, node: ast.AST) -> bool:
line = self.lines[node.lineno - 1]
return self.config.is_line_excluded(line)

def get_token_functions(self, tree: ast.Module) -> dict[Offset, list[TokenFunc]]:
with self.under_state(State(False, False)):
with self.under_state(State(False, False, False)):
self.visit(tree)
for condition, callback in self._conditional_callbacks:
if condition():
Expand All @@ -188,6 +199,14 @@ def under_state(self, state: State) -> None:
finally:
self._state_stack.pop()

def visit(self, node: ast.AST) -> Any:
if isinstance(node, ast.stmt) and self._is_excluded(node):
ctx = self.under_state(self.state._replace(omit=True))
else:
ctx = contextlib.nullcontext()
with ctx:
return super().visit(node)

def generic_visit(self, node: ast.AST) -> Any:
for field in reversed(node._fields):
value = getattr(node, field)
Expand Down Expand Up @@ -244,7 +263,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Transform typing.List -> list"""
if (
self.state.in_annotation
self.state.update_annotation()
and isinstance(node.value, ast.Name)
and node.value.id == self._typing_import_name
and node.attr in BASIC_COLLECTION_TYPES
Expand All @@ -258,7 +277,7 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
def visit_Name(self, node: ast.Name) -> Any:
if node.id in self._typing_imports_to_remove:
name = self._typing_imports_to_remove[node.id]
if not self.state.in_annotation:
if not self.state.update_annotation():
# It is referred to outside of an annotation, so we need to exclude it
self._conditional_callbacks.insert(
0,
Expand All @@ -283,7 +302,7 @@ def visit_BinOp(self, node: ast.BinOp) -> Any:
return self.generic_visit(node)

def visit_Subscript(self, node: ast.Subscript) -> Any:
if not self.state.in_annotation:
if not self.state.update_annotation():
return self.generic_visit(node)
if isinstance(node.value, ast.Attribute):
if (
Expand Down Expand Up @@ -327,7 +346,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:

def visit_Constant(self, node: ast.Constant) -> Any:
if (
self.state.in_annotation
self.state.update_annotation()
and not self.state.in_literal
and isinstance(node.value, str)
):
Expand Down
4 changes: 2 additions & 2 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors = [
]
dependencies = [
"tokenize-rt>=5.0.0",
"tomli; python_version < '3.11'",
]
requires-python = ">=3.8"
readme = "README.md"
Expand Down Expand Up @@ -49,3 +50,9 @@ exclude = '''
| tests/samples
)/
'''


[tool.fix_future_annotations]
exclude_lines = [
"# ffa: ignore"
]
19 changes: 19 additions & 0 deletions tests/samples/exclude_lines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Optional, Tuple, Union


class NoFix:
def __init__(self, names: List[str]) -> None:
self.names = names

def lengh(self) -> Optional[int]:
if self.names:
return len(self.names)
return None


def foo() -> Union[str, int]: # ffa: ignore
return 42


def bar() -> Tuple[str, int]:
return "bar", 42
21 changes: 21 additions & 0 deletions tests/samples/exclude_lines_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from typing import List, Optional, Union


class NoFix:
def __init__(self, names: List[str]) -> None:
self.names = names

def lengh(self) -> Optional[int]:
if self.names:
return len(self.names)
return None


def foo() -> Union[str, int]: # ffa: ignore
return 42


def bar() -> tuple[str, int]:
return "bar", 42
6 changes: 4 additions & 2 deletions tests/test_fix_future_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from fix_future_annotations._main import fix_file
from fix_future_annotations._config import Config

SAMPLES = Path(__file__).with_name("samples")

Expand All @@ -19,9 +20,10 @@ def _load_samples() -> list:
@pytest.mark.parametrize("origin, fixed", _load_samples())
def test_fix_samples(origin: Path, fixed: Path, tmp_path: Path) -> None:
copied = shutil.copy2(origin, tmp_path)
result = fix_file(copied, True)
config = Config(exclude_lines=["# ffa: ignore", "class NoFix:"])
result = fix_file(copied, write=True, config=config)

assert fixed.read_text() == Path(copied).read_text()

result = fix_file(copied, False)
result = fix_file(copied, write=False, config=config)
assert not result

0 comments on commit b02b4ff

Please sign in to comment.