Skip to content

Commit

Permalink
Narrow individual items when matching a tuple to a sequence pattern (#…
Browse files Browse the repository at this point in the history
…16905)

Fixes #12364

When matching a tuple to a sequence pattern, this change narrows the
type of tuple items inside the matched case:

```py
def test(a: bool, b: bool) -> None:
    match a, b:
        case True, True:
            reveal_type(a)  # before: "builtins.bool", after: "Literal[True]"
```

This also works with nested tuples, recursively:

```py
def test(a: bool, b: bool, c: bool) -> None:
    match a, (b, c):
        case _, [True, False]:
            reveal_type(c)  # before: "builtins.bool", after: "Literal[False]"
```

This only partially fixes issue #12364; see [my comment
there](#12364 (comment))
for more context.

---

This is my first contribution to mypy, so I may miss some context or
conventions; I'm eager for any feedback!

---------

Co-authored-by: Loïc Simon <[email protected]>
  • Loading branch information
loic-simon and Loïc Simon authored Apr 4, 2024
1 parent ec44015 commit 8019010
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
17 changes: 17 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5119,6 +5119,9 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
if pattern_map:
for expr, typ in pattern_map.items():
self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ))
self.push_type_map(pattern_type.captures)
if g is not None:
with self.binder.frame_context(can_skip=False, fall_through=3):
Expand Down Expand Up @@ -5156,6 +5159,20 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def _get_recursive_sub_patterns_map(
self, expr: Expression, typ: Type
) -> dict[Expression, Type]:
sub_patterns_map: dict[Expression, Type] = {}
typ_ = get_proper_type(typ)
if isinstance(expr, TupleExpr) and isinstance(typ_, TupleType):
# When matching a tuple expression with a sequence pattern, narrow individual tuple items
assert len(expr.items) == len(typ_.items)
for item_expr, item_typ in zip(expr.items, typ_.items):
sub_patterns_map[item_expr] = item_typ
sub_patterns_map.update(self._get_recursive_sub_patterns_map(item_expr, item_typ))

return sub_patterns_map

def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]:
all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
Expand Down
66 changes: 66 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,72 @@ match m:
reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]"
[builtins fixtures/list.pyi]

[case testMatchSequencePatternNarrowSubjectItems]
m: int
n: str
o: bool

match m, n, o:
case [3, "foo", True]:
reveal_type(m) # N: Revealed type is "Literal[3]"
reveal_type(n) # N: Revealed type is "Literal['foo']"
reveal_type(o) # N: Revealed type is "Literal[True]"
case [a, b, c]:
reveal_type(m) # N: Revealed type is "builtins.int"
reveal_type(n) # N: Revealed type is "builtins.str"
reveal_type(o) # N: Revealed type is "builtins.bool"

reveal_type(m) # N: Revealed type is "builtins.int"
reveal_type(n) # N: Revealed type is "builtins.str"
reveal_type(o) # N: Revealed type is "builtins.bool"
[builtins fixtures/tuple.pyi]

[case testMatchSequencePatternNarrowSubjectItemsRecursive]
m: int
n: int
o: int
p: int
q: int
r: int

match m, (n, o), (p, (q, r)):
case [0, [1, 2], [3, [4, 5]]]:
reveal_type(m) # N: Revealed type is "Literal[0]"
reveal_type(n) # N: Revealed type is "Literal[1]"
reveal_type(o) # N: Revealed type is "Literal[2]"
reveal_type(p) # N: Revealed type is "Literal[3]"
reveal_type(q) # N: Revealed type is "Literal[4]"
reveal_type(r) # N: Revealed type is "Literal[5]"
[builtins fixtures/tuple.pyi]

[case testMatchSequencePatternSequencesLengthMismatchNoNarrowing]
m: int
n: str
o: bool

match m, n, o:
case [3, "foo"]:
pass
case [3, "foo", True, True]:
pass
[builtins fixtures/tuple.pyi]

[case testMatchSequencePatternSequencesLengthMismatchNoNarrowingRecursive]
m: int
n: int
o: int

match m, (n, o):
case [0]:
pass
case [0, 1, [2]]:
pass
case [0, [1]]:
pass
case [0, [1, 2, 3]]:
pass
[builtins fixtures/tuple.pyi]

-- Mapping Pattern --

[case testMatchMappingPatternCaptures]
Expand Down

0 comments on commit 8019010

Please sign in to comment.