Skip to content

Commit

Permalink
Correct and simplify logic of recording assignments (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
giomeg authored Nov 19, 2021
1 parent ae8d0cd commit 9732f5e
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,23 +418,33 @@ def __init__(self, parent: "Scope") -> None:
self._assignment_count = 0

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
self._add_assignment(
Assignment(name=name, scope=self, node=node, index=self._assignment_count)
target = self._find_assignment_target(name)
target._assignments[name].add(
Assignment(
name=name, scope=target, node=node, index=target._assignment_count
)
)

def record_import_assignment(
self, name: str, node: cst.CSTNode, as_name: cst.CSTNode
) -> None:
self._add_assignment(
target = self._find_assignment_target(name)
target._assignments[name].add(
ImportAssignment(
name=name,
scope=self,
scope=target,
node=node,
as_name=as_name,
index=self._assignment_count,
index=target._assignment_count,
)
)

def _find_assignment_target(self, name: str) -> "Scope":
return self

def _find_assignment_target_parent(self, name: str) -> "Scope":
return self

def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)

Expand All @@ -446,14 +456,6 @@ def _contains_in_self_or_parent(self, name: str) -> bool:
"""Overridden by ClassScope to hide it's assignments from child scopes."""
return name in self

def _add_assignment(self, assignment: "BaseAssignment") -> None:
assignment.scope = self
self._assignments[assignment.name].add(assignment)

def _add_assignment_as_parent(self, assignment: "BaseAssignment") -> None:
"""Overridden by ClassScope to forward 'nonlocal' assignments from child scopes."""
self._add_assignment(assignment)

@abc.abstractmethod
def __contains__(self, name: str) -> bool:
"""Check if the name str exist in current scope by ``name in scope``."""
Expand Down Expand Up @@ -615,7 +617,7 @@ def record_global_overwrite(self, name: str) -> None:
def record_nonlocal_overwrite(self, name: str) -> None:
raise NotImplementedError("declarations in builtin scope are not allowed")

def _add_assignment(self, assignment: "BaseAssignment") -> None:
def _find_assignment_target(self, name: str) -> "Scope":
raise NotImplementedError("assignments in builtin scope are not allowed")


Expand Down Expand Up @@ -668,13 +670,11 @@ def record_global_overwrite(self, name: str) -> None:
def record_nonlocal_overwrite(self, name: str) -> None:
self._scope_overwrites[name] = self.parent

def _add_assignment(self, assignment: "BaseAssignment") -> None:
if assignment.name in self._scope_overwrites:
self._scope_overwrites[assignment.name]._add_assignment_as_parent(
assignment
)
def _find_assignment_target(self, name: str) -> "Scope":
if name in self._scope_overwrites:
return self._scope_overwrites[name]._find_assignment_target_parent(name)
else:
super()._add_assignment(assignment)
return super()._find_assignment_target(name)

def __contains__(self, name: str) -> bool:
if name in self._scope_overwrites:
Expand Down Expand Up @@ -707,7 +707,7 @@ class ClassScope(LocalScope):
When a class is defined, it creates a ClassScope.
"""

def _add_assignment_as_parent(self, assignment: "BaseAssignment") -> None:
def _find_assignment_target_parent(self, name: str) -> "Scope":
"""
Forward the assignment to parent.
Expand All @@ -722,7 +722,7 @@ def inner_fn():
# hidden from its children.
"""
self.parent._add_assignment_as_parent(assignment)
return self.parent._find_assignment_target_parent(name)

def _getitem_from_self_or_parent(self, name: str) -> Set[BaseAssignment]:
"""
Expand Down

0 comments on commit 9732f5e

Please sign in to comment.