Skip to content

Commit

Permalink
Fixed a bug in type evaluation of the two-argument form of the `super…
Browse files Browse the repository at this point in the history
…` call. There were situations where the incorrect MRO class was used. This addresses #5299.
  • Loading branch information
msfterictraut committed Jun 14, 2023
1 parent 7bfe315 commit e805ca9
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 28 deletions.
1 change: 1 addition & 0 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5450,6 +5450,7 @@ export class Checker extends ParseTreeWalker {
continue;
}

assert(isClass(mroBaseClass));
const baseClassAndSymbol = lookUpClassMember(mroBaseClass, name, ClassMemberLookupFlags.Default);
if (!baseClassAndSymbol) {
continue;
Expand Down
10 changes: 5 additions & 5 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8089,11 +8089,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
const parentNode = node.parent!;
if (parentNode.nodeType === ParseNodeType.MemberAccess) {
const memberName = parentNode.memberName.value;
const lookupResults = lookUpClassMember(
targetClassType,
memberName,
ClassMemberLookupFlags.SkipOriginalClass
);
const effectiveTargetClass = isClass(targetClassType) ? targetClassType : undefined;

const lookupResults = bindToType
? lookUpClassMember(bindToType, memberName, ClassMemberLookupFlags.Default, effectiveTargetClass)
: undefined;
if (lookupResults && isInstantiableClass(lookupResults.classType)) {
return {
type: resultIsInstance
Expand Down
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,7 @@ function narrowTypeForDiscriminatedLiteralFieldComparison(

// Handle the case where the field is a property
// that has a declared literal return type for its getter.
if (isClassInstance(subtype) && isProperty(memberType)) {
if (isClassInstance(subtype) && isClassInstance(memberType) && isProperty(memberType)) {
const getterInfo = lookUpObjectMember(memberType, 'fget');

if (getterInfo && getterInfo.isTypeDeclared) {
Expand Down
57 changes: 35 additions & 22 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,16 @@ export const enum ClassMemberLookupFlags {
export const enum ClassIteratorFlags {
Default = 0,

// By default, the original (derived) class is searched along
// with its base classes. If this flag is set, the original
// class is skipped and only the base classes are searched.
SkipOriginalClass = 1 << 0,

// By default, base classes are searched as well as the
// original (derived) class. If this flag is set, no recursion
// is performed.
SkipBaseClasses = 1 << 1,
SkipBaseClasses = 1 << 0,

// Skip the 'object' base class in particular.
SkipObjectBaseClass = 1 << 2,
SkipObjectBaseClass = 1 << 1,

// Skip the 'type' base class in particular.
SkipTypeBaseClass = 1 << 3,
SkipTypeBaseClass = 1 << 2,
}

export const enum AssignTypeFlags {
Expand Down Expand Up @@ -1222,12 +1217,13 @@ export function getContainerDepth(type: Type, recursionCount = 0) {
}

export function lookUpObjectMember(
objectType: Type,
objectType: ClassType,
memberName: string,
flags = ClassMemberLookupFlags.Default
flags = ClassMemberLookupFlags.Default,
skipMroClass?: ClassType | undefined
): ClassMember | undefined {
if (isClassInstance(objectType)) {
return lookUpClassMember(objectType, memberName, flags);
return lookUpClassMember(objectType, memberName, flags, skipMroClass);
}

return undefined;
Expand All @@ -1236,11 +1232,12 @@ export function lookUpObjectMember(
// Looks up a member in a class using the multiple-inheritance rules
// defined by Python.
export function lookUpClassMember(
classType: Type,
classType: ClassType,
memberName: string,
flags = ClassMemberLookupFlags.Default
flags = ClassMemberLookupFlags.Default,
skipMroClass?: ClassType | undefined
): ClassMember | undefined {
const memberItr = getClassMemberIterator(classType, memberName, flags);
const memberItr = getClassMemberIterator(classType, memberName, flags, skipMroClass);

return memberItr.next()?.value;
}
Expand All @@ -1253,14 +1250,23 @@ export function lookUpClassMember(
// ClassB[str] which inherits from Dict[_T1, int], a search for '__iter__'
// would return a class type of Dict[str, int] and a symbolType of
// (self) -> Iterator[str].
export function* getClassMemberIterator(classType: Type, memberName: string, flags = ClassMemberLookupFlags.Default) {
// If skipMroClass is defined, all MRO classes up to and including that class
// are skipped.
export function* getClassMemberIterator(
classType: ClassType | AnyType | UnknownType,
memberName: string,
flags = ClassMemberLookupFlags.Default,
skipMroClass?: ClassType | undefined
) {
const declaredTypesOnly = (flags & ClassMemberLookupFlags.DeclaredTypesOnly) !== 0;
let skippedUndeclaredType = false;

if (isClass(classType)) {
let classFlags = ClassIteratorFlags.Default;
if (flags & ClassMemberLookupFlags.SkipOriginalClass) {
classFlags = classFlags | ClassIteratorFlags.SkipOriginalClass;
if (isClass(classType)) {
skipMroClass = classType;
}
}
if (flags & ClassMemberLookupFlags.SkipBaseClasses) {
classFlags = classFlags | ClassIteratorFlags.SkipBaseClasses;
Expand All @@ -1272,7 +1278,7 @@ export function* getClassMemberIterator(classType: Type, memberName: string, fla
classFlags = classFlags | ClassIteratorFlags.SkipTypeBaseClass;
}

const classItr = getClassIterator(classType, classFlags);
const classItr = getClassIterator(classType, classFlags, skipMroClass);

for (const [mroClass, specializedMroClass] of classItr) {
if (!isInstantiableClass(mroClass)) {
Expand Down Expand Up @@ -1377,14 +1383,21 @@ export function* getClassMemberIterator(classType: Type, memberName: string, fla
return undefined;
}

export function* getClassIterator(classType: Type, flags = ClassIteratorFlags.Default) {
export function* getClassIterator(classType: Type, flags = ClassIteratorFlags.Default, skipMroClass?: ClassType) {
if (isClass(classType)) {
let skipMroEntry = (flags & ClassIteratorFlags.SkipOriginalClass) !== 0;
let foundSkipMroClass = skipMroClass === undefined;

for (const mroClass of classType.details.mro) {
if (skipMroEntry) {
skipMroEntry = false;
continue;
// Are we still searching fro teh skipMroClass?
if (!foundSkipMroClass && skipMroClass) {
if (!isClass(mroClass)) {
foundSkipMroClass = true;
} else if (ClassType.isSameGenericClass(mroClass, skipMroClass)) {
foundSkipMroClass = true;
continue;
} else {
continue;
}
}

// If mroClass is an ancestor of classType, partially specialize
Expand Down
35 changes: 35 additions & 0 deletions packages/pyright-internal/src/tests/samples/super2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,38 @@ def factoryB(cls):

b2 = B.factoryB()
reveal_type(b2, expected_text="B")


class C:
def __init__(self) -> None:
...


class CChild(C):
def __init__(self, name: str) -> None:
...


class D:
def __init__(self, name: str, num: int):
...


class DChild1(CChild, D):
def __init__(self, name: str, num: int) -> None:
super(C, self).__init__(name, num)


class DChild2(CChild, D):
def __init__(self, name: str) -> None:
super(DChild2, self).__init__(name)


class DChild3(CChild, D):
def __init__(self) -> None:
super(CChild, self).__init__()


d1 = DChild1("", 1)
d2 = DChild2("")
d3 = DChild3()

0 comments on commit e805ca9

Please sign in to comment.