diff --git a/mypy/fastparse.py b/mypy/fastparse.py index a0d0ec8e34b0f..56723ee295f04 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -39,7 +39,7 @@ from mypy import message_registry, errorcodes as codes from mypy.errors import Errors from mypy.options import Options -from mypy.reachability import mark_block_unreachable +from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable try: # pull this into a final variable to make mypyc be quiet about the @@ -447,12 +447,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ret: List[Statement] = [] current_overload: List[OverloadPart] = [] current_overload_name: Optional[str] = None + last_if_stmt: Optional[IfStmt] = None + last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None for stmt in stmts: if (current_overload_name is not None and isinstance(stmt, (Decorator, FuncDef)) and stmt.name == current_overload_name): + if last_if_overload is not None: + if isinstance(last_if_overload, OverloadedFuncDef): + current_overload.extend(last_if_overload.items) + else: + current_overload.append(last_if_overload) + last_if_stmt, last_if_overload = None, None current_overload.append(stmt) + elif ( + current_overload_name is not None + and isinstance(stmt, IfStmt) + and len(stmt.body[0].body) == 1 + and isinstance( + stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef)) + and stmt.body[0].body[0].name == current_overload_name + ): + # IfStmt only contains stmts relevant to current_overload. + # Check if stmts are reachable and add them to current_overload, + # otherwise skip IfStmt to allow subsequent overload + # or function definitions. + infer_reachability_of_if_statement(stmt, self.options) + if stmt.body[0].is_unreachable is True: + continue + if last_if_overload is not None: + if isinstance(last_if_overload, OverloadedFuncDef): + current_overload.extend(last_if_overload.items) + else: + current_overload.append(last_if_overload) + last_if_stmt, last_if_overload = None, None + last_if_overload = None + if isinstance(stmt.body[0].body[0], OverloadedFuncDef): + current_overload.extend(stmt.body[0].body[0].items) + else: + current_overload.append(stmt.body[0].body[0]) else: + if last_if_stmt is not None: + ret.append(last_if_stmt) + last_if_stmt, last_if_overload = None, None + if len(current_overload) == 1: ret.append(current_overload[0]) elif len(current_overload) > 1: @@ -466,6 +504,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if isinstance(stmt, Decorator) and not unnamed_function(stmt.name): current_overload = [stmt] current_overload_name = stmt.name + elif ( + isinstance(stmt, IfStmt) + and len(stmt.body[0].body) == 1 + and isinstance( + stmt.body[0].body[0], (Decorator, OverloadedFuncDef)) + and infer_reachability_of_if_statement( + stmt, self.options + ) is None # type: ignore[func-returns-value] + and stmt.body[0].is_unreachable is False + ): + current_overload_name = stmt.body[0].body[0].name + last_if_stmt = stmt + last_if_overload = stmt.body[0].body[0] else: current_overload = [] current_overload_name = None diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index bf7acdc1cd51c..d6831c03a69a0 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -5339,3 +5339,160 @@ def register(cls: Any) -> Any: return None x = register(Foo) reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/dict.pyi] + +[case testOverloadIfBasic] +# flags: --always-true True +from typing import overload, Any + +class A: ... +class B: ... + +@overload +def f1(g: int) -> A: ... +if True: + @overload + def f1(g: str) -> B: ... +def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + +@overload +def f2(g: int) -> A: ... +@overload +def f2(g: bytes) -> A: ... +if not True: + @overload + def f2(g: str) -> B: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \ + # N: Possible overload variants: \ + # N: def f2(g: int) -> A \ + # N: def f2(g: bytes) -> A \ + # N: Revealed type is "Any" + +[case testOverloadIfSysVersion] +# flags: --python-version 3.9 +from typing import overload, Any +import sys + +class A: ... +class B: ... + +@overload +def f1(g: int) -> A: ... +if sys.version_info >= (3, 9): + @overload + def f1(g: str) -> B: ... +def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + +@overload +def f2(g: int) -> A: ... +@overload +def f2(g: bytes) -> A: ... +if sys.version_info >= (3, 10): + @overload + def f2(g: str) -> B: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \ + # N: Possible overload variants: \ + # N: def f2(g: int) -> A \ + # N: def f2(g: bytes) -> A \ + # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] + +[case testOverloadIfMatching] +from typing import overload, Any + +class A: ... +class B: ... +class C: ... + +@overload +def f1(g: int) -> A: ... +if True: + # Some comment + @overload + def f1(g: str) -> B: ... +def f1(g: Any) -> Any: ... +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # N: Revealed type is "__main__.B" + +@overload +def f2(g: int) -> A: ... +if True: + @overload + def f2(g: bytes) -> B: ... + @overload + def f2(g: str) -> C: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # N: Revealed type is "__main__.C" + +@overload +def f3(g: int) -> A: ... +@overload +def f3(g: str) -> B: ... +if True: + def f3(g: Any) -> Any: ... +reveal_type(f3(42)) # N: Revealed type is "__main__.A" +reveal_type(f3("Hello")) # N: Revealed type is "__main__.B" + +if True: + @overload + def f4(g: int) -> A: ... +@overload +def f4(g: str) -> B: ... +def f4(g: Any) -> Any: ... +reveal_type(f4(42)) # N: Revealed type is "__main__.A" +reveal_type(f4("Hello")) # N: Revealed type is "__main__.B" + +if True: + # Some comment + @overload + def f5(g: int) -> A: ... + @overload + def f5(g: str) -> B: ... +def f5(g: Any) -> Any: ... +reveal_type(f5(42)) # N: Revealed type is "__main__.A" +reveal_type(f5("Hello")) # N: Revealed type is "__main__.B" + +[case testOverloadIfNotMatching] +from typing import overload, Any + +class A: ... +class B: ... +class C: ... + +@overload # E: An overloaded function outside a stub file must have an implementation +def f1(g: int) -> A: ... +@overload +def f1(g: bytes) -> B: ... +if True: + @overload # E: Name "f1" already defined on line 7 \ + # E: Single overload definition, multiple required + def f1(g: str) -> C: ... + pass # Some other action +def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7 +reveal_type(f1(42)) # N: Revealed type is "__main__.A" +reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \ + # N: Possible overload variants: \ + # N: def f1(g: int) -> A \ + # N: def f1(g: bytes) -> B \ + # N: Revealed type is "Any" + +if True: + pass # Some other action + @overload # E: Single overload definition, multiple required + def f2(g: int) -> A: ... +@overload # E: Name "f2" already defined on line 21 +def f2(g: bytes) -> B: ... +@overload +def f2(g: str) -> C: ... +def f2(g: Any) -> Any: ... +reveal_type(f2(42)) # N: Revealed type is "__main__.A" +reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \ + # E: Argument 1 to "f2" has incompatible type "str"; expected "int"