diff --git a/docs/source/metadata.rst b/docs/source/metadata.rst index 5f0df2801..f6c9c0784 100644 --- a/docs/source/metadata.rst +++ b/docs/source/metadata.rst @@ -141,13 +141,26 @@ There are four different type of scope in Python: :class:`~libcst.metadata.ComprehensionScope`. .. image:: _static/img/python_scopes.png - :alt: LibCST + :alt: Diagram showing how the above 4 scopes are nested in each other :width: 400 :align: center LibCST allows you to inspect these scopes to see what local variables are assigned or accessed within. +.. note:: + Import statements bring new symbols into scope that are declared in other files. + As such, they are represented by :class:`~libcst.metadata.Assignment` for scope + analysis purposes. Dotted imports (e.g. ``import a.b.c``) generate multiple + :class:`~libcst.metadata.Assignment` objects — one for each module. When analyzing + references, only the most specific access is recorded. + + For example, the above ``import a.b.c`` statement generates three + :class:`~libcst.metadata.Assignment` objects: one for ``a``, one for ``a.b``, and + one for ``a.b.c``. A reference for ``a.b.c`` records an access only for the last + assignment, while a reference for ``a.d`` only records an access for the + :class:`~libcst.metadata.Assignment` representing ``a``. + .. autoclass:: libcst.metadata.ScopeProvider :no-undoc-members: diff --git a/libcst/codemod/commands/tests/test_remove_unused_imports.py b/libcst/codemod/commands/tests/test_remove_unused_imports.py index 08204358e..8c500e0bc 100644 --- a/libcst/codemod/commands/tests/test_remove_unused_imports.py +++ b/libcst/codemod/commands/tests/test_remove_unused_imports.py @@ -48,3 +48,32 @@ def test_type_annotations(self) -> None: x: a = 1 """ self.assertCodemod(before, before) + + def test_dotted_imports(self) -> None: + before = """ + import a.b, a.b.c + import e.f + import g.h + import x.y, x.y.z + + def foo() -> None: + a.b + e.g + g.h.i + x.y.z + """ + + after = """ + import a.b, a.b.c + import e.f + import g.h + import x.y.z + + def foo() -> None: + a.b + e.g + g.h.i + x.y.z + """ + + self.assertCodemod(before, after) diff --git a/libcst/codemod/visitors/_remove_imports.py b/libcst/codemod/visitors/_remove_imports.py index 0023bd227..841812c4c 100644 --- a/libcst/codemod/visitors/_remove_imports.py +++ b/libcst/codemod/visitors/_remove_imports.py @@ -12,6 +12,7 @@ from libcst.codemod.visitors._gather_exports import GatherExportsVisitor from libcst.helpers import get_absolute_module_for_import, get_full_name_for_node from libcst.metadata import Assignment, Scope, ScopeProvider +from libcst.metadata.scope_provider import _gen_dotted_names class RemovedNodeVisitor(ContextAwareVisitor): @@ -295,24 +296,21 @@ def visit_Module(self, node: cst.Module) -> None: def _is_in_use(self, scope: Scope, alias: cst.ImportAlias) -> bool: # Grab the string name of this alias from the point of view of this module. asname = alias.asname - if asname is not None: - name_node = asname.name - else: - name_node = alias.name - while isinstance(name_node, cst.Attribute): - name_node = name_node.value - name_or_alias = cst.ensure_type(name_node, cst.Name).value - - if name_or_alias in self.exported_objects: - return True + names = _gen_dotted_names( + cst.ensure_type(asname.name, cst.Name) if asname is not None else alias.name + ) - for assignment in scope[name_or_alias]: - if ( - isinstance(assignment, Assignment) - and isinstance(assignment.node, (cst.ImportFrom, cst.Import)) - and len(assignment.references) > 0 - ): + for name_or_alias, _ in names: + if name_or_alias in self.exported_objects: return True + + for assignment in scope[name_or_alias]: + if ( + isinstance(assignment, Assignment) + and isinstance(assignment.node, (cst.ImportFrom, cst.Import)) + and len(assignment.references) > 0 + ): + return True return False def leave_Import( diff --git a/libcst/codemod/visitors/tests/test_remove_imports.py b/libcst/codemod/visitors/tests/test_remove_imports.py index ec8e460c4..76c751c66 100644 --- a/libcst/codemod/visitors/tests/test_remove_imports.py +++ b/libcst/codemod/visitors/tests/test_remove_imports.py @@ -387,21 +387,25 @@ def test_remove_import_complex(self) -> None: import baz, qux import a.b import c.d + import x.y.z import e.f as g import h.i as j def foo() -> None: c.d() + x.u j() """ after = """ import bar import qux import c.d + import x.y.z import h.i as j def foo() -> None: c.d() + x.u j() """ @@ -414,6 +418,7 @@ def foo() -> None: ("c.d", None, None), ("e.f", None, "g"), ("h.i", None, "j"), + ("x.y.z", None, None), ], ) @@ -428,6 +433,7 @@ def test_remove_fromimport_complex(self) -> None: from d.e import f from h.i import j as k from l.m import n as o + from x import * def foo() -> None: f() @@ -437,6 +443,7 @@ def foo() -> None: from bar import qux from d.e import f from h.i import j as k + from x import * def foo() -> None: f() diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index f77ee5764..296da70de 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -14,17 +14,20 @@ from typing import ( Collection, Dict, + Iterable, Iterator, List, Mapping, MutableMapping, Optional, Set, + Tuple, Type, Union, ) import libcst as cst +from libcst import ensure_type from libcst._add_slots import add_slots from libcst._metadata_dependent import MetadataDependent from libcst.helpers import get_full_name_for_node @@ -52,11 +55,14 @@ def __new__(cls) -> "Tree": ... """ - #: The name node of the access. A name is an access when the expression context is - #: :attr:`ExpressionContext.LOAD`. - node: cst.Name + #: The node of the access. A name is an access when the expression context is + #: :attr:`ExpressionContext.LOAD`. This is usually the name node representing the + #: access, except for dotted imports, when it might be the attribute that + #: represents the most specific part of the imported symbol. + node: Union[cst.Name, cst.Attribute] - #: The scope of the access. Note that a access could be in a child scope of its assignment. + #: The scope of the access. Note that a access could be in a child scope of its + #: assignment. scope: "Scope" __assignments: Set["BaseAssignment"] @@ -584,12 +590,32 @@ class ComprehensionScope(LocalScope): pass +# Generates dotted names from an Attribute or Name node: +# Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a") +# each string has the corresponding CSTNode attached to it +def _gen_dotted_names( + node: Union[cst.Attribute, cst.Name] +) -> Iterable[Tuple[str, Union[cst.Attribute, cst.Name]]]: + if isinstance(node, cst.Name): + yield (node.value, node) + else: + value = node.value + if not isinstance(value, (cst.Attribute, cst.Name)): + raise ValueError(f"Unexpected name value in import: {value}") + name_values = iter(_gen_dotted_names(value)) + (next_name, next_node) = next(name_values) + yield (f"{next_name}.{node.attr.value}", node) + yield (next_name, next_node) + yield from name_values + + class ScopeVisitor(cst.CSTVisitor): # since it's probably not useful. That can makes this visitor cleaner. def __init__(self, provider: "ScopeProvider") -> None: self.provider: ScopeProvider = provider self.scope: Scope = GlobalScope() - self.__deferred_accesses: List[Access] = [] + self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = [] + self.__top_level_attribute: Optional[cst.Attribute] = None @contextmanager def _new_scope( @@ -613,24 +639,18 @@ def _switch_scope(self, scope: Scope) -> Iterator[None]: def _visit_import_alike(self, node: Union[cst.Import, cst.ImportFrom]) -> bool: names = node.names - if not isinstance(names, cst.ImportStar): - # make sure node.names is Sequence[ImportAlias] - for name in names: - asname = name.asname - if asname is not None: - name_value = cst.ensure_type(asname.name, cst.Name).value - else: - name_node = name.name - while isinstance(name_node, cst.Attribute): - # the value of Attribute in import alike can only be either Name or Attribute - name_node = name_node.value - if isinstance(name_node, cst.Name): - name_value = name_node.value - else: - raise Exception( - f"Unexpected ImportAlias name value: {name_node}" - ) + if isinstance(names, cst.ImportStar): + return False + + # make sure node.names is Sequence[ImportAlias] + for name in names: + asname = name.asname + if asname is not None: + name_values = _gen_dotted_names(cst.ensure_type(asname.name, cst.Name)) + else: + name_values = _gen_dotted_names(name.name) + for name_value, _ in name_values: self.scope.record_assignment(name_value, node) return False @@ -641,7 +661,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> Optional[bool]: return self._visit_import_alike(node) def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: + if self.__top_level_attribute is None: + self.__top_level_attribute = node node.value.visit(self) # explicitly not visiting attr + if self.__top_level_attribute is node: + self.__top_level_attribute = None return False def visit_Name(self, node: cst.Name) -> Optional[bool]: @@ -651,8 +675,7 @@ def visit_Name(self, node: cst.Name) -> Optional[bool]: self.scope.record_assignment(node.value, node) elif context in (ExpressionContext.LOAD, ExpressionContext.DEL): access = Access(node, self.scope) - self.__deferred_accesses.append(access) - self.scope.record_access(node.value, access) + self.__deferred_accesses.append((access, self.__top_level_attribute)) def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: self.scope.record_assignment(node.name.value, node) @@ -788,10 +811,19 @@ def infer_accesses(self) -> None: # In worst case, all accesses (m) and assignments (n) refer to the same name, # the time complexity is O(m x n), this optimizes it as O(m + n). scope_name_accesses = defaultdict(set) - for access in self.__deferred_accesses: - name = access.node.value + for (access, enclosing_attribute) in self.__deferred_accesses: + if enclosing_attribute is not None: + name = None + for name, node in _gen_dotted_names(enclosing_attribute): + if name in access.scope: + access.node = node + break + assert name is not None + else: + name = ensure_type(access.node, cst.Name).value scope_name_accesses[(access.scope, name)].add(access) access.record_assignments(access.scope[name]) + access.scope.record_access(name, access) for (scope, name), accesses in scope_name_accesses.items(): for assignment in scope[name]: diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 5699be177..f9d164621 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -141,25 +141,60 @@ def test_import(self) -> None: """ ) scope_of_module = scopes[m] - for idx, in_scope in enumerate(["foo", "fizzbuzz", "a", "g"]): - self.assertEqual( - len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope." - ) + for idx, in_scopes in enumerate( + [["foo", "foo.bar"], ["fizzbuzz"], ["a", "a.b", "a.b.c"], ["g"],] + ): + for in_scope in in_scopes: + self.assertEqual( + len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope." + ) - assignment = cast(Assignment, list(scope_of_module[in_scope])[0]) - self.assertEqual( - assignment.name, - in_scope, - f"Assignment name {assignment.name} should equal to {in_scope}.", - ) - import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0] - self.assertEqual( - assignment.node, - import_node, - f"The node of Assignment {assignment.node} should equal to {import_node}", - ) + assignment = cast(Assignment, list(scope_of_module[in_scope])[0]) + self.assertEqual( + assignment.name, + in_scope, + f"Assignment name {assignment.name} should equal to {in_scope}.", + ) + import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0] + self.assertEqual( + assignment.node, + import_node, + f"The node of Assignment {assignment.node} should equal to {import_node}", + ) + + def test_dotted_import_access(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + import a.b.c, x.y + a.b.c(x.z) + """ + ) + scope_of_module = scopes[m] + first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) + call = ensure_type( + ensure_type(first_statement.body[0], cst.Expr).value, cst.Call + ) + self.assertTrue("a.b.c" in scope_of_module) + self.assertTrue("a" in scope_of_module) + self.assertEqual(scope_of_module.accesses["a"], set()) + + a_b_c_assignment = cast(Assignment, list(scope_of_module["a.b.c"])[0]) + a_b_c_access = list(a_b_c_assignment.references)[0] + self.assertEqual(scope_of_module.accesses["a.b.c"], {a_b_c_access}) + self.assertEqual(a_b_c_access.node, call.func) + + x_assignment = cast(Assignment, list(scope_of_module["x"])[0]) + x_access = list(x_assignment.references)[0] + self.assertEqual(scope_of_module.accesses["x"], {x_access}) + self.assertEqual( + x_access.node, ensure_type(call.args[0].value, cst.Attribute).value + ) + + self.assertTrue("x.y" in scope_of_module) + self.assertEqual(list(scope_of_module["x.y"])[0].references, set()) + self.assertEqual(scope_of_module.accesses["x.y"], set()) - def test_imoprt_from(self) -> None: + def test_import_from(self) -> None: m, scopes = get_scope_metadata_provider( """ from foo.bar import a, b as b_renamed @@ -782,7 +817,7 @@ def test_multiple_assignments(self) -> None: }, ) - def test_assignemnts_and_accesses(self) -> None: + def test_assignments_and_accesses(self) -> None: m, scopes = get_scope_metadata_provider( """ a = 1