Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin accesses to import alias node #554

Merged
merged 6 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions libcst/codemod/visitors/_gather_unused_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def is_in_use(self, scope: cst.metadata.Scope, alias: cst.ImportAlias) -> bool:

for assignment in scope[name_or_alias]:
if (
isinstance(assignment, cst.metadata.Assignment)
and isinstance(assignment.node, (cst.ImportFrom, cst.Import))
isinstance(assignment, cst.metadata.ImportAssignment)
and len(assignment.references) > 0
):
return True
Expand Down
2 changes: 2 additions & 0 deletions libcst/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ComprehensionScope,
FunctionScope,
GlobalScope,
ImportAssignment,
QualifiedName,
QualifiedNameSource,
Scope,
Expand All @@ -63,6 +64,7 @@
"BaseAssignment",
"Assignment",
"BuiltinAssignment",
"ImportAssignment",
"BuiltinScope",
"Access",
"Scope",
Expand Down
65 changes: 52 additions & 13 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,23 @@ class BuiltinAssignment(BaseAssignment):
pass


class ImportAssignment(Assignment):
"""An assignment records the import node and it's alias"""

as_name: cst.CSTNode

def __init__(
self,
name: str,
scope: "Scope",
node: cst.CSTNode,
index: int,
as_name: cst.CSTNode,
):
super().__init__(name, scope, node, index)
self.as_name = as_name


class Assignments:
"""A container to provide all assignments in a scope."""

Expand Down Expand Up @@ -401,10 +418,23 @@ def __init__(self, parent: "Scope") -> None:
self._assignment_count = 0

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
self._assignments[name].add(
self._add_assignment(
Assignment(name=name, scope=self, node=node, index=self._assignment_count)
giomeg marked this conversation as resolved.
Show resolved Hide resolved
)

def record_import_assignment(
self, name: str, node: cst.CSTNode, as_name: cst.CSTNode
) -> None:
self._add_assignment(
ImportAssignment(
name=name,
scope=self,
node=node,
as_name=as_name,
index=self._assignment_count,
)
)

def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)

Expand All @@ -416,9 +446,12 @@ def _contains_in_self_or_parent(self, name: str) -> bool:
"""Overridden by ClassScope to hide it's assignments from child scopes."""
return name in self

def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None:
def _add_assignment(self, assignment: "BaseAssignment") -> None:
self._assignments[assignment.name].add(assignment)
giomeg marked this conversation as resolved.
Show resolved Hide resolved

def _add_assignment_as_parent(self, assignment):
giomeg marked this conversation as resolved.
Show resolved Hide resolved
"""Overridden by ClassScope to forward 'nonlocal' assignments from child scopes."""
self.record_assignment(name, node)
self._add_assignment(assignment)

@abc.abstractmethod
def __contains__(self, name: str) -> bool:
Expand Down Expand Up @@ -575,15 +608,15 @@ def __getitem__(self, name: str) -> Set[BaseAssignment]:
return self._assignments[name]
return set()

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
raise NotImplementedError("assignments in builtin scope are not allowed")

def record_global_overwrite(self, name: str) -> None:
raise NotImplementedError("global overwrite in builtin scope are not allowed")

def record_nonlocal_overwrite(self, name: str) -> None:
raise NotImplementedError("declarations in builtin scope are not allowed")

def _add_assignment(self, assignment: "BaseAssignment") -> None:
raise NotImplementedError("assignments in builtin scope are not allowed")


class GlobalScope(Scope):
"""
Expand Down Expand Up @@ -634,11 +667,13 @@ def record_global_overwrite(self, name: str) -> None:
def record_nonlocal_overwrite(self, name: str) -> None:
self._scope_overwrites[name] = self.parent

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
if name in self._scope_overwrites:
self._scope_overwrites[name]._record_assignment_as_parent(name, node)
def _add_assignment(self, assignment: "BaseAssignment") -> None:
if assignment.name in self._scope_overwrites:
self._scope_overwrites[assignment.name]._add_assignment_as_parent(
assignment
)
else:
super().record_assignment(name, node)
super()._add_assignment(assignment)

def __contains__(self, name: str) -> bool:
if name in self._scope_overwrites:
Expand Down Expand Up @@ -671,7 +706,7 @@ class ClassScope(LocalScope):
When a class is defined, it creates a ClassScope.
"""

def _record_assignment_as_parent(self, name: str, node: cst.CSTNode) -> None:
def _add_assignment_as_parent(self, assignment: "BaseAssignment") -> None:
"""
Forward the assignment to parent.

Expand All @@ -686,7 +721,7 @@ def inner_fn():
# hidden from its children.

"""
self.parent._record_assignment_as_parent(name, node)
self.parent._add_assignment_as_parent(assignment)

def _getitem_from_self_or_parent(self, name: str) -> Set[BaseAssignment]:
"""
Expand Down Expand Up @@ -826,11 +861,15 @@ def _visit_import_alike(self, node: Union[cst.Import, cst.ImportFrom]) -> bool:
asname = name.asname
if asname is not None:
name_values = _gen_dotted_names(cst.ensure_type(asname.name, cst.Name))
import_node_asname = asname.name
else:
name_values = _gen_dotted_names(name.name)
import_node_asname = name.name

for name_value, _ in name_values:
self.scope.record_assignment(name_value, node)
self.scope.record_import_assignment(
name_value, node, import_node_asname
)
return False

def visit_Import(self, node: cst.Import) -> Optional[bool]:
Expand Down
43 changes: 34 additions & 9 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from libcst.metadata import MetadataWrapper
from libcst.metadata.scope_provider import (
Assignment,
ImportAssignment,
BuiltinAssignment,
BuiltinScope,
ClassScope,
Expand Down Expand Up @@ -192,17 +193,24 @@ def test_import(self) -> None:
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)

assignment = cast(Assignment, list(scope_of_module[in_scope])[0])
assignment = cast(ImportAssignment, list(scope_of_module[in_scope])[0])
self.assertEqual(
assignment.name,
in_scope,
f"Assignment name {assignment.name} should equal to {in_scope}.",
f"ImportAssignment name {assignment.name} should equal to {in_scope}.",
)
import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
self.assertEqual(
assignment.node,
import_node,
f"The node of Assignment {assignment.node} should equal to {import_node}",
f"The node of ImportAssignment {assignment.node} should equal to {import_node}",
)
alias = import_node.names[0]
as_name = alias.asname.name if alias.asname else alias.name
self.assertEqual(
assignment.as_name,
as_name,
f"The alias name of ImportAssignment {assignment.as_name} should equal to {as_name}",
)

def test_dotted_import_access(self) -> None:
Expand All @@ -221,7 +229,7 @@ def test_dotted_import_access(self) -> None:
self.assertTrue("a" in scope_of_module)
self.assertEqual(scope_of_module.accesses["a"], set())

a_b_c_assignment = cast(Assignment, list(scope_of_module["a.b.c"])[0])
a_b_c_assignment = cast(ImportAssignment, list(scope_of_module["a.b.c"])[0])
a_b_c_access = list(a_b_c_assignment.references)[0]
self.assertEqual(scope_of_module.accesses["a.b.c"], {a_b_c_access})
self.assertEqual(a_b_c_access.node, call.func)
Expand Down Expand Up @@ -261,7 +269,9 @@ def test_dotted_import_with_call_access(self) -> None:
self.assertTrue("os.path" in scope_of_module)
self.assertTrue("os" in scope_of_module)

os_path_join_assignment = cast(Assignment, list(scope_of_module["os.path"])[0])
os_path_join_assignment = cast(
ImportAssignment, list(scope_of_module["os.path"])[0]
)
os_path_join_assignment_references = list(os_path_join_assignment.references)
self.assertNotEqual(len(os_path_join_assignment_references), 0)
os_path_join_access = os_path_join_assignment_references[0]
Expand Down Expand Up @@ -289,21 +299,36 @@ def test_import_from(self) -> None:
for alias in import_aliases:
self.assertEqual(scopes[alias], scope_of_module)

for idx, in_scope in [(0, "a"), (0, "b_renamed"), (1, "c"), (2, "d")]:
for idx, in_scope, imported_object_idx in [
(0, "a", 0),
(0, "b_renamed", 1),
(1, "c", 0),
(2, "d", 0),
]:
self.assertEqual(
len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
)
import_assignment = cast(Assignment, list(scope_of_module[in_scope])[0])
import_assignment = cast(
ImportAssignment, list(scope_of_module[in_scope])[0]
)
self.assertEqual(
import_assignment.name,
in_scope,
f"The name of Assignment {import_assignment.name} should equal to {in_scope}.",
f"The name of ImportAssignment {import_assignment.name} should equal to {in_scope}.",
)
import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
self.assertEqual(
import_assignment.node,
import_node,
f"The node of Assignment {import_assignment.node} should equal to {import_node}",
f"The node of ImportAssignment {import_assignment.node} should equal to {import_node}",
)

alias = import_node.names[imported_object_idx]
as_name = alias.asname.name if alias.asname else alias.name
self.assertEqual(
import_assignment.as_name,
as_name,
f"The alias name of ImportAssignment {import_assignment.as_name} should equal to {as_name}",
)

for not_in_scope in ["foo", "bar", "foo.bar", "b"]:
Expand Down