diff --git a/pytype/abstract/_instances.py b/pytype/abstract/_instances.py index 05dafd1da..0243f2b25 100644 --- a/pytype/abstract/_instances.py +++ b/pytype/abstract/_instances.py @@ -679,7 +679,7 @@ def __repr__(self): return f"splat({self.iterable.data!r})" -class SequenceLength(_base.BaseValue): +class SequenceLength(_base.BaseValue, mixin.HasSlots): """Sequence length for match statements.""" def __init__(self, sequence, ctx): @@ -693,6 +693,8 @@ def __init__(self, sequence, ctx): length += 1 self.length = length self.splat = splat + mixin.HasSlots.init_mixin(self) + self.set_native_slot("__sub__", self.sub_slot) def __repr__(self): splat = "+" if self.splat else "" @@ -700,3 +702,13 @@ def __repr__(self): def instantiate(self, node, container=None): return self.to_variable(node) + + def sub_slot(self, node, other_var): + # We should not get a ConversionError here; this is code generated by the + # compiler from a literal sequence in a concrete match statement + val = abstract_utils.get_atomic_python_constant(other_var, int) + if self.splat: + ret = self.ctx.convert.build_int(node) + else: + ret = self.ctx.convert.constant_to_var(self.length - val, node=node) + return node, ret diff --git a/pytype/tests/test_pattern_matching.py b/pytype/tests/test_pattern_matching.py index 0f29e54d2..ecb5e987e 100644 --- a/pytype/tests/test_pattern_matching.py +++ b/pytype/tests/test_pattern_matching.py @@ -59,6 +59,49 @@ def f(x: int): def f(x: int) -> None: ... """) + def test_sequence3(self): + self.Check(""" + from typing import Tuple + + def f(path: Tuple[str, str]) -> bool: + match path: + case (('foo' | 'bar'), 'baz'): + return True + case _: + return False + """) + + def test_sequence4(self): + self.Check(""" + from typing import Sequence + + def f(path: Sequence[str]) -> bool: + match path: + case [*_, ('foo' | 'bar'), 'baz']: + return True + case _: + return False + """) + + def test_sequence5(self): + self.Check(""" + from typing import List, Optional + def f(path): + match path: + case (*_, ('foo' | 'bar'), 'baz'): + return 10 + case _: + return None + + a = f((1, 2, 3, 4, 'foo', 'baz')) + b = f((1, 2)) + xs: List[str] = [] + c = f(xs) + assert_type(a, int) + assert_type(b, None) + assert_type(c, Optional[int]) + """) + def test_list1(self): ty = self.Infer(""" def f(x: list[int]):