From e755a83502b746b68b74a6a006f5244199409dc2 Mon Sep 17 00:00:00 2001 From: Giorgi Megreli Date: Fri, 19 Nov 2021 11:57:44 +0000 Subject: [PATCH] Correct and simplify logic of recording assignments --- libcst/metadata/scope_provider.py | 44 +++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index e3590f0ca..ae829bc25 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -418,23 +418,33 @@ def __init__(self, parent: "Scope") -> None: self._assignment_count = 0 def record_assignment(self, name: str, node: cst.CSTNode) -> None: - self._add_assignment( - Assignment(name=name, scope=self, node=node, index=self._assignment_count) + target = self._find_assignment_target(name) + target._assignments[name].add( + Assignment( + name=name, scope=target, node=node, index=target._assignment_count + ) ) def record_import_assignment( self, name: str, node: cst.CSTNode, as_name: cst.CSTNode ) -> None: - self._add_assignment( + target = self._find_assignment_target(name) + target._assignments[name].add( ImportAssignment( name=name, - scope=self, + scope=target, node=node, as_name=as_name, - index=self._assignment_count, + index=target._assignment_count, ) ) + def _find_assignment_target(self, name: str) -> "Scope": + return self + + def _find_assignment_target_parent(self, name: str) -> "Scope": + return self + def record_access(self, name: str, access: Access) -> None: self._accesses[name].add(access) @@ -446,14 +456,6 @@ def _contains_in_self_or_parent(self, name: str) -> bool: """Overridden by ClassScope to hide it's assignments from child scopes.""" return name in self - def _add_assignment(self, assignment: "BaseAssignment") -> None: - assignment.scope = self - self._assignments[assignment.name].add(assignment) - - def _add_assignment_as_parent(self, assignment: "BaseAssignment") -> None: - """Overridden by ClassScope to forward 'nonlocal' assignments from child scopes.""" - self._add_assignment(assignment) - @abc.abstractmethod def __contains__(self, name: str) -> bool: """Check if the name str exist in current scope by ``name in scope``.""" @@ -615,7 +617,7 @@ def record_global_overwrite(self, name: str) -> None: def record_nonlocal_overwrite(self, name: str) -> None: raise NotImplementedError("declarations in builtin scope are not allowed") - def _add_assignment(self, assignment: "BaseAssignment") -> None: + def _find_assignment_target(self, name: str) -> "Scope": raise NotImplementedError("assignments in builtin scope are not allowed") @@ -668,13 +670,11 @@ def record_global_overwrite(self, name: str) -> None: def record_nonlocal_overwrite(self, name: str) -> None: self._scope_overwrites[name] = self.parent - def _add_assignment(self, assignment: "BaseAssignment") -> None: - if assignment.name in self._scope_overwrites: - self._scope_overwrites[assignment.name]._add_assignment_as_parent( - assignment - ) + def _find_assignment_target(self, name: str) -> "Scope": + if name in self._scope_overwrites: + return self._scope_overwrites[name]._find_assignment_target_parent(name) else: - super()._add_assignment(assignment) + return super()._find_assignment_target(name) def __contains__(self, name: str) -> bool: if name in self._scope_overwrites: @@ -707,7 +707,7 @@ class ClassScope(LocalScope): When a class is defined, it creates a ClassScope. """ - def _add_assignment_as_parent(self, assignment: "BaseAssignment") -> None: + def _find_assignment_target_parent(self, name: str) -> "Scope": """ Forward the assignment to parent. @@ -722,7 +722,7 @@ def inner_fn(): # hidden from its children. """ - self.parent._add_assignment_as_parent(assignment) + return self.parent._find_assignment_target_parent(name) def _getitem_from_self_or_parent(self, name: str) -> Set[BaseAssignment]: """