Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass 'parents' as a parameter when walking the tree #224

Merged
merged 1 commit into from
Oct 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions src/pep8ext_naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ast
from ast import iter_child_nodes
from collections import deque
from collections.abc import Iterable
from fnmatch import fnmatchcase
from functools import partial
from itertools import chain
Expand Down Expand Up @@ -107,8 +108,7 @@ class NamingChecker:
ignore_names = frozenset(_default_ignore_names)

def __init__(self, tree, filename):
self.parents = deque()
self._node = tree
self.tree = tree

@classmethod
def add_options(cls, parser):
Expand Down Expand Up @@ -157,23 +157,22 @@ def parse_options(cls, options):
)

def run(self):
return self.visit_tree(self._node) if self._node else ()
return self.visit_tree(self.tree, deque()) if self.tree else ()

def visit_tree(self, node):
yield from self.visit_node(node)
self.parents.append(node)
def visit_tree(self, node, parents: deque):
yield from self.visit_node(node, parents)
parents.append(node)
for child in iter_child_nodes(node):
yield from self.visit_tree(child)
self.parents.pop()
yield from self.visit_tree(child, parents)
parents.pop()

def visit_node(self, node):
def visit_node(self, node, parents: Iterable):
if isinstance(node, ast.ClassDef):
self.tag_class_functions(node)
elif isinstance(node, FUNC_NODES):
self.find_global_defs(node)

method = 'visit_' + node.__class__.__name__.lower()
parents = self.parents
ignore_names = self.ignore_names
for visitor in self.visitors:
visitor_method = getattr(visitor, method, None)
Expand Down Expand Up @@ -263,14 +262,14 @@ class ClassNameCheck(BaseASTCheck):
N818 = "exception name '{name}' should be named with an Error suffix"

@classmethod
def get_classdef(cls, name, parents):
def get_classdef(cls, name, parents: Iterable):
for parent in parents:
for node in parent.body:
if isinstance(node, ast.ClassDef) and node.name == name:
return node

@classmethod
def superclass_names(cls, name, parents, _names=None):
def superclass_names(cls, name, parents: Iterable, _names=None):
names = _names or set()
classdef = cls.get_classdef(name, parents)
if not classdef:
Expand All @@ -281,7 +280,7 @@ def superclass_names(cls, name, parents, _names=None):
names.update(cls.superclass_names(base.id, parents, names))
return names

def visit_classdef(self, node, parents, ignore=None):
def visit_classdef(self, node, parents: Iterable, ignore=None):
name = node.name
if _ignored(name, ignore):
return
Expand Down Expand Up @@ -316,7 +315,7 @@ def has_override_decorator(node):
return True
return False

def visit_functiondef(self, node, parents, ignore=None):
def visit_functiondef(self, node, parents: Iterable, ignore=None):
function_type = getattr(node, 'function_type', _FunctionType.FUNCTION)
name = node.name
if _ignored(name, ignore):
Expand Down Expand Up @@ -347,7 +346,7 @@ class FunctionArgNamesCheck(BaseASTCheck):
N804 = "first argument of a classmethod should be named 'cls'"
N805 = "first argument of a method should be named 'self'"

def visit_functiondef(self, node, parents, ignore=None):
def visit_functiondef(self, node, parents: Iterable, ignore=None):

def arg_name(arg):
return (arg, arg.arg) if arg else (node, arg)
Expand Down Expand Up @@ -389,7 +388,7 @@ class ImportAsCheck(BaseASTCheck):
N814 = "camelcase '{name}' imported as constant '{asname}'"
N817 = "camelcase '{name}' imported as acronym '{asname}'"

def visit_importfrom(self, node, parents, ignore=None):
def visit_importfrom(self, node, parents: Iterable, ignore=None):
for name in node.names:
asname = name.asname
if not asname:
Expand Down Expand Up @@ -421,7 +420,7 @@ class VariablesCheck(BaseASTCheck):
N815 = "variable '{name}' in class scope should not be mixedCase"
N816 = "variable '{name}' in global scope should not be mixedCase"

def _find_errors(self, assignment_target, parents, ignore):
def _find_errors(self, assignment_target, parents: Iterable, ignore):
for parent_func in reversed(parents):
if isinstance(parent_func, ast.ClassDef):
checker = self.class_variable_check
Expand Down Expand Up @@ -449,36 +448,36 @@ def is_namedtupe(node_value):
return True
return False

def visit_assign(self, node, parents, ignore=None):
def visit_assign(self, node, parents: Iterable, ignore=None):
if self.is_namedtupe(node.value):
return
for target in node.targets:
yield from self._find_errors(target, parents, ignore)

def visit_namedexpr(self, node, parents, ignore):
def visit_namedexpr(self, node, parents: Iterable, ignore):
if self.is_namedtupe(node.value):
return
yield from self._find_errors(node.target, parents, ignore)

visit_annassign = visit_namedexpr

def visit_with(self, node, parents, ignore):
def visit_with(self, node, parents: Iterable, ignore):
for item in node.items:
yield from self._find_errors(
item.optional_vars, parents, ignore)

visit_asyncwith = visit_with

def visit_for(self, node, parents, ignore):
def visit_for(self, node, parents: Iterable, ignore):
yield from self._find_errors(node.target, parents, ignore)

visit_asyncfor = visit_for

def visit_excepthandler(self, node, parents, ignore):
def visit_excepthandler(self, node, parents: Iterable, ignore):
if node.name:
yield from self._find_errors(node, parents, ignore)

def visit_generatorexp(self, node, parents, ignore):
def visit_generatorexp(self, node, parents: Iterable, ignore):
for gen in node.generators:
yield from self._find_errors(gen.target, parents, ignore)

Expand Down
Loading