diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 27475d5bc..0cf281c0d 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -932,6 +932,24 @@ def SaveMatchedNode(matcher: _OtherNodeT, name: str) -> _OtherNodeT: return cast(_OtherNodeT, _ExtractMatchingNode(matcher, name)) +def _matches_zero_nodes( + matcher: Union[ + BaseMatcherNode, + _BaseWildcardNode, + MatchIfTrue[Callable[[object], bool]], + _BaseMetadataMatcher, + DoNotCareSentinel, + ] +) -> bool: + if isinstance(matcher, AtLeastN) and matcher.n == 0: + return True + if isinstance(matcher, AtMostN): + return True + if isinstance(matcher, _ExtractMatchingNode): + return _matches_zero_nodes(matcher.matcher) + return False + + @dataclass(frozen=True) class _SequenceMatchesResult: sequence_capture: Optional[ @@ -960,14 +978,13 @@ def _sequence_matches( # noqa: C901 return _SequenceMatchesResult({}, None) if not nodes and matchers: # Base case, we have one or more matcher that wasn't matched - return ( - _SequenceMatchesResult({}, []) - if all( - (isinstance(m, AtLeastN) and m.n == 0) or isinstance(m, AtMostN) - for m in matchers + if all(_matches_zero_nodes(m) for m in matchers): + return _SequenceMatchesResult( + {m.name: () for m in matchers if isinstance(m, _ExtractMatchingNode)}, + (), ) - else _SequenceMatchesResult(None, None) - ) + else: + return _SequenceMatchesResult(None, None) if nodes and not matchers: # Base case, we have nodes left that don't match any matcher return _SequenceMatchesResult(None, None) diff --git a/libcst/matchers/tests/test_extract.py b/libcst/matchers/tests/test_extract.py index 5c3cf12a1..77c134a8a 100644 --- a/libcst/matchers/tests/test_extract.py +++ b/libcst/matchers/tests/test_extract.py @@ -322,6 +322,34 @@ def test_extract_optional_wildcard(self) -> None: ) self.assertEqual(nodes, {}) + def test_extract_optional_wildcard_head(self) -> None: + expression = cst.parse_expression("[3]") + nodes = m.extract( + expression, + m.List( + elements=[ + m.SaveMatchedNode(m.ZeroOrMore(), "head1"), + m.SaveMatchedNode(m.ZeroOrMore(), "head2"), + m.Element(value=m.Integer(value="3")), + ] + ), + ) + self.assertEqual(nodes, {"head1": (), "head2": ()}) + + def test_extract_optional_wildcard_tail(self) -> None: + expression = cst.parse_expression("[3]") + nodes = m.extract( + expression, + m.List( + elements=[ + m.Element(value=m.Integer(value="3")), + m.SaveMatchedNode(m.ZeroOrMore(), "tail1"), + m.SaveMatchedNode(m.ZeroOrMore(), "tail2"), + ] + ), + ) + self.assertEqual(nodes, {"tail1": (), "tail2": ()}) + def test_extract_optional_wildcard_present(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract(