Skip to content

Commit

Permalink
Add support for conditionally defined overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Nov 7, 2021
1 parent 2db0511 commit 424ba75
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 1 deletion.
53 changes: 52 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from mypy import message_registry, errorcodes as codes
from mypy.errors import Errors
from mypy.options import Options
from mypy.reachability import mark_block_unreachable
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable

try:
# pull this into a final variable to make mypyc be quiet about the
Expand Down Expand Up @@ -447,12 +447,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret: List[Statement] = []
current_overload: List[OverloadPart] = []
current_overload_name: Optional[str] = None
last_if_stmt: Optional[IfStmt] = None
last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None
for stmt in stmts:
if (current_overload_name is not None
and isinstance(stmt, (Decorator, FuncDef))
and stmt.name == current_overload_name):
if last_if_overload is not None:
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
current_overload.append(stmt)
elif (
current_overload_name is not None
and isinstance(stmt, IfStmt)
and len(stmt.body[0].body) == 1
and isinstance(
stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
and stmt.body[0].body[0].name == current_overload_name
):
# IfStmt only contains stmts relevant to current_overload.
# Check if stmts are reachable and add them to current_overload,
# otherwise skip IfStmt to allow subsequent overload
# or function definitions.
infer_reachability_of_if_statement(stmt, self.options)
if stmt.body[0].is_unreachable is True:
continue
if last_if_overload is not None:
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
last_if_overload = None
if isinstance(stmt.body[0].body[0], OverloadedFuncDef):
current_overload.extend(stmt.body[0].body[0].items)
else:
current_overload.append(stmt.body[0].body[0])
else:
if last_if_stmt is not None:
ret.append(last_if_stmt)
last_if_stmt, last_if_overload = None, None

if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
Expand All @@ -466,6 +504,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
current_overload = [stmt]
current_overload_name = stmt.name
elif (
isinstance(stmt, IfStmt)
and len(stmt.body[0].body) == 1
and isinstance(
stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
and infer_reachability_of_if_statement(
stmt, self.options
) is None # type: ignore[func-returns-value]
and stmt.body[0].is_unreachable is False
):
current_overload_name = stmt.body[0].body[0].name
last_if_stmt = stmt
last_if_overload = stmt.body[0].body[0]
else:
current_overload = []
current_overload_name = None
Expand Down
157 changes: 157 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -5339,3 +5339,160 @@ def register(cls: Any) -> Any: return None
x = register(Foo)
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/dict.pyi]

[case testOverloadIfBasic]
# flags: --always-true True
from typing import overload, Any

class A: ...
class B: ...

@overload
def f1(g: int) -> A: ...
if True:
@overload
def f1(g: str) -> B: ...
def f1(g: Any) -> Any: ...
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"

@overload
def f2(g: int) -> A: ...
@overload
def f2(g: bytes) -> A: ...
if not True:
@overload
def f2(g: str) -> B: ...
def f2(g: Any) -> Any: ...
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
# N: Possible overload variants: \
# N: def f2(g: int) -> A \
# N: def f2(g: bytes) -> A \
# N: Revealed type is "Any"

[case testOverloadIfSysVersion]
# flags: --python-version 3.9
from typing import overload, Any
import sys

class A: ...
class B: ...

@overload
def f1(g: int) -> A: ...
if sys.version_info >= (3, 9):
@overload
def f1(g: str) -> B: ...
def f1(g: Any) -> Any: ...
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"

@overload
def f2(g: int) -> A: ...
@overload
def f2(g: bytes) -> A: ...
if sys.version_info >= (3, 10):
@overload
def f2(g: str) -> B: ...
def f2(g: Any) -> Any: ...
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
# N: Possible overload variants: \
# N: def f2(g: int) -> A \
# N: def f2(g: bytes) -> A \
# N: Revealed type is "Any"
[builtins fixtures/tuple.pyi]

[case testOverloadIfMatching]
from typing import overload, Any

class A: ...
class B: ...
class C: ...

@overload
def f1(g: int) -> A: ...
if True:
# Some comment
@overload
def f1(g: str) -> B: ...
def f1(g: Any) -> Any: ...
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"

@overload
def f2(g: int) -> A: ...
if True:
@overload
def f2(g: bytes) -> B: ...
@overload
def f2(g: str) -> C: ...
def f2(g: Any) -> Any: ...
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
reveal_type(f2("Hello")) # N: Revealed type is "__main__.C"

@overload
def f3(g: int) -> A: ...
@overload
def f3(g: str) -> B: ...
if True:
def f3(g: Any) -> Any: ...
reveal_type(f3(42)) # N: Revealed type is "__main__.A"
reveal_type(f3("Hello")) # N: Revealed type is "__main__.B"

if True:
@overload
def f4(g: int) -> A: ...
@overload
def f4(g: str) -> B: ...
def f4(g: Any) -> Any: ...
reveal_type(f4(42)) # N: Revealed type is "__main__.A"
reveal_type(f4("Hello")) # N: Revealed type is "__main__.B"

if True:
# Some comment
@overload
def f5(g: int) -> A: ...
@overload
def f5(g: str) -> B: ...
def f5(g: Any) -> Any: ...
reveal_type(f5(42)) # N: Revealed type is "__main__.A"
reveal_type(f5("Hello")) # N: Revealed type is "__main__.B"

[case testOverloadIfNotMatching]
from typing import overload, Any

class A: ...
class B: ...
class C: ...

@overload # E: An overloaded function outside a stub file must have an implementation
def f1(g: int) -> A: ...
@overload
def f1(g: bytes) -> B: ...
if True:
@overload # E: Name "f1" already defined on line 7 \
# E: Single overload definition, multiple required
def f1(g: str) -> C: ...
pass # Some other action
def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \
# N: Possible overload variants: \
# N: def f1(g: int) -> A \
# N: def f1(g: bytes) -> B \
# N: Revealed type is "Any"

if True:
pass # Some other action
@overload # E: Single overload definition, multiple required
def f2(g: int) -> A: ...
@overload # E: Name "f2" already defined on line 21
def f2(g: bytes) -> B: ...
@overload
def f2(g: str) -> C: ...
def f2(g: Any) -> Any: ...
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \
# E: Argument 1 to "f2" has incompatible type "str"; expected "int"

0 comments on commit 424ba75

Please sign in to comment.