From dfdcd0e1c9cc32dc8b2bfb6613ea6b5fe001ab76 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Wed, 20 Oct 2021 14:44:41 +0300 Subject: [PATCH] Now overloads with ambiguous `self` are handled properly, refs #11347 --- mypy/checkexpr.py | 16 +++++-- test-data/unit/check-overloading.test | 62 +++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e8a2e501a452..717bf737ac46 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2820,6 +2820,7 @@ def infer_overload_return_type( matches: list[CallableType] = [] return_types: list[Type] = [] inferred_types: list[Type] = [] + self_contains_any = has_any_type(object_type) if object_type is not None else False args_contain_any = any(map(has_any_type, arg_types)) type_maps: list[dict[Expression, Type]] = [] @@ -2840,7 +2841,7 @@ def infer_overload_return_type( if is_match: # Return early if possible; otherwise record info, so we can # check for ambiguity due to 'Any' below. - if not args_contain_any: + if not args_contain_any and not self_contains_any: self.chk.store_types(m) return ret_type, infer_type p_infer_type = get_proper_type(infer_type) @@ -2856,7 +2857,14 @@ def infer_overload_return_type( if not matches: return None - elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names): + elif any_causes_overload_ambiguity( + matches, + return_types, + arg_types, + arg_kinds, + arg_names, + self_contains_any=self_contains_any, + ): # An argument of type or containing the type 'Any' caused ambiguity. # We try returning a precise type if we can. If not, we give up and just return 'Any'. if all_same_types(return_types): @@ -6452,6 +6460,8 @@ def any_causes_overload_ambiguity( arg_types: list[Type], arg_kinds: list[ArgKind], arg_names: Sequence[str | None] | None, + *, + self_contains_any: bool = False, ) -> bool: """May an argument containing 'Any' cause ambiguous result type on call to overloaded function? @@ -6501,7 +6511,7 @@ def any_causes_overload_ambiguity( if not all_same_types(matching_formals) and not all_same_types(matching_returns): # Any maps to multiple different types, and the return types of these items differ. return True - return False + return self_contains_any def all_same_types(types: list[Type]) -> bool: diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 7bca5cc7b508..6d4d83b72897 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6693,3 +6693,65 @@ class B: def f(self, *args, **kwargs): pass [builtins fixtures/tuple.pyi] + +[case testOverloadSelfArgWithMultipleMatches] +# https://github.com/python/mypy/issues/11347 +from typing import Generic, TypeVar, overload, Any + +T = TypeVar('T') + +class Some(Generic[T]): + @overload + def method(self: Some[int]) -> str: ... + @overload + def method(self: Some[str]) -> float: ... + def method(self): ... + +s1: Some[int] +s2: Some[str] +s3: Some[Any] +reveal_type(s1.method()) # N: Revealed type is "builtins.str" +reveal_type(s2.method()) # N: Revealed type is "builtins.float" +reveal_type(s3.method()) # N: Revealed type is "Any" +[builtins fixtures/dict.pyi] + +[case testOverloadSelfArgWithOtherSameArgAndMultipleMatches] +from typing import Generic, TypeVar, overload, Any + +T = TypeVar('T') + +class Some(Generic[T]): + @overload + def method(self: Some[int], other: int) -> str: ... + @overload + def method(self: Some[str], other: int) -> float: ... + def method(self): ... + +# was ok +s1: Some[int] +s2: Some[str] +s3: Some[Any] +reveal_type(s1.method(1)) # N: Revealed type is "builtins.str" +reveal_type(s2.method(1)) # N: Revealed type is "builtins.float" +reveal_type(s3.method(1)) # N: Revealed type is "Any" +[builtins fixtures/dict.pyi] + +[case testOverloadSelfArgWithOtherDifferentArgAndMultipleMatches] +from typing import Generic, TypeVar, overload, Any + +T = TypeVar('T') + +class Some(Generic[T]): + @overload + def method(self: Some[int], other: int) -> str: ... + @overload + def method(self: Some[str], other: str) -> float: ... + def method(self): ... + +s1: Some[int] +s2: Some[str] +s3: Some[Any] +reveal_type(s1.method(1)) # N: Revealed type is "builtins.str" +reveal_type(s2.method('a')) # N: Revealed type is "builtins.float" +reveal_type(s3.method(1)) # N: Revealed type is "builtins.str" +[builtins fixtures/dict.pyi]