Skip to content

Commit

Permalink
[ScopeProvider] Expose more granular Assignments and Accesses for dot…
Browse files Browse the repository at this point in the history
…ted imports (#284)
  • Loading branch information
zsol authored Apr 21, 2020
1 parent 30cb9f3 commit 477a03e
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 61 deletions.
15 changes: 14 additions & 1 deletion docs/source/metadata.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,26 @@ There are four different type of scope in Python:
:class:`~libcst.metadata.ComprehensionScope`.

.. image:: _static/img/python_scopes.png
:alt: LibCST
:alt: Diagram showing how the above 4 scopes are nested in each other
:width: 400
:align: center

LibCST allows you to inspect these scopes to see what local variables are
assigned or accessed within.

.. note::
Import statements bring new symbols into scope that are declared in other files.
As such, they are represented by :class:`~libcst.metadata.Assignment` for scope
analysis purposes. Dotted imports (e.g. ``import a.b.c``) generate multiple
:class:`~libcst.metadata.Assignment` objects — one for each module. When analyzing
references, only the most specific access is recorded.

For example, the above ``import a.b.c`` statement generates three
:class:`~libcst.metadata.Assignment` objects: one for ``a``, one for ``a.b``, and
one for ``a.b.c``. A reference for ``a.b.c`` records an access only for the last
assignment, while a reference for ``a.d`` only records an access for the
:class:`~libcst.metadata.Assignment` representing ``a``.

.. autoclass:: libcst.metadata.ScopeProvider
:no-undoc-members:

Expand Down
29 changes: 29 additions & 0 deletions libcst/codemod/commands/tests/test_remove_unused_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,32 @@ def test_type_annotations(self) -> None:
x: a = 1
"""
self.assertCodemod(before, before)

def test_dotted_imports(self) -> None:
before = """
import a.b, a.b.c
import e.f
import g.h
import x.y, x.y.z
def foo() -> None:
a.b
e.g
g.h.i
x.y.z
"""

after = """
import a.b, a.b.c
import e.f
import g.h
import x.y.z
def foo() -> None:
a.b
e.g
g.h.i
x.y.z
"""

self.assertCodemod(before, after)
30 changes: 14 additions & 16 deletions libcst/codemod/visitors/_remove_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from libcst.codemod.visitors._gather_exports import GatherExportsVisitor
from libcst.helpers import get_absolute_module_for_import, get_full_name_for_node
from libcst.metadata import Assignment, Scope, ScopeProvider
from libcst.metadata.scope_provider import _gen_dotted_names


class RemovedNodeVisitor(ContextAwareVisitor):
Expand Down Expand Up @@ -295,24 +296,21 @@ def visit_Module(self, node: cst.Module) -> None:
def _is_in_use(self, scope: Scope, alias: cst.ImportAlias) -> bool:
# Grab the string name of this alias from the point of view of this module.
asname = alias.asname
if asname is not None:
name_node = asname.name
else:
name_node = alias.name
while isinstance(name_node, cst.Attribute):
name_node = name_node.value
name_or_alias = cst.ensure_type(name_node, cst.Name).value

if name_or_alias in self.exported_objects:
return True
names = _gen_dotted_names(
cst.ensure_type(asname.name, cst.Name) if asname is not None else alias.name
)

for assignment in scope[name_or_alias]:
if (
isinstance(assignment, Assignment)
and isinstance(assignment.node, (cst.ImportFrom, cst.Import))
and len(assignment.references) > 0
):
for name_or_alias, _ in names:
if name_or_alias in self.exported_objects:
return True

for assignment in scope[name_or_alias]:
if (
isinstance(assignment, Assignment)
and isinstance(assignment.node, (cst.ImportFrom, cst.Import))
and len(assignment.references) > 0
):
return True
return False

def leave_Import(
Expand Down
7 changes: 7 additions & 0 deletions libcst/codemod/visitors/tests/test_remove_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,21 +387,25 @@ def test_remove_import_complex(self) -> None:
import baz, qux
import a.b
import c.d
import x.y.z
import e.f as g
import h.i as j
def foo() -> None:
c.d()
x.u
j()
"""
after = """
import bar
import qux
import c.d
import x.y.z
import h.i as j
def foo() -> None:
c.d()
x.u
j()
"""

Expand All @@ -414,6 +418,7 @@ def foo() -> None:
("c.d", None, None),
("e.f", None, "g"),
("h.i", None, "j"),
("x.y.z", None, None),
],
)

Expand All @@ -428,6 +433,7 @@ def test_remove_fromimport_complex(self) -> None:
from d.e import f
from h.i import j as k
from l.m import n as o
from x import *
def foo() -> None:
f()
Expand All @@ -437,6 +443,7 @@ def foo() -> None:
from bar import qux
from d.e import f
from h.i import j as k
from x import *
def foo() -> None:
f()
Expand Down
84 changes: 58 additions & 26 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
from typing import (
Collection,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Type,
Union,
)

import libcst as cst
from libcst import ensure_type
from libcst._add_slots import add_slots
from libcst._metadata_dependent import MetadataDependent
from libcst.helpers import get_full_name_for_node
Expand Down Expand Up @@ -52,11 +55,14 @@ def __new__(cls) -> "Tree":
...
"""

#: The name node of the access. A name is an access when the expression context is
#: :attr:`ExpressionContext.LOAD`.
node: cst.Name
#: The node of the access. A name is an access when the expression context is
#: :attr:`ExpressionContext.LOAD`. This is usually the name node representing the
#: access, except for dotted imports, when it might be the attribute that
#: represents the most specific part of the imported symbol.
node: Union[cst.Name, cst.Attribute]

#: The scope of the access. Note that a access could be in a child scope of its assignment.
#: The scope of the access. Note that a access could be in a child scope of its
#: assignment.
scope: "Scope"

__assignments: Set["BaseAssignment"]
Expand Down Expand Up @@ -584,12 +590,32 @@ class ComprehensionScope(LocalScope):
pass


# Generates dotted names from an Attribute or Name node:
# Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a")
# each string has the corresponding CSTNode attached to it
def _gen_dotted_names(
node: Union[cst.Attribute, cst.Name]
) -> Iterable[Tuple[str, Union[cst.Attribute, cst.Name]]]:
if isinstance(node, cst.Name):
yield (node.value, node)
else:
value = node.value
if not isinstance(value, (cst.Attribute, cst.Name)):
raise ValueError(f"Unexpected name value in import: {value}")
name_values = iter(_gen_dotted_names(value))
(next_name, next_node) = next(name_values)
yield (f"{next_name}.{node.attr.value}", node)
yield (next_name, next_node)
yield from name_values


class ScopeVisitor(cst.CSTVisitor):
# since it's probably not useful. That can makes this visitor cleaner.
def __init__(self, provider: "ScopeProvider") -> None:
self.provider: ScopeProvider = provider
self.scope: Scope = GlobalScope()
self.__deferred_accesses: List[Access] = []
self.__deferred_accesses: List[Tuple[Access, Optional[cst.Attribute]]] = []
self.__top_level_attribute: Optional[cst.Attribute] = None

@contextmanager
def _new_scope(
Expand All @@ -613,24 +639,18 @@ def _switch_scope(self, scope: Scope) -> Iterator[None]:

def _visit_import_alike(self, node: Union[cst.Import, cst.ImportFrom]) -> bool:
names = node.names
if not isinstance(names, cst.ImportStar):
# make sure node.names is Sequence[ImportAlias]
for name in names:
asname = name.asname
if asname is not None:
name_value = cst.ensure_type(asname.name, cst.Name).value
else:
name_node = name.name
while isinstance(name_node, cst.Attribute):
# the value of Attribute in import alike can only be either Name or Attribute
name_node = name_node.value
if isinstance(name_node, cst.Name):
name_value = name_node.value
else:
raise Exception(
f"Unexpected ImportAlias name value: {name_node}"
)
if isinstance(names, cst.ImportStar):
return False

# make sure node.names is Sequence[ImportAlias]
for name in names:
asname = name.asname
if asname is not None:
name_values = _gen_dotted_names(cst.ensure_type(asname.name, cst.Name))
else:
name_values = _gen_dotted_names(name.name)

for name_value, _ in name_values:
self.scope.record_assignment(name_value, node)
return False

Expand All @@ -641,7 +661,11 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> Optional[bool]:
return self._visit_import_alike(node)

def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
if self.__top_level_attribute is None:
self.__top_level_attribute = node
node.value.visit(self) # explicitly not visiting attr
if self.__top_level_attribute is node:
self.__top_level_attribute = None
return False

def visit_Name(self, node: cst.Name) -> Optional[bool]:
Expand All @@ -651,8 +675,7 @@ def visit_Name(self, node: cst.Name) -> Optional[bool]:
self.scope.record_assignment(node.value, node)
elif context in (ExpressionContext.LOAD, ExpressionContext.DEL):
access = Access(node, self.scope)
self.__deferred_accesses.append(access)
self.scope.record_access(node.value, access)
self.__deferred_accesses.append((access, self.__top_level_attribute))

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope.record_assignment(node.name.value, node)
Expand Down Expand Up @@ -788,10 +811,19 @@ def infer_accesses(self) -> None:
# In worst case, all accesses (m) and assignments (n) refer to the same name,
# the time complexity is O(m x n), this optimizes it as O(m + n).
scope_name_accesses = defaultdict(set)
for access in self.__deferred_accesses:
name = access.node.value
for (access, enclosing_attribute) in self.__deferred_accesses:
if enclosing_attribute is not None:
name = None
for name, node in _gen_dotted_names(enclosing_attribute):
if name in access.scope:
access.node = node
break
assert name is not None
else:
name = ensure_type(access.node, cst.Name).value
scope_name_accesses[(access.scope, name)].add(access)
access.record_assignments(access.scope[name])
access.scope.record_access(name, access)

for (scope, name), accesses in scope_name_accesses.items():
for assignment in scope[name]:
Expand Down
71 changes: 53 additions & 18 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,60 @@ def test_import(self) -> None:
"""
)
scope_of_module = scopes[m]
for idx, in_scope in enumerate(["foo", "fizzbuzz", "a", "g"]):
self.assertEqual(
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)
for idx, in_scopes in enumerate(
[["foo", "foo.bar"], ["fizzbuzz"], ["a", "a.b", "a.b.c"], ["g"],]
):
for in_scope in in_scopes:
self.assertEqual(
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)

assignment = cast(Assignment, list(scope_of_module[in_scope])[0])
self.assertEqual(
assignment.name,
in_scope,
f"Assignment name {assignment.name} should equal to {in_scope}.",
)
import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
self.assertEqual(
assignment.node,
import_node,
f"The node of Assignment {assignment.node} should equal to {import_node}",
)
assignment = cast(Assignment, list(scope_of_module[in_scope])[0])
self.assertEqual(
assignment.name,
in_scope,
f"Assignment name {assignment.name} should equal to {in_scope}.",
)
import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
self.assertEqual(
assignment.node,
import_node,
f"The node of Assignment {assignment.node} should equal to {import_node}",
)

def test_dotted_import_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
import a.b.c, x.y
a.b.c(x.z)
"""
)
scope_of_module = scopes[m]
first_statement = ensure_type(m.body[1], cst.SimpleStatementLine)
call = ensure_type(
ensure_type(first_statement.body[0], cst.Expr).value, cst.Call
)
self.assertTrue("a.b.c" in scope_of_module)
self.assertTrue("a" in scope_of_module)
self.assertEqual(scope_of_module.accesses["a"], set())

a_b_c_assignment = cast(Assignment, list(scope_of_module["a.b.c"])[0])
a_b_c_access = list(a_b_c_assignment.references)[0]
self.assertEqual(scope_of_module.accesses["a.b.c"], {a_b_c_access})
self.assertEqual(a_b_c_access.node, call.func)

x_assignment = cast(Assignment, list(scope_of_module["x"])[0])
x_access = list(x_assignment.references)[0]
self.assertEqual(scope_of_module.accesses["x"], {x_access})
self.assertEqual(
x_access.node, ensure_type(call.args[0].value, cst.Attribute).value
)

self.assertTrue("x.y" in scope_of_module)
self.assertEqual(list(scope_of_module["x.y"])[0].references, set())
self.assertEqual(scope_of_module.accesses["x.y"], set())

def test_imoprt_from(self) -> None:
def test_import_from(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from foo.bar import a, b as b_renamed
Expand Down Expand Up @@ -782,7 +817,7 @@ def test_multiple_assignments(self) -> None:
},
)

def test_assignemnts_and_accesses(self) -> None:
def test_assignments_and_accesses(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
a = 1
Expand Down

0 comments on commit 477a03e

Please sign in to comment.