Skip to content

Commit

Permalink
Support pattern matching literal sequences.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568557533
  • Loading branch information
martindemello authored and rchen152 committed Sep 26, 2023
1 parent d3e3d70 commit 73b9d0f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
14 changes: 13 additions & 1 deletion pytype/abstract/_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def __repr__(self):
return f"splat({self.iterable.data!r})"


class SequenceLength(_base.BaseValue):
class SequenceLength(_base.BaseValue, mixin.HasSlots):
"""Sequence length for match statements."""

def __init__(self, sequence, ctx):
Expand All @@ -693,10 +693,22 @@ def __init__(self, sequence, ctx):
length += 1
self.length = length
self.splat = splat
mixin.HasSlots.init_mixin(self)
self.set_native_slot("__sub__", self.sub_slot)

def __repr__(self):
splat = "+" if self.splat else ""
return f"SequenceLength[{self.length}{splat}]"

def instantiate(self, node, container=None):
return self.to_variable(node)

def sub_slot(self, node, other_var):
# We should not get a ConversionError here; this is code generated by the
# compiler from a literal sequence in a concrete match statement
val = abstract_utils.get_atomic_python_constant(other_var, int)
if self.splat:
ret = self.ctx.convert.build_int(node)
else:
ret = self.ctx.convert.constant_to_var(self.length - val, node=node)
return node, ret
43 changes: 43 additions & 0 deletions pytype/tests/test_pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,49 @@ def f(x: int):
def f(x: int) -> None: ...
""")

def test_sequence3(self):
self.Check("""
from typing import Tuple
def f(path: Tuple[str, str]) -> bool:
match path:
case (('foo' | 'bar'), 'baz'):
return True
case _:
return False
""")

def test_sequence4(self):
self.Check("""
from typing import Sequence
def f(path: Sequence[str]) -> bool:
match path:
case [*_, ('foo' | 'bar'), 'baz']:
return True
case _:
return False
""")

def test_sequence5(self):
self.Check("""
from typing import List, Optional
def f(path):
match path:
case (*_, ('foo' | 'bar'), 'baz'):
return 10
case _:
return None
a = f((1, 2, 3, 4, 'foo', 'baz'))
b = f((1, 2))
xs: List[str] = []
c = f(xs)
assert_type(a, int)
assert_type(b, None)
assert_type(c, Optional[int])
""")

def test_list1(self):
ty = self.Infer("""
def f(x: list[int]):
Expand Down

0 comments on commit 73b9d0f

Please sign in to comment.