Skip to content

Commit

Permalink
checkexpr: cache type of container literals when possible
Browse files Browse the repository at this point in the history
When a container (list, set, tuple, or dict) literal expression is
used as an argument to an overloaded function it will get repeatedly
typechecked. This becomes particularly problematic when the expression
is somewhat large, as seen in #9427

To avoid repeated work, add a new field in the relevant AST nodes to
cache the resolved type of the expression. Right now the cache is
only used in the fast path, although it could conceivably be leveraged
for the slow path as well in a follow-up commit.

To further reduce duplicate work, when the fast-path doesn't work, we
use the cache to make a note of that, to avoid repeatedly attempting to
take the fast path.

Fixes #9427
  • Loading branch information
hugues-aff committed May 1, 2022
1 parent a56ebec commit 8b2bf54
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
38 changes: 25 additions & 13 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3249,13 +3249,13 @@ def apply_type_arguments_to_callable(

def visit_list_expr(self, e: ListExpr) -> Type:
"""Type check a list expression [...]."""
return self.check_lst_expr(e.items, 'builtins.list', '<list>', e)
return self.check_lst_expr(e, 'builtins.list', '<list>')

def visit_set_expr(self, e: SetExpr) -> Type:
return self.check_lst_expr(e.items, 'builtins.set', '<set>', e)
return self.check_lst_expr(e, 'builtins.set', '<set>')

def fast_container_type(
self, items: List[Expression], container_fullname: str
self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str
) -> Optional[Type]:
"""
Fast path to determine the type of a list or set literal,
Expand All @@ -3270,21 +3270,26 @@ def fast_container_type(
ctx = self.type_context[-1]
if ctx:
return None
if e._resolved_type is not None:
return e._resolved_type if isinstance(e._resolved_type, Instance) else None
values: List[Type] = []
for item in items:
for item in e.items:
if isinstance(item, StarExpr):
# fallback to slow path
e._resolved_type = NoneType()
return None
values.append(self.accept(item))
vt = join.join_type_list(values)
if not isinstance(vt, Instance):
return None
return self.chk.named_generic_type(container_fullname, [vt])
ct = self.chk.named_generic_type(container_fullname, [vt])
e._resolved_type = ct
return ct

def check_lst_expr(self, items: List[Expression], fullname: str,
tag: str, context: Context) -> Type:
def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str,
tag: str) -> Type:
# fast path
t = self.fast_container_type(items, fullname)
t = self.fast_container_type(e, fullname)
if t:
return t

Expand All @@ -3303,10 +3308,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str,
variables=[tv])
out = self.check_call(constructor,
[(i.expr if isinstance(i, StarExpr) else i)
for i in items],
for i in e.items],
[(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS)
for i in items],
context)[0]
for i in e.items],
e)[0]
return remove_instance_last_known_values(out)

def visit_tuple_expr(self, e: TupleExpr) -> Type:
Expand Down Expand Up @@ -3356,7 +3361,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
else:
# A star expression that's not a Tuple.
# Treat the whole thing as a variable-length tuple.
return self.check_lst_expr(e.items, 'builtins.tuple', '<tuple>', e)
return self.check_lst_expr(e, 'builtins.tuple', '<tuple>')
else:
if not type_context_items or j >= len(type_context_items):
tt = self.accept(item)
Expand All @@ -3382,6 +3387,8 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
ctx = self.type_context[-1]
if ctx:
return None
if e._resolved_type is not None:
return e._resolved_type if isinstance(e._resolved_type, Instance) else None
keys: List[Type] = []
values: List[Type] = []
stargs: Optional[Tuple[Type, Type]] = None
Expand All @@ -3395,17 +3402,22 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
):
stargs = (st.args[0], st.args[1])
else:
e._resolved_type = NoneType()
return None
else:
keys.append(self.accept(key))
values.append(self.accept(value))
kt = join.join_type_list(keys)
vt = join.join_type_list(values)
if not (isinstance(kt, Instance) and isinstance(vt, Instance)):
e._resolved_type = NoneType()
return None
if stargs and (stargs[0] != kt or stargs[1] != vt):
e._resolved_type = NoneType()
return None
return self.chk.named_generic_type('builtins.dict', [kt, vt])
dt = self.chk.named_generic_type('builtins.dict', [kt, vt])
e._resolved_type = dt
return dt

def visit_dict_expr(self, e: DictExpr) -> Type:
"""Type check a dict expression.
Expand Down
16 changes: 12 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,13 +2026,15 @@ def is_dynamic(self) -> bool:
class ListExpr(Expression):
"""List literal expression [...]."""

__slots__ = ('items',)
__slots__ = ('items', '_resolved_type')

items: List[Expression]
_resolved_type: Optional["mypy.types.ProperType"]

def __init__(self, items: List[Expression]) -> None:
super().__init__()
self.items = items
self._resolved_type = None

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_list_expr(self)
Expand All @@ -2041,13 +2043,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
class DictExpr(Expression):
"""Dictionary literal expression {key: value, ...}."""

__slots__ = ('items',)
__slots__ = ('items', '_resolved_type')

items: List[Tuple[Optional[Expression], Expression]]
_resolved_type: Optional["mypy.types.ProperType"]

def __init__(self, items: List[Tuple[Optional[Expression], Expression]]) -> None:
super().__init__()
self.items = items
self._resolved_type = None

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_dict_expr(self)
Expand All @@ -2058,13 +2062,15 @@ class TupleExpr(Expression):
Also lvalue sequences (..., ...) and [..., ...]"""

__slots__ = ('items',)
__slots__ = ('items', '_resolved_type')

items: List[Expression]
_resolved_type: Optional["mypy.types.ProperType"]

def __init__(self, items: List[Expression]) -> None:
super().__init__()
self.items = items
self._resolved_type = None

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_tuple_expr(self)
Expand All @@ -2073,13 +2079,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
class SetExpr(Expression):
"""Set literal expression {value, ...}."""

__slots__ = ('items',)
__slots__ = ('items', '_resolved_type')

items: List[Expression]
_resolved_type: Optional["mypy.types.ProperType"]

def __init__(self, items: List[Expression]) -> None:
super().__init__()
self.items = items
self._resolved_type = None

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_set_expr(self)
Expand Down

0 comments on commit 8b2bf54

Please sign in to comment.