diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 2ac44c023..a459d300e 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -145,15 +145,28 @@ def _get_qualified_name_and_dequalified_node( dequalified_node = node.attr if isinstance(node, cst.Attribute) else node return qualified_name, dequalified_node + def _module_and_target(self, qualified_name: str) -> Tuple[str, str]: + relative_prefix = "" + while qualified_name.startswith("."): + relative_prefix += "." + qualified_name = qualified_name[1:] + split = qualified_name.rsplit(".", 1) + if len(split) == 1: + qualifier, target = "", split[0] + else: + qualifier, target = split + return (relative_prefix + qualifier, target) + def _handle_qualification_and_should_qualify(self, qualified_name: str) -> bool: """ Basd on a qualified name and the existing module imports, record that we need to add an import if necessary and return whether or not we should use the qualified name due to a preexisting import. """ - split_name = qualified_name.split(".") - if len(split_name) > 1 and qualified_name not in self.existing_imports: - module, target = ".".join(split_name[:-1]), split_name[-1] + module, target = self._module_and_target(qualified_name) + if module in ("", "builtins"): + return False + elif qualified_name not in self.existing_imports: if module == "builtins": return False elif module in self.existing_imports: diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 4a63ab29e..54aec5d4e 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -123,6 +123,32 @@ def run_test_case_with_flags( FOO: Union[Example, int] = bar() """, ), + "with_relative_imports": ( + """ + from .relative0 import T0 + from ..relative1 import T1 + from . import relative2 + + x0: typing.Optional[T0] + x1: typing.Optional[T1] + x2: typing.Optional[relative2.T2] + """, + """ + x0 = None + x1 = None + x2 = None + """, + """ + from ..relative1 import T1 + from .relative0 import T0 + from .relative2 import T2 + from typing import Optional + + x0: Optional[T0] = None + x1: Optional[T1] = None + x2: Optional[T2] = None + """, + ), } ) def test_annotate_globals(self, stub: str, before: str, after: str) -> None: