Skip to content

Commit

Permalink
[scope] keep track of assignment/access ordering (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol authored Nov 17, 2020
1 parent 90df5a6 commit 2ef7302
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 9 deletions.
84 changes: 75 additions & 9 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@
)


_ASSIGNMENT_LIKE_NODES = (
cst.AnnAssign,
cst.AsName,
cst.Assign,
cst.AugAssign,
cst.ClassDef,
cst.CompFor,
cst.For,
cst.FunctionDef,
cst.Global,
cst.Import,
cst.ImportFrom,
cst.NamedExpr,
cst.Nonlocal,
cst.Parameters,
cst.WithItem,
)


@add_slots
@dataclass(frozen=False)
class Access:
Expand Down Expand Up @@ -68,6 +87,7 @@ def __new__(cls) -> "Tree":
is_type_hint: bool

__assignments: Set["BaseAssignment"]
__index: int

def __init__(
self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
Expand All @@ -77,6 +97,7 @@ def __init__(
self.is_annotation = is_annotation
self.is_type_hint = is_type_hint
self.__assignments = set()
self.__index = scope._assignment_count

def __hash__(self) -> int:
return id(self)
Expand All @@ -86,11 +107,25 @@ def referents(self) -> Collection["BaseAssignment"]:
"""Return all assignments of the access."""
return self.__assignments

def record_assignment(self, assignment: "BaseAssignment") -> None:
self.__assignments.add(assignment)
@property
def _index(self) -> int:
return self.__index

def record_assignments(self, assignments: Set["BaseAssignment"]) -> None:
self.__assignments |= assignments
def record_assignment(self, assignment: "BaseAssignment") -> None:
if assignment.scope != self.scope or assignment._index < self.__index:
self.__assignments.add(assignment)

def record_assignments(self, name: str) -> None:
assignments = self.scope[name]
# filter out assignments that happened later than this access
previous_assignments = {
assignment
for assignment in assignments
if assignment.scope != self.scope or assignment._index < self.__index
}
if not previous_assignments and assignments:
previous_assignments = self.scope.parent[name]
self.__assignments |= previous_assignments


class BaseAssignment(abc.ABC):
Expand All @@ -109,10 +144,22 @@ def __init__(self, name: str, scope: "Scope") -> None:
self.__accesses = set()

def record_access(self, access: Access) -> None:
self.__accesses.add(access)
if access.scope != self.scope or self._index < access._index:
self.__accesses.add(access)

def record_accesses(self, accesses: Set[Access]) -> None:
self.__accesses |= accesses
later_accesses = {
access
for access in accesses
if access.scope != self.scope or self._index < access._index
}
self.__accesses |= later_accesses
earlier_accesses = accesses - later_accesses
if earlier_accesses and self.scope.parent != self.scope:
# Accesses "earlier" than the relevant assignment should be attached
# to assignments of the same name in the parent
for shadowed_assignment in self.scope.parent[self.name]:
shadowed_assignment.record_accesses(earlier_accesses)

@property
def references(self) -> Collection[Access]:
Expand All @@ -123,18 +170,31 @@ def references(self) -> Collection[Access]:
def __hash__(self) -> int:
return id(self)

@property
def _index(self) -> int:
"""Return an integer that represents the order of assignments in `scope`"""
return -1


class Assignment(BaseAssignment):
"""An assignment records the name, CSTNode and its accesses."""

#: The node of assignment, it could be a :class:`~libcst.Import`, :class:`~libcst.ImportFrom`,
#: :class:`~libcst.Name`, :class:`~libcst.FunctionDef`, or :class:`~libcst.ClassDef`.
node: cst.CSTNode
__index: int

def __init__(self, name: str, scope: "Scope", node: cst.CSTNode) -> None:
def __init__(
self, name: str, scope: "Scope", node: cst.CSTNode, index: int
) -> None:
self.node = node
self.__index = index
super().__init__(name, scope)

@property
def _index(self) -> int:
return self.__index


# even though we don't override the constructor.
class BuiltinAssignment(BaseAssignment):
Expand Down Expand Up @@ -318,16 +378,20 @@ class Scope(abc.ABC):
globals: "GlobalScope"
_assignments: MutableMapping[str, Set[BaseAssignment]]
_accesses: MutableMapping[str, Set[Access]]
_assignment_count: int

def __init__(self, parent: "Scope") -> None:
super().__init__()
self.parent = parent
self.globals = parent.globals
self._assignments = defaultdict(set)
self._accesses = defaultdict(set)
self._assignment_count = 0

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
self._assignments[name].add(Assignment(name=name, scope=self, node=node))
self._assignments[name].add(
Assignment(name=name, scope=self, node=node, index=self._assignment_count)
)

def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)
Expand Down Expand Up @@ -934,7 +998,7 @@ def infer_accesses(self) -> None:
break

scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(access.scope[name])
access.record_assignments(name)
access.scope.record_access(name, access)

for (scope, name), accesses in scope_name_accesses.items():
Expand All @@ -945,6 +1009,8 @@ def infer_accesses(self) -> None:

def on_leave(self, original_node: cst.CSTNode) -> None:
self.provider.set_metadata(original_node, self.scope)
if isinstance(original_node, _ASSIGNMENT_LIKE_NODES):
self.scope._assignment_count += 1
super().on_leave(original_node)


Expand Down
102 changes: 102 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,3 +1329,105 @@ def test_gen_dotted_names(self) -> None:
)
}
self.assertEqual(names, {"a.b.c", "a.b", "a"})

def test_ordering(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from a import b
class X:
x = b
b = b
y = b
"""
)
global_scope = scopes[m]
import_stmt = ensure_type(
ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom
)
first_assignment = list(global_scope.assignments)[0]
assert isinstance(first_assignment, cst.metadata.Assignment)
self.assertEqual(first_assignment.node, import_stmt)
global_refs = list(first_assignment.references)
self.assertEqual(len(global_refs), 2)
class_def = ensure_type(m.body[1], cst.ClassDef)
x = ensure_type(
ensure_type(class_def.body.body[0], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertEqual(x.value, global_refs[0].node)
class_b = ensure_type(
ensure_type(class_def.body.body[1], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertEqual(class_b.value, global_refs[1].node)

class_accesses = list(scopes[x].accesses)
self.assertEqual(len(class_accesses), 3)
self.assertIn(
class_b.targets[0].target,
[
ref.node
for acc in class_accesses
for ref in acc.referents
if isinstance(ref, Assignment)
],
)
y = ensure_type(
ensure_type(class_def.body.body[2], cst.SimpleStatementLine).body[0],
cst.Assign,
)
self.assertIn(y.value, [access.node for access in class_accesses])

def test_ordering_between_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
print(a)
print(b)
a = 1
b = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
first_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
second_print = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.Call,
)
self.assertEqual(
first_print.args[0].value,
a_param_refs[0].node,
)
a_global = (
cst.ensure_type(
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_global_assignment = list(scopes[a_global]["a"])[0]
a_global_refs = list(a_global_assignment.references)
self.assertEqual(a_global_refs, [])
b_global = (
cst.ensure_type(
cst.ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
b_global_assignment = list(scopes[b_global]["b"])[0]
b_global_refs = list(b_global_assignment.references)
self.assertEqual(len(b_global_refs), 1)
self.assertEqual(b_global_refs[0].node, second_print.args[0].value)

0 comments on commit 2ef7302

Please sign in to comment.