Skip to content

Commit

Permalink
Improved the protocol matching logic so it honors partially-solved ty…
Browse files Browse the repository at this point in the history
…pe variables whose values are provided by other argument types in a call. This addresses #5243. (#5273)

Co-authored-by: Eric Traut <[email protected]>
  • Loading branch information
erictraut and msfterictraut authored Jun 12, 2023
1 parent 783633a commit 7ffe61a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 2 deletions.
43 changes: 41 additions & 2 deletions packages/pyright-internal/src/analyzer/protocols.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { assignProperty } from './properties';
import { TypeEvaluator } from './typeEvaluatorTypes';
import {
ClassType,
isAnyOrUnknown,
isClassInstance,
isFunction,
isInstantiableClass,
Expand All @@ -23,7 +24,9 @@ import {
maxTypeRecursionCount,
ModuleType,
Type,
TypeVarType,
UnknownType,
Variance,
} from './types';
import {
applySolvedTypeVars,
Expand All @@ -36,6 +39,7 @@ import {
partiallySpecializeType,
populateTypeVarContextForSelfType,
removeParamSpecVariadicsFromSignature,
requiresSpecialization,
} from './typeUtils';
import { TypeVarContext } from './typeVarContext';

Expand Down Expand Up @@ -145,7 +149,7 @@ function assignClassToProtocolInternal(
return isTypeSame(destType, srcType);
}

const protocolTypeVarContext = new TypeVarContext(getTypeVarScopeId(destType));
const protocolTypeVarContext = createProtocolTypeVarContext(evaluator, destType, destTypeVarContext);
const selfTypeVarContext = new TypeVarContext(getTypeVarScopeId(destType));
const noLiteralSrcType = evaluator.stripLiteralValue(srcType) as ClassType;
populateTypeVarContextForSelfType(selfTypeVarContext, destType, noLiteralSrcType);
Expand Down Expand Up @@ -484,7 +488,7 @@ export function assignModuleToProtocol(

let typesAreConsistent = true;
const checkedSymbolSet = new Set<string>();
const protocolTypeVarContext = new TypeVarContext(getTypeVarScopeId(destType));
const protocolTypeVarContext = createProtocolTypeVarContext(evaluator, destType, destTypeVarContext);

destType.details.mro.forEach((mroClass) => {
if (!isInstantiableClass(mroClass) || !ClassType.isProtocolClass(mroClass)) {
Expand Down Expand Up @@ -581,3 +585,38 @@ export function assignModuleToProtocol(

return typesAreConsistent;
}

function createProtocolTypeVarContext(
evaluator: TypeEvaluator,
destType: ClassType,
destTypeVarContext: TypeVarContext | undefined
) {
const protocolTypeVarContext = new TypeVarContext(getTypeVarScopeId(destType));
if (destTypeVarContext && destType?.typeArguments) {
// Infer the type parameter variance because we need it below.
evaluator.inferTypeParameterVarianceForClass(destType);

// Populate the typeVarContext with any concrete constraints that
// have already been solved.
const specializedDestType = applySolvedTypeVars(destType, destTypeVarContext, {
useNarrowBoundOnly: true,
}) as ClassType;
destType.details.typeParameters.forEach((typeParam, index) => {
if (index < specializedDestType.typeArguments!.length) {
const typeArg = specializedDestType.typeArguments![index];

if (!requiresSpecialization(typeArg) && !isAnyOrUnknown(typeArg)) {
const typeParamVariance = TypeVarType.getVariance(typeParam);
protocolTypeVarContext.setTypeVarType(
typeParam,
typeParamVariance !== Variance.Contravariant ? typeArg : undefined,
/* narrowBoundNoLiterals */ undefined,
typeParamVariance !== Variance.Covariant ? typeArg : undefined
);
}
}
});
}

return protocolTypeVarContext;
}
48 changes: 48 additions & 0 deletions packages/pyright-internal/src/tests/samples/protocol41.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This sample verifies that a generic protocol that is specialized with
# a type variable can be matched if that type variable's type is
# supplied by another argument in a call.

from typing import Protocol, TypeVar

_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)


class MyStr:
...


class MyBytes:
def __buffer__(self, __flags: int) -> memoryview:
...


MyAnyStr = TypeVar("MyAnyStr", MyStr, MyBytes)


class Buffer(Protocol):
def __buffer__(self, __flags: int) -> memoryview:
...


class SupportsRead(Protocol[_T_co]):
def read(self, __length: int = ...) -> _T_co:
...


class SupportsWrite(Protocol[_T_contra]):
def write(self, __s: _T_contra) -> object:
...


class BufferedWriter:
def write(self, __buffer: Buffer) -> int:
raise NotImplementedError


def f(s: SupportsRead[MyAnyStr], t: SupportsWrite[MyAnyStr]) -> None:
...


def h(src: SupportsRead[MyBytes], tgt: BufferedWriter) -> None:
f(src, tgt)
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,12 @@ test('Protocol40', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('Protocol41', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol41.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('TypedDict1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typedDict1.py']);

Expand Down

0 comments on commit 7ffe61a

Please sign in to comment.