diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index 39fb33c57..7886f458d 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -36,6 +36,25 @@ ) +_ASSIGNMENT_LIKE_NODES = ( + cst.AnnAssign, + cst.AsName, + cst.Assign, + cst.AugAssign, + cst.ClassDef, + cst.CompFor, + cst.For, + cst.FunctionDef, + cst.Global, + cst.Import, + cst.ImportFrom, + cst.NamedExpr, + cst.Nonlocal, + cst.Parameters, + cst.WithItem, +) + + @add_slots @dataclass(frozen=False) class Access: @@ -68,6 +87,7 @@ def __new__(cls) -> "Tree": is_type_hint: bool __assignments: Set["BaseAssignment"] + __index: int def __init__( self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool @@ -77,6 +97,7 @@ def __init__( self.is_annotation = is_annotation self.is_type_hint = is_type_hint self.__assignments = set() + self.__index = scope._assignment_count def __hash__(self) -> int: return id(self) @@ -86,11 +107,25 @@ def referents(self) -> Collection["BaseAssignment"]: """Return all assignments of the access.""" return self.__assignments - def record_assignment(self, assignment: "BaseAssignment") -> None: - self.__assignments.add(assignment) + @property + def _index(self) -> int: + return self.__index - def record_assignments(self, assignments: Set["BaseAssignment"]) -> None: - self.__assignments |= assignments + def record_assignment(self, assignment: "BaseAssignment") -> None: + if assignment.scope != self.scope or assignment._index < self.__index: + self.__assignments.add(assignment) + + def record_assignments(self, name: str) -> None: + assignments = self.scope[name] + # filter out assignments that happened later than this access + previous_assignments = { + assignment + for assignment in assignments + if assignment.scope != self.scope or assignment._index < self.__index + } + if not previous_assignments and assignments: + previous_assignments = self.scope.parent[name] + self.__assignments |= previous_assignments class BaseAssignment(abc.ABC): @@ -109,10 +144,22 @@ def __init__(self, name: str, scope: "Scope") -> None: self.__accesses = set() def record_access(self, access: Access) -> None: - self.__accesses.add(access) + if access.scope != self.scope or self._index < access._index: + self.__accesses.add(access) def record_accesses(self, accesses: Set[Access]) -> None: - self.__accesses |= accesses + later_accesses = { + access + for access in accesses + if access.scope != self.scope or self._index < access._index + } + self.__accesses |= later_accesses + earlier_accesses = accesses - later_accesses + if earlier_accesses and self.scope.parent != self.scope: + # Accesses "earlier" than the relevant assignment should be attached + # to assignments of the same name in the parent + for shadowed_assignment in self.scope.parent[self.name]: + shadowed_assignment.record_accesses(earlier_accesses) @property def references(self) -> Collection[Access]: @@ -123,6 +170,11 @@ def references(self) -> Collection[Access]: def __hash__(self) -> int: return id(self) + @property + def _index(self) -> int: + """Return an integer that represents the order of assignments in `scope`""" + return -1 + class Assignment(BaseAssignment): """An assignment records the name, CSTNode and its accesses.""" @@ -130,11 +182,19 @@ class Assignment(BaseAssignment): #: The node of assignment, it could be a :class:`~libcst.Import`, :class:`~libcst.ImportFrom`, #: :class:`~libcst.Name`, :class:`~libcst.FunctionDef`, or :class:`~libcst.ClassDef`. node: cst.CSTNode + __index: int - def __init__(self, name: str, scope: "Scope", node: cst.CSTNode) -> None: + def __init__( + self, name: str, scope: "Scope", node: cst.CSTNode, index: int + ) -> None: self.node = node + self.__index = index super().__init__(name, scope) + @property + def _index(self) -> int: + return self.__index + # even though we don't override the constructor. class BuiltinAssignment(BaseAssignment): @@ -318,6 +378,7 @@ class Scope(abc.ABC): globals: "GlobalScope" _assignments: MutableMapping[str, Set[BaseAssignment]] _accesses: MutableMapping[str, Set[Access]] + _assignment_count: int def __init__(self, parent: "Scope") -> None: super().__init__() @@ -325,9 +386,12 @@ def __init__(self, parent: "Scope") -> None: self.globals = parent.globals self._assignments = defaultdict(set) self._accesses = defaultdict(set) + self._assignment_count = 0 def record_assignment(self, name: str, node: cst.CSTNode) -> None: - self._assignments[name].add(Assignment(name=name, scope=self, node=node)) + self._assignments[name].add( + Assignment(name=name, scope=self, node=node, index=self._assignment_count) + ) def record_access(self, name: str, access: Access) -> None: self._accesses[name].add(access) @@ -934,7 +998,7 @@ def infer_accesses(self) -> None: break scope_name_accesses[(access.scope, name)].add(access) - access.record_assignments(access.scope[name]) + access.record_assignments(name) access.scope.record_access(name, access) for (scope, name), accesses in scope_name_accesses.items(): @@ -945,6 +1009,8 @@ def infer_accesses(self) -> None: def on_leave(self, original_node: cst.CSTNode) -> None: self.provider.set_metadata(original_node, self.scope) + if isinstance(original_node, _ASSIGNMENT_LIKE_NODES): + self.scope._assignment_count += 1 super().on_leave(original_node) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index f04efa05b..e54bbff99 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -1329,3 +1329,105 @@ def test_gen_dotted_names(self) -> None: ) } self.assertEqual(names, {"a.b.c", "a.b", "a"}) + + def test_ordering(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + from a import b + class X: + x = b + b = b + y = b + """ + ) + global_scope = scopes[m] + import_stmt = ensure_type( + ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom + ) + first_assignment = list(global_scope.assignments)[0] + assert isinstance(first_assignment, cst.metadata.Assignment) + self.assertEqual(first_assignment.node, import_stmt) + global_refs = list(first_assignment.references) + self.assertEqual(len(global_refs), 2) + class_def = ensure_type(m.body[1], cst.ClassDef) + x = ensure_type( + ensure_type(class_def.body.body[0], cst.SimpleStatementLine).body[0], + cst.Assign, + ) + self.assertEqual(x.value, global_refs[0].node) + class_b = ensure_type( + ensure_type(class_def.body.body[1], cst.SimpleStatementLine).body[0], + cst.Assign, + ) + self.assertEqual(class_b.value, global_refs[1].node) + + class_accesses = list(scopes[x].accesses) + self.assertEqual(len(class_accesses), 3) + self.assertIn( + class_b.targets[0].target, + [ + ref.node + for acc in class_accesses + for ref in acc.referents + if isinstance(ref, Assignment) + ], + ) + y = ensure_type( + ensure_type(class_def.body.body[2], cst.SimpleStatementLine).body[0], + cst.Assign, + ) + self.assertIn(y.value, [access.node for access in class_accesses]) + + def test_ordering_between_scopes(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + def f(a): + print(a) + print(b) + a = 1 + b = 1 + """ + ) + f = cst.ensure_type(m.body[0], cst.FunctionDef) + a_param = f.params.params[0].name + a_param_assignment = list(scopes[a_param]["a"])[0] + a_param_refs = list(a_param_assignment.references) + first_print = cst.ensure_type( + cst.ensure_type( + cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + cst.Call, + ) + second_print = cst.ensure_type( + cst.ensure_type( + cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + cst.Call, + ) + self.assertEqual( + first_print.args[0].value, + a_param_refs[0].node, + ) + a_global = ( + cst.ensure_type( + cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign + ) + .targets[0] + .target + ) + a_global_assignment = list(scopes[a_global]["a"])[0] + a_global_refs = list(a_global_assignment.references) + self.assertEqual(a_global_refs, []) + b_global = ( + cst.ensure_type( + cst.ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Assign + ) + .targets[0] + .target + ) + b_global_assignment = list(scopes[b_global]["b"])[0] + b_global_refs = list(b_global_assignment.references) + self.assertEqual(len(b_global_refs), 1) + self.assertEqual(b_global_refs[0].node, second_print.args[0].value)