Skip to content

Commit

Permalink
Support relative imports in ATAV qualifier handling
Browse files Browse the repository at this point in the history
Based on diff review of #536,
I investigated relatvie import handling and realized that with minor
changes we can now handle them correctly.

Relative imports aren't likely in code coming from an automated
tool, but they could happen in hand-written stubs if anyone tries
to use this codemod tool to merge stubs with code.

Added a new test:
```
> python -m unittest libcst.codemod.visitors.tests.test_apply_type_annotations
.............................................
----------------------------------------------------------------------
Ran 45 tests in 2.195s

OK

```
  • Loading branch information
stroxler committed Oct 28, 2021
1 parent 3743c70 commit 20b0070
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
19 changes: 16 additions & 3 deletions libcst/codemod/visitors/_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,28 @@ def _get_qualified_name_and_dequalified_node(
dequalified_node = node.attr if isinstance(node, cst.Attribute) else node
return qualified_name, dequalified_node

def _module_and_target(self, qualified_name: str) -> Tuple[str, str]:
relative_prefix = ""
while qualified_name.startswith("."):
relative_prefix += "."
qualified_name = qualified_name[1:]
split = qualified_name.rsplit(".", 1)
if len(split) == 1:
qualifier, target = "", split[0]
else:
qualifier, target = split
return (relative_prefix + qualifier, target)

def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool:
"""
Basd on a qualified name and the existing module imports, record that
we need to add an import if necessary and return whether or not we
should use the qualified name due to a preexisting import.
"""
split_name = qualified_name.split(".")
if len(split_name) > 1 and qualified_name not in self.existing_imports:
module, target = ".".join(split_name[:-1]), split_name[-1]
module, target = self._module_and_target(qualified_name)
if module in ("", "builtins"):
return False
elif qualified_name not in self.existing_imports:
if module == "builtins":
return False
elif module in self.existing_imports:
Expand Down
26 changes: 26 additions & 0 deletions libcst/codemod/visitors/tests/test_apply_type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,32 @@ def run_test_case_with_flags(
FOO: Union[Example, int] = bar()
""",
),
"with_relative_imports": (
"""
from .relative0 import T0
from ..relative1 import T1
from . import relative2
x0: typing.Optional[T0]
x1: typing.Optional[T1]
x2: typing.Optional[relative2.T2]
""",
"""
x0 = None
x1 = None
x2 = None
""",
"""
from ..relative1 import T1
from .relative0 import T0
from .relative2 import T2
from typing import Optional
x0: Optional[T0] = None
x1: Optional[T1] = None
x2: Optional[T2] = None
""",
),
}
)
def test_annotate_globals(self, stub: str, before: str, after: str) -> None:
Expand Down

0 comments on commit 20b0070

Please sign in to comment.