From e50c043a4f0fc642ce8384eb80cf32e74e526541 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 18 Jun 2023 17:05:32 -0700 Subject: [PATCH] Infer function parameter types using annotations Fixes https://github.com/pylint-dev/pylint/issues/4813 & https://github.com/pylint-dev/pylint/issues/8781 --- astroid/inference.py | 24 ++++++++++++++++++++++-- tests/test_inference.py | 18 ++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/astroid/inference.py b/astroid/inference.py index 6dcfa49f1b..bf61a6377f 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -1168,10 +1168,30 @@ def infer_assign( assign node. """ if isinstance(self.parent, nodes.AugAssign): - return self.parent.infer(context) + yield from self.parent.infer(context) + return stmts = list(self.assigned_stmts(context=context)) - return bases._infer_stmts(stmts, context) + yield from bases._infer_stmts(stmts, context) + + # Infer function parameter types using their annotations, if present. + if ( + isinstance(self, nodes.AssignName) + and isinstance(self.parent, nodes.Arguments) + and self.parent.args is not None + ): + try: + annotation_index = self.parent.args.index(self) + except ValueError: + annotation_index = None + + if ( + annotation_index is not None + and (annotation := self.parent.annotations[annotation_index]) is not None + ): + for annotation_result in annotation.infer(context=context): + if isinstance(annotation_result, nodes.ClassDef): + yield bases.Instance(annotation_result) nodes.AssignName._infer = infer_assign diff --git a/tests/test_inference.py b/tests/test_inference.py index 6760f9c91b..92521091e7 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4488,6 +4488,24 @@ def test_uninferable_type_subscript(self) -> None: with self.assertRaises(InferenceError): _ = next(node.infer()) + def test_infer_parameters_from_type_hints(self) -> None: + node = extract_node( + """ + class Logger: + def info(self, msg: str) -> None: + ... + + class MyClassThatLogs: + def __init__(self, logger: Logger) -> None: + self.logger = logger + + my_class_that_logs = MyClassThatLogs(Logger()) + my_class_that_logs.logger + """ + ) + inferred = list(node.infer()) + assert isinstance(inferred[1], Instance) + class GetattrTest(unittest.TestCase): def test_yes_when_unknown(self) -> None: