Skip to content

Commit

Permalink
'in' can narrow TypedDict unions
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst committed Oct 15, 2022
1 parent d528bf2 commit 4617bf8
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 18 deletions.
89 changes: 71 additions & 18 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5009,6 +5009,44 @@ def conditional_callable_type_map(

return None, {}

def contains_operator_right_operand_type_map(
self, item_type: Type, collection_type: Type
) -> tuple[Type, Type]:
"""
Deduces the type of the right operand of the `in` operator.
For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
"""
if_types, else_types = [collection_type], [collection_type]
item_strs = try_getting_str_literals_from_type(item_type)
if item_strs:
if_types, else_types = self._contains_string_right_operand_type_map(
set(item_strs), collection_type
)
return UnionType.make_union(if_types), UnionType.make_union(else_types)

def _contains_string_right_operand_type_map(
self, item_strs: set[str], t: Type
) -> tuple[list[Type], list[Type]]:
t = get_proper_type(t)
if_types: list[Type] = []
else_types: list[Type] = []
if isinstance(t, TypedDictType):
if item_strs <= t.items.keys():
if_types.append(t)
elif item_strs.isdisjoint(t.items.keys()):
else_types.append(t)
else:
if_types.append(t)
else_types.append(t)
elif isinstance(t, UnionType):
for union_item in t.items:
a, b = self._contains_string_right_operand_type_map(item_strs, union_item)
if_types.extend(a)
else_types.extend(b)
else:
if_types = else_types = [t]
return if_types, else_types

def _is_truthy_type(self, t: ProperType) -> bool:
return (
(
Expand Down Expand Up @@ -5316,28 +5354,39 @@ def has_no_custom_eq_checks(t: Type) -> bool:
elif operator in {"in", "not in"}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
if left_index not in narrowable_operand_index_to_hash:
continue

item_type = operand_types[left_index]
collection_type = operand_types[right_index]

# We only try and narrow away 'None' for now
if not is_optional(item_type):
continue
if_map, else_map = {}, {}

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_optional(item_type):
collection_item_type = get_proper_type(
builtin_item_type(collection_type)
)
if (
collection_item_type is not None
and not is_optional(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)

if right_index in narrowable_operand_index_to_hash:
(
right_if_type,
right_else_type,
) = self.contains_operator_right_operand_type_map(
item_type, collection_type
)
expr = operands[right_index]
if_map[expr] = right_if_type
else_map[expr] = right_else_type

collection_item_type = get_proper_type(builtin_item_type(collection_type))
if collection_item_type is None or is_optional(collection_item_type):
continue
if (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
):
continue
if is_overlapping_erased_types(item_type, collection_item_type):
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
else:
continue
else:
if_map = {}
else_map = {}
Expand Down Expand Up @@ -5390,6 +5439,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
or_conditional_maps(left_if_vars, right_if_vars),
and_conditional_maps(left_else_vars, right_else_vars),
)
elif isinstance(node, OpExpr) and node.op == "in":
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)

elif isinstance(node, UnaryExpr) and node.op == "not":
left, right = self.find_isinstance_check(node.expr)
return right, left
Expand Down
41 changes: 41 additions & 0 deletions test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,47 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testFinalTypedDictTagged]
from __future__ import annotations
from typing import Literal, TypedDict
from typing_extensions import final

@final
class D1(TypedDict):
foo: int


@final
class D2(TypedDict):
bar: int

d: D1 | D2
val: int

val = d['foo'] # E: TypedDict "D2" has no key "foo"
if 'foo' in d:
val = d['foo']
else:
val = d['bar']

foo_or_bar: Literal['foo', 'bar']
if foo_or_bar in d:
val = d['foo'] # E: TypedDict "D2" has no key "foo"
val = d['bar'] # E: TypedDict "D1" has no key "bar"
else:
val = d['foo'] # E: TypedDict "D2" has no key "foo"
val = d['bar'] # E: TypedDict "D1" has no key "bar"

foo_or_invalid: Literal['foo', 'invalid']
if foo_or_invalid in d:
val = d['foo']
else:
val = d['foo'] # E: TypedDict "D2" has no key "foo"
val = d['bar'] # E: TypedDict "D1" has no key "bar"

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testCannotSubclassFinalTypedDict]
from typing import TypedDict
from typing_extensions import final
Expand Down

0 comments on commit 4617bf8

Please sign in to comment.