Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type argument inference for overloaded functions with explicit self types (Fixes #14943). #14975

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
9 changes: 9 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
# Supported for both proper and non-proper
ignore_promotions: bool = False,
ignore_uninhabited: bool = False,
ignore_type_vars: bool = False,
# Proper subtype flags
erase_instances: bool = False,
keep_erased_types: bool = False,
Expand All @@ -96,6 +97,7 @@ def __init__(
self.ignore_declared_variance = ignore_declared_variance
self.ignore_promotions = ignore_promotions
self.ignore_uninhabited = ignore_uninhabited
self.ignore_type_vars = ignore_type_vars
self.erase_instances = erase_instances
self.keep_erased_types = keep_erased_types
self.options = options
Expand All @@ -119,6 +121,7 @@ def is_subtype(
ignore_declared_variance: bool = False,
ignore_promotions: bool = False,
ignore_uninhabited: bool = False,
ignore_type_vars: bool = False,
options: Options | None = None,
) -> bool:
"""Is 'left' subtype of 'right'?
Expand All @@ -139,6 +142,7 @@ def is_subtype(
ignore_declared_variance=ignore_declared_variance,
ignore_promotions=ignore_promotions,
ignore_uninhabited=ignore_uninhabited,
ignore_type_vars=ignore_type_vars,
options=options,
)
else:
Expand Down Expand Up @@ -287,6 +291,11 @@ def _is_subtype(
# ErasedType as we do for non-proper subtyping.
return True

if subtype_context.ignore_type_vars and (
isinstance(left, TypeVarType) or isinstance(right, TypeVarType)
):
return True

if isinstance(right, UnionType) and not isinstance(left, UnionType):
# Normally, when 'left' is not itself a union, the only way
# 'left' can be a subtype of the union 'right' is if it is a
Expand Down
20 changes: 19 additions & 1 deletion mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,28 @@ class B(A): pass
b = B().copy() # type: B

"""

from mypy.subtypes import is_subtype

if isinstance(method, Overloaded):
# Try to remove overload items with non-matching self types first (fixes #14943)
origtype = get_proper_type(original_type)
if isinstance(origtype, Instance):
methoditems = []
for mi in method.items:
selftype = get_proper_type(mi.arg_types[0])
if not isinstance(selftype, Instance) or is_subtype(
origtype, selftype, ignore_type_vars=True
):
methoditems.append(mi)
if len(methoditems) == 0:
methoditems = method.items
else:
methoditems = method.items
return cast(
F, Overloaded([bind_self(c, original_type, is_classmethod) for c in method.items])
F, Overloaded([bind_self(mi, original_type, is_classmethod) for mi in methoditems])
)

assert isinstance(method, CallableType)
func = method
if not func.arg_types:
Expand Down
39 changes: 39 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -4020,3 +4020,42 @@ class P(Protocol):

[file lib.py]
class C: ...

[case TestOverloadedMethodWithExplictSelfTypes]
from typing import Generic, overload, Protocol, TypeVar, Union

AnyStr = TypeVar("AnyStr", str, bytes)
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

class SupportsRead(Protocol[T_co]):
def read(self) -> T_co: ...

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra) -> int: ...

class Input(Generic[AnyStr]):
def read(self) -> AnyStr: ...

class Output(Generic[AnyStr]):
@overload
def write(self: Output[str], s: str) -> int: ...
@overload
def write(self: Output[bytes], s: bytes) -> int: ...
def write(self, s: Union[str, bytes]) -> int: ...

def f(src: SupportsRead[AnyStr], dst: SupportsWrite[AnyStr]) -> None: ...

def g1(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g2(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g3(a: Input[str], b: Output[bytes]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

def g4(a: Input[bytes], b: Output[str]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

[builtins fixtures/tuple.pyi]