Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualify imported symbols when the dequalified form would cause a conflict #674

Merged
merged 4 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 227 additions & 25 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union

Expand All @@ -14,6 +15,7 @@
from libcst.codemod.visitors._add_imports import AddImportsVisitor
from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_full_name_for_node
from libcst.metadata import PositionProvider, QualifiedNameProvider

Expand All @@ -29,6 +31,41 @@
]


def _module_and_target(qualified_name: str) -> Tuple[str, str]:
relative_prefix = ""
while qualified_name.startswith("."):
relative_prefix += "."
qualified_name = qualified_name[1:]
split = qualified_name.rsplit(".", 1)
if len(split) == 1:
qualifier, target = "", split[0]
else:
qualifier, target = split
return (relative_prefix + qualifier, target)


def _get_unique_qualified_name(
visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode
) -> str:
name = None
names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)]
if len(names) == 0:
# we hit this branch if the stub is directly using a fully
# qualified name, which is not technically valid python but is
# convenient to allow.
name = get_full_name_for_node(node)
elif len(names) == 1 and isinstance(names[0], str):
name = names[0]
if name is None:
start = visitor.get_metadata(PositionProvider, node).start
raise ValueError(
"Could not resolve a unique qualified name for type "
+ f"{get_full_name_for_node(node)} at {start.line}:{start.column}. "
+ f"Candidate names were: {names!r}"
)
return name


def _get_import_alias_names(
import_aliases: Sequence[cst.ImportAlias],
) -> Set[str]:
Expand Down Expand Up @@ -186,6 +223,130 @@ def finish(self) -> None:
self.typevars = {k: v for k, v in self.typevars.items() if k in self.names}


@dataclass(frozen=True)
class ImportedSymbol:
"""Import of foo.Bar, where both foo and Bar are potentially aliases."""

module_name: str
module_alias: Optional[str] = None
target_name: Optional[str] = None
target_alias: Optional[str] = None

@property
def symbol(self) -> Optional[str]:
return self.target_alias or self.target_name

@property
def module_symbol(self) -> str:
return self.module_alias or self.module_name


class ImportedSymbolCollector(m.MatcherDecoratableVisitor):
"""
Collect imported symbols from a stub module.
"""

METADATA_DEPENDENCIES = (
PositionProvider,
QualifiedNameProvider,
)

def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None:
super().__init__()
self.existing_imports: Set[str] = existing_imports
self.imported_symbols: Dict[str, Set[ImportedSymbol]] = defaultdict(set)

def visit_ClassDef(self, node: cst.ClassDef) -> None:
for base in node.bases:
value = base.value
if isinstance(value, NAME_OR_ATTRIBUTE):
self._handle_NameOrAttribute(value)
elif isinstance(value, cst.Subscript):
self._handle_Subscript(value)

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
if node.returns is not None:
self._handle_Annotation(annotation=node.returns)
self._handle_Parameters(node.params)

# pyi files don't support inner functions, return False to stop the traversal.
return False

def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
self._handle_Annotation(annotation=node.annotation)

# Handler functions.
#
# These ultimately all call _handle_NameOrAttribute, which adds the
# qualified name to the list of imported symbols

def _handle_NameOrAttribute(
self,
node: NameOrAttribute,
) -> None:
obj = sym = None # keep pyre happy
if isinstance(node, cst.Name):
obj = None
sym = node.value
elif isinstance(node, cst.Attribute):
obj = node.value.value # pyre-ignore[16]
sym = node.attr.value
qualified_name = _get_unique_qualified_name(self, node)
module, target = _module_and_target(qualified_name)
if module in ("", "builtins"):
return
elif qualified_name not in self.existing_imports:
mod = ImportedSymbol(
module_name=module,
module_alias=obj if obj != module else None,
target_name=target,
target_alias=sym if sym != target else None,
)
self.imported_symbols[sym].add(mod)

def _handle_Index(self, slice: cst.Index) -> None:
value = slice.value
if isinstance(value, cst.Subscript):
self._handle_Subscript(value)
elif isinstance(value, cst.Attribute):
self._handle_NameOrAttribute(value)

def _handle_Subscript(self, node: cst.Subscript) -> None:
value = node.value
if isinstance(value, NAME_OR_ATTRIBUTE):
self._handle_NameOrAttribute(value)
else:
raise ValueError("Expected any indexed type to have")
if _get_unique_qualified_name(self, node) in ("Type", "typing.Type"):
return
slice = node.slice
if isinstance(slice, tuple):
for item in slice:
if isinstance(item.slice.value, NAME_OR_ATTRIBUTE):
self._handle_NameOrAttribute(item.slice.value)
else:
if isinstance(item.slice, cst.Index):
self._handle_Index(item.slice)
elif isinstance(slice, cst.Index):
self._handle_Index(slice)

def _handle_Annotation(self, annotation: cst.Annotation) -> None:
node = annotation.annotation
if isinstance(node, cst.Subscript):
self._handle_Subscript(node)
elif isinstance(node, NAME_OR_ATTRIBUTE):
self._handle_NameOrAttribute(node)
elif isinstance(node, cst.SimpleString):
pass
else:
raise ValueError(f"Unexpected annotation node: {node}")

def _handle_Parameters(self, parameters: cst.Parameters) -> None:
for parameter in list(parameters.params):
if parameter.annotation is not None:
self._handle_Annotation(annotation=parameter.annotation)


class TypeCollector(m.MatcherDecoratableVisitor):
"""
Collect type annotations from a stub module.
Expand All @@ -201,6 +362,7 @@ class TypeCollector(m.MatcherDecoratableVisitor):
def __init__(
self,
existing_imports: Set[str],
module_imports: Dict[str, ImportItem],
context: CodemodContext,
) -> None:
super().__init__()
Expand All @@ -212,6 +374,9 @@ def __init__(
# as well as module names, although downstream we effectively ignore
# the module names as of the current implementation.
self.existing_imports: Set[str] = existing_imports
# Module imports, gathered by prescanning the stub file to determine
# which modules need to be imported directly to qualify their symbols.
self.module_imports: Dict[str, ImportItem] = module_imports
# Fields that help us track temporary state as we recurse
self.qualifier: List[str] = []
self.current_assign: Optional[cst.Assign] = None # used to collect typevars
Expand Down Expand Up @@ -323,33 +488,11 @@ def leave_Module(
) -> None:
self.annotations.finish()

def _get_unique_qualified_name(
self,
node: cst.CSTNode,
) -> str:
name = None
names = [q.name for q in self.get_metadata(QualifiedNameProvider, node)]
if len(names) == 0:
# we hit this branch if the stub is directly using a fully
# qualified name, which is not technically valid python but is
# convenient to allow.
name = get_full_name_for_node(node)
elif len(names) == 1 and isinstance(names[0], str):
name = names[0]
if name is None:
start = self.get_metadata(PositionProvider, node).start
raise ValueError(
"Could not resolve a unique qualified name for type "
+ f"{get_full_name_for_node(node)} at {start.line}:{start.column}. "
+ f"Candidate names were: {names!r}"
)
return name

def _get_qualified_name_and_dequalified_node(
self,
node: Union[cst.Name, cst.Attribute],
) -> Tuple[str, Union[cst.Name, cst.Attribute]]:
qualified_name = self._get_unique_qualified_name(node)
qualified_name = _get_unique_qualified_name(self, node)
dequalified_node = node.attr if isinstance(node, cst.Attribute) else node
return qualified_name, dequalified_node

Expand Down Expand Up @@ -382,6 +525,16 @@ def _handle_qualification_and_should_qualify(
elif qualified_name not in self.existing_imports:
if module in self.existing_imports:
return True
elif module in self.module_imports:
m = self.module_imports[module]
if m.obj_name is None:
asname = m.alias
else:
asname = None
AddImportsVisitor.add_needed_import(
self.context, m.module_name, asname=asname
)
return True
else:
if node and isinstance(node, cst.Name) and node.value != target:
asname = node.value
Expand Down Expand Up @@ -443,7 +596,7 @@ def _handle_Subscript(
new_node = node.with_changes(value=self._handle_NameOrAttribute(value))
else:
raise ValueError("Expected any indexed type to have")
if self._get_unique_qualified_name(node) in ("Type", "typing.Type"):
if _get_unique_qualified_name(self, node) in ("Type", "typing.Type"):
# Note: we are intentionally not handling qualification of
# anything inside `Type` because it's common to have nested
# classes, which we cannot currently distinguish from classes
Expand Down Expand Up @@ -679,7 +832,8 @@ def transform_module_impl(
self.strict_annotation_matching = (
self.strict_annotation_matching or strict_annotation_matching
)
visitor = TypeCollector(existing_import_names, self.context)
module_imports = self._get_module_imports(stub, existing_import_names)
visitor = TypeCollector(existing_import_names, module_imports, self.context)
cst.MetadataWrapper(stub).visit(visitor)
self.annotations.update(visitor.annotations)

Expand All @@ -697,6 +851,54 @@ def transform_module_impl(
else:
return tree

# helpers for collecting type information from the stub files

def _get_module_imports(
self, stub: cst.Module, existing_import_names: Set[str]
) -> Dict[str, ImportItem]:
"""Returns a dict of modules that need to be imported to qualify symbols."""
# We correlate all imported symbols, e.g. foo.bar.Baz, with a list of module
# and from imports. If the same unqualified symbol is used from different
# modules, we give preference to an explicit from-import if any, and qualify
# everything else by importing the module.
#
# e.g. the following stub:
# import foo as quux
# from bar import Baz as X
# def f(x: X) -> quux.X: ...
# will return {'foo': ImportItem("foo", "quux")}. When the apply type
# annotation visitor hits `quux.X` it will retrieve the canonical name
# `foo.X` and then note that `foo` is in the module imports map, so it will
# leave the symbol qualified.
import_gatherer = GatherImportsVisitor(CodemodContext())
stub.visit(import_gatherer)
symbol_map = import_gatherer.symbol_mapping
isc = ImportedSymbolCollector(existing_import_names, self.context)
cst.MetadataWrapper(stub).visit(isc)
module_imports = {}
for sym, isyms in isc.imported_symbols.items():
martindemello marked this conversation as resolved.
Show resolved Hide resolved
if len(isyms) == 1:
# If we have a single use of a symbol we can from-import it
continue
used = False
for isym in isyms:
if not isym.symbol:
continue
imp = symbol_map.get(isym.symbol)
if not used and imp and imp.module_name == isym.module_name:
# We can only import a symbol directly once.
used = True
else:
imp = symbol_map.get(isym.module_symbol)
if imp:
# imp will be None in corner cases like
# import foo.bar as Baz
# x: Baz
# which is technically valid python but nonsensical as a
# type annotation. Dropping it on the floor for now.
module_imports[imp.module_name] = imp
return module_imports

# helpers for processing annotation nodes
def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation:
# TODO: We probably want to make sure references to classes defined in the current
Expand Down
12 changes: 12 additions & 0 deletions libcst/codemod/visitors/_gather_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import libcst
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_for_import


Expand Down Expand Up @@ -60,19 +61,24 @@ def __init__(self, context: CodemodContext) -> None:
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {}
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
# Track the import for every symbol introduced into the module
self.symbol_mapping: Dict[str, ImportItem] = {}

def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)

for name in node.names:
alias = name.evaluated_alias
imp = ImportItem(name.evaluated_name, alias=alias)
if alias is not None:
# Track this as an aliased module
self.module_aliases[name.evaluated_name] = alias
self.symbol_mapping[alias] = imp
else:
# Get the module we're importing as a string.
self.module_imports.add(name.evaluated_name)
self.symbol_mapping[name.evaluated_name] = imp

def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
Expand Down Expand Up @@ -114,3 +120,9 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
return

self.object_mapping[module].update(new_objects)
for ia in nodenames:
martindemello marked this conversation as resolved.
Show resolved Hide resolved
imp = ImportItem(
module, obj_name=ia.evaluated_name, alias=ia.evaluated_alias
)
key = ia.evaluated_alias or ia.evaluated_name
self.symbol_mapping[key] = imp
Loading