Skip to content

Commit

Permalink
Infer function parameter types using annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ringohoffman committed Jun 19, 2023
1 parent eabc643 commit e50c043
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
24 changes: 22 additions & 2 deletions astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,10 +1168,30 @@ def infer_assign(
assign node.
"""
if isinstance(self.parent, nodes.AugAssign):
return self.parent.infer(context)
yield from self.parent.infer(context)
return

stmts = list(self.assigned_stmts(context=context))
return bases._infer_stmts(stmts, context)
yield from bases._infer_stmts(stmts, context)

# Infer function parameter types using their annotations, if present.
if (
isinstance(self, nodes.AssignName)
and isinstance(self.parent, nodes.Arguments)
and self.parent.args is not None
):
try:
annotation_index = self.parent.args.index(self)
except ValueError:
annotation_index = None

if (
annotation_index is not None
and (annotation := self.parent.annotations[annotation_index]) is not None
):
for annotation_result in annotation.infer(context=context):
if isinstance(annotation_result, nodes.ClassDef):
yield bases.Instance(annotation_result)


nodes.AssignName._infer = infer_assign
Expand Down
18 changes: 18 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4488,6 +4488,24 @@ def test_uninferable_type_subscript(self) -> None:
with self.assertRaises(InferenceError):
_ = next(node.infer())

def test_infer_parameters_from_type_hints(self) -> None:
node = extract_node(
"""
class Logger:
def info(self, msg: str) -> None:
...
class MyClassThatLogs:
def __init__(self, logger: Logger) -> None:
self.logger = logger
my_class_that_logs = MyClassThatLogs(Logger())
my_class_that_logs.logger
"""
)
inferred = list(node.infer())
assert isinstance(inferred[1], Instance)


class GetattrTest(unittest.TestCase):
def test_yes_when_unknown(self) -> None:
Expand Down

0 comments on commit e50c043

Please sign in to comment.