diff --git a/packages/pyright-internal/src/analyzer/dataClasses.ts b/packages/pyright-internal/src/analyzer/dataClasses.ts index a2ec3f02642c..0ed2710d61a4 100644 --- a/packages/pyright-internal/src/analyzer/dataClasses.ts +++ b/packages/pyright-internal/src/analyzer/dataClasses.ts @@ -25,10 +25,11 @@ import { TypeAnnotationNode, } from '../parser/parseNodes'; import * as AnalyzerNodeInfo from './analyzerNodeInfo'; +import { getFileInfo } from './analyzerNodeInfo'; import { createFunctionFromConstructor } from './constructors'; import { DeclarationType } from './declaration'; import { updateNamedTupleBaseClass } from './namedTuples'; -import { getEnclosingClassOrFunction, getScopeIdForNode } from './parseTreeUtils'; +import { getClassFullName, getEnclosingClassOrFunction, getScopeIdForNode, getTypeSourceId } from './parseTreeUtils'; import { evaluateStaticBoolExpression } from './staticExpressions'; import { Symbol, SymbolFlags } from './symbol'; import { isPrivateName } from './symbolNameUtils'; @@ -513,7 +514,19 @@ export function synthesizeDataClassMethods( effectiveType = transformDescriptorType(evaluator, effectiveType); if (entry.converter) { + const fieldType = effectiveType; effectiveType = getConverterInputType(evaluator, entry.converter, effectiveType, entry.name); + symbolTable.set( + entry.name, + getDescriptorForConverterField( + evaluator, + node, + entry.converter, + entry.name, + fieldType, + effectiveType + ) + ); } const effectiveName = entry.alias || entry.name; @@ -799,6 +812,88 @@ function getConverterAsFunction( return undefined; } +// Synthesizes an asymmetric descriptor class to be used in place of the +// annotated type of a field with a converter. The descriptor's __get__ method +// returns the declared type of the field and its __set__ method accepts the +// converter's input type. Returns the symbol for an instance of this descriptor +// type. +function getDescriptorForConverterField( + evaluator: TypeEvaluator, + dataclassNode: ParseNode, + converterNode: ParseNode, + fieldName: string, + getType: Type, + setType: Type +): Symbol { + const fileInfo = getFileInfo(dataclassNode); + const typeMetaclass = evaluator.getBuiltInType(dataclassNode, 'type'); + const descriptorName = `__converterDescriptor_${fieldName}`; + + const descriptorClass = ClassType.createInstantiable( + descriptorName, + getClassFullName(converterNode, fileInfo.moduleName, descriptorName), + fileInfo.moduleName, + fileInfo.filePath, + ClassTypeFlags.None, + getTypeSourceId(converterNode), + /* declaredMetaclass */ undefined, + isInstantiableClass(typeMetaclass) ? typeMetaclass : UnknownType.create() + ); + descriptorClass.details.baseClasses.push(evaluator.getBuiltInType(dataclassNode, 'object')); + computeMroLinearization(descriptorClass); + + const fields = descriptorClass.details.fields; + const selfType = synthesizeTypeVarForSelfCls(descriptorClass, /* isClsParam */ false); + + const setFunction = FunctionType.createSynthesizedInstance('__set__'); + FunctionType.addParameter(setFunction, { + category: ParameterCategory.Simple, + name: 'self', + type: selfType, + hasDeclaredType: true, + }); + FunctionType.addParameter(setFunction, { + category: ParameterCategory.Simple, + name: 'obj', + type: AnyType.create(), + hasDeclaredType: true, + }); + FunctionType.addParameter(setFunction, { + category: ParameterCategory.Simple, + name: 'value', + type: setType, + hasDeclaredType: true, + }); + setFunction.details.declaredReturnType = NoneType.createInstance(); + const setSymbol = Symbol.createWithType(SymbolFlags.ClassMember, setFunction); + fields.set('__set__', setSymbol); + + const getFunction = FunctionType.createSynthesizedInstance('__get__'); + FunctionType.addParameter(getFunction, { + category: ParameterCategory.Simple, + name: 'self', + type: selfType, + hasDeclaredType: true, + }); + FunctionType.addParameter(getFunction, { + category: ParameterCategory.Simple, + name: 'obj', + type: AnyType.create(), + hasDeclaredType: true, + }); + FunctionType.addParameter(getFunction, { + category: ParameterCategory.Simple, + name: 'objtype', + type: AnyType.create(), + hasDeclaredType: true, + }); + getFunction.details.declaredReturnType = getType; + const getSymbol = Symbol.createWithType(SymbolFlags.ClassMember, getFunction); + fields.set('__get__', getSymbol); + + return Symbol.createWithType(SymbolFlags.ClassMember, ClassType.cloneAsInstance(descriptorClass)); +} + // If the specified type is a descriptor — in particular, if it implements a // __set__ method, this method transforms the type into the input parameter // for the set method. diff --git a/packages/pyright-internal/src/tests/samples/dataclassConverter2.py b/packages/pyright-internal/src/tests/samples/dataclassConverter2.py new file mode 100644 index 000000000000..83c68b673ec3 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/dataclassConverter2.py @@ -0,0 +1,38 @@ +# This sample tests assignment of dataclass fields that use +# the coverter parameter described in PEP 712. + +from dataclasses import dataclass, field + + +def converter_simple(s: str) -> int: ... +def converter_passThru(x: str | int) -> str | int: ... + +@dataclass +class Foo: + # This should generate an error because "converter" is not an official property yet. + asymmetric: int = field(converter=converter_simple) + # This should generate an error because "converter" is not an official property yet. + symmetric: str | int = field(converter=converter_passThru) + +foo = Foo("1", 1) + +reveal_type(foo.asymmetric, expected_text="int") +foo.asymmetric = "2" +reveal_type(foo.asymmetric, expected_text="int") # Asymmetric -- type narrowing should not occur +# This should generate an error because only strs can be assigned to field0. +foo.asymmetric = 2 + +reveal_type(foo.symmetric, expected_text="str | int") +foo.symmetric = "1" +reveal_type(foo.symmetric, expected_text="Literal['1']") # Symmetric -- type narrowing should occur + + +reveal_type(Foo.asymmetric, expected_text="int") +Foo.asymmetric = "2" +reveal_type(Foo.asymmetric, expected_text="int") +# This should generate an error because only strs can be assigned to field0. +Foo.asymmetric = 2 + +reveal_type(Foo.symmetric, expected_text="str | int") +Foo.symmetric = "1" +reveal_type(Foo.symmetric, expected_text="Literal['1']") \ No newline at end of file diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 25d8fb31ec60..88271ff3d786 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -679,6 +679,12 @@ test('DataClassConverter1', () => { TestUtils.validateResults(analysisResults, 17); }); +test('DataClassConverter2', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassConverter2.py']); + + TestUtils.validateResults(analysisResults, 4); +}); + test('DataClassPostInit1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassPostInit1.py']);