Skip to content

Commit

Permalink
Improved type inference for lambdas in the case where a parameter inc…
Browse files Browse the repository at this point in the history
…ludes a default value and the expected type doesn't include that parameter. This improvement was suggested in the [mypy issue tracker](python/mypy#15459). (#5337)

Co-authored-by: Eric Traut <[email protected]>
  • Loading branch information
erictraut and msfterictraut authored Jun 18, 2023
1 parent 8ce23eb commit 643bb1d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
19 changes: 14 additions & 5 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13231,17 +13231,26 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
// For now, use only the first expected type.
const expectedFunctionType = expectedFunctionTypes.length > 0 ? expectedFunctionTypes[0] : undefined;
let paramsArePositionOnly = true;
const expectedParamDetails = expectedFunctionType ? getParameterListDetails(expectedFunctionType) : undefined;

node.parameters.forEach((param, index) => {
let paramType: Type = UnknownType.create();
if (expectedFunctionType && index < expectedFunctionType.details.parameters.length) {
paramType = FunctionType.getEffectiveParameterType(expectedFunctionType, index);
let paramType: Type | undefined;
if (expectedParamDetails) {
if (index < expectedParamDetails.params.length) {
paramType = expectedParamDetails.params[index].type;
} else if (param.defaultValue) {
// If the lambda param has a default value but there is no associated
// parameter in the expected type, assume that the default value is
// being used to explicitly capture a value from an outer scope. Infer
// its type from the default value expression.
paramType = getTypeOfExpression(param.defaultValue, undefined, inferenceContext).type;
}
}

if (param.name) {
writeTypeCache(
param.name,
{ type: transformVariadicParamType(node, param.category, paramType) },
{ type: transformVariadicParamType(node, param.category, paramType ?? UnknownType.create()) },
EvaluatorFlags.None
);
}
Expand Down Expand Up @@ -13288,7 +13297,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
hasDefault: !!param.defaultValue,
defaultValueExpression: param.defaultValue,
hasDeclaredType: true,
type: paramType,
type: paramType ?? UnknownType.create(),
};

FunctionType.addParameter(functionType, functionParam);
Expand Down
12 changes: 12 additions & 0 deletions packages/pyright-internal/src/tests/samples/lambda12.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# This sample tests the case where a lambda includes one or more parameters
# that accept a default value and the the expected type does not include
# these parameters. In this case, the types of the extra parameters should
# be inferred based on the default value type.

# pyright: strict

from typing import Callable


def func1() -> list[Callable[[int], int]]:
return [lambda x, i=i: i * x for i in range(5)]
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator1.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,12 @@ test('Lambda11', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('Lambda12', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['lambda12.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('Call1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['call1.py']);

Expand Down

0 comments on commit 643bb1d

Please sign in to comment.