From 6ae25834bbcebb79d08c4516c71569033593b4d7 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 10 Sep 2020 21:09:59 +0300 Subject: [PATCH] Implement TypeOf matcher (#384) * Implement TypeOf matcher * Satisfy the type checker * Expand the test case * Fix the annotation of _raw_options * Add documentation... --- docs/source/matchers.rst | 1 + libcst/codegen/gen_matcher_classes.py | 11 ++- libcst/matchers/__init__.py | 58 +++++++++------- libcst/matchers/_matcher_base.py | 94 +++++++++++++++++++++++++- libcst/matchers/tests/test_matchers.py | 84 +++++++++++++++++++++++ 5 files changed, 218 insertions(+), 30 deletions(-) diff --git a/docs/source/matchers.rst b/docs/source/matchers.rst index 37398f40f..eac6faa95 100644 --- a/docs/source/matchers.rst +++ b/docs/source/matchers.rst @@ -145,6 +145,7 @@ when calling :func:`~libcst.matchers.matches` or using decorators. .. autoclass:: libcst.matchers.OneOf .. autoclass:: libcst.matchers.AllOf +.. autoclass:: libcst.matchers.TypeOf .. autofunction:: libcst.matchers.DoesNotMatch .. autoclass:: libcst.matchers.MatchIfTrue .. autofunction:: libcst.matchers.MatchRegex diff --git a/libcst/codegen/gen_matcher_classes.py b/libcst/codegen/gen_matcher_classes.py index 5c6a550d8..b0657890d 100644 --- a/libcst/codegen/gen_matcher_classes.py +++ b/libcst/codegen/gen_matcher_classes.py @@ -456,14 +456,13 @@ def _get_fields(node: Type[cst.CSTNode]) -> Generator[Field, None, None]: generated_code.append("") generated_code.append("") generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes") -generated_code.append("from abc import ABC") generated_code.append("from dataclasses import dataclass") generated_code.append("from typing import Callable, Sequence, Union") generated_code.append("from typing_extensions import Literal") generated_code.append("import libcst as cst") generated_code.append("") generated_code.append( - "from libcst.matchers._matcher_base import BaseMatcherNode, DoNotCareSentinel, DoNotCare, OneOf, AllOf, DoesNotMatch, MatchIfTrue, MatchRegex, MatchMetadata, MatchMetadataIfTrue, ZeroOrMore, AtLeastN, ZeroOrOne, AtMostN, SaveMatchedNode, extract, extractall, findall, matches, replace" + "from libcst.matchers._matcher_base import AbstractBaseMatcherNodeMeta, BaseMatcherNode, DoNotCareSentinel, DoNotCare, TypeOf, OneOf, AllOf, DoesNotMatch, MatchIfTrue, MatchRegex, MatchMetadata, MatchMetadataIfTrue, ZeroOrMore, AtLeastN, ZeroOrOne, AtMostN, SaveMatchedNode, extract, extractall, findall, matches, replace" ) all_exports.update( [ @@ -477,6 +476,7 @@ def _get_fields(node: Type[cst.CSTNode]) -> Generator[Field, None, None]: "MatchRegex", "MatchMetadata", "MatchMetadataIfTrue", + "TypeOf", "ZeroOrMore", "AtLeastN", "ZeroOrOne", @@ -504,10 +504,15 @@ def _get_fields(node: Type[cst.CSTNode]) -> Generator[Field, None, None]: ] ) +generated_code.append("") +generated_code.append("") +generated_code.append("class _NodeABC(metaclass=AbstractBaseMatcherNodeMeta):") +generated_code.append(" __slots__ = ()") + for base in typeclasses: generated_code.append("") generated_code.append("") - generated_code.append(f"class {base.__name__}(ABC):") + generated_code.append(f"class {base.__name__}(_NodeABC):") generated_code.append(" pass") all_exports.add(base.__name__) diff --git a/libcst/matchers/__init__.py b/libcst/matchers/__init__.py index 8bd9f6b63..3b2d90776 100644 --- a/libcst/matchers/__init__.py +++ b/libcst/matchers/__init__.py @@ -5,7 +5,6 @@ # This file was generated by libcst.codegen.gen_matcher_classes -from abc import ABC from dataclasses import dataclass from typing import Callable, Sequence, Union @@ -14,6 +13,7 @@ import libcst as cst from libcst.matchers._decorators import call_if_inside, call_if_not_inside, leave, visit from libcst.matchers._matcher_base import ( + AbstractBaseMatcherNodeMeta, AllOf, AtLeastN, AtMostN, @@ -27,6 +27,7 @@ MatchRegex, OneOf, SaveMatchedNode, + TypeOf, ZeroOrMore, ZeroOrOne, extract, @@ -42,103 +43,107 @@ ) -class BaseAssignTargetExpression(ABC): +class _NodeABC(metaclass=AbstractBaseMatcherNodeMeta): + __slots__ = () + + +class BaseAssignTargetExpression(_NodeABC): pass -class BaseAugOp(ABC): +class BaseAugOp(_NodeABC): pass -class BaseBinaryOp(ABC): +class BaseBinaryOp(_NodeABC): pass -class BaseBooleanOp(ABC): +class BaseBooleanOp(_NodeABC): pass -class BaseComp(ABC): +class BaseComp(_NodeABC): pass -class BaseCompOp(ABC): +class BaseCompOp(_NodeABC): pass -class BaseCompoundStatement(ABC): +class BaseCompoundStatement(_NodeABC): pass -class BaseDelTargetExpression(ABC): +class BaseDelTargetExpression(_NodeABC): pass -class BaseDict(ABC): +class BaseDict(_NodeABC): pass -class BaseDictElement(ABC): +class BaseDictElement(_NodeABC): pass -class BaseElement(ABC): +class BaseElement(_NodeABC): pass -class BaseExpression(ABC): +class BaseExpression(_NodeABC): pass -class BaseFormattedStringContent(ABC): +class BaseFormattedStringContent(_NodeABC): pass -class BaseList(ABC): +class BaseList(_NodeABC): pass -class BaseMetadataProvider(ABC): +class BaseMetadataProvider(_NodeABC): pass -class BaseNumber(ABC): +class BaseNumber(_NodeABC): pass -class BaseParenthesizableWhitespace(ABC): +class BaseParenthesizableWhitespace(_NodeABC): pass -class BaseSet(ABC): +class BaseSet(_NodeABC): pass -class BaseSimpleComp(ABC): +class BaseSimpleComp(_NodeABC): pass -class BaseSlice(ABC): +class BaseSlice(_NodeABC): pass -class BaseSmallStatement(ABC): +class BaseSmallStatement(_NodeABC): pass -class BaseStatement(ABC): +class BaseStatement(_NodeABC): pass -class BaseString(ABC): +class BaseString(_NodeABC): pass -class BaseSuite(ABC): +class BaseSuite(_NodeABC): pass -class BaseUnaryOp(ABC): +class BaseUnaryOp(_NodeABC): pass @@ -13242,6 +13247,7 @@ class Yield(BaseExpression, BaseMatcherNode): "TrailingWhitespace", "Try", "Tuple", + "TypeOf", "UnaryOperation", "While", "With", diff --git a/libcst/matchers/_matcher_base.py b/libcst/matchers/_matcher_base.py index 532cb53dd..70a9340ac 100644 --- a/libcst/matchers/_matcher_base.py +++ b/libcst/matchers/_matcher_base.py @@ -7,12 +7,14 @@ import copy import inspect import re +from abc import ABCMeta from dataclasses import dataclass, fields from enum import Enum, auto from typing import ( Callable, Dict, Generic, + Iterator, List, Mapping, NoReturn, @@ -51,11 +53,26 @@ def __repr__(self) -> str: _BaseMatcherNodeSelfT = TypeVar("_BaseMatcherNodeSelfT", bound="BaseMatcherNode") _OtherNodeT = TypeVar("_OtherNodeT") _MetadataValueT = TypeVar("_MetadataValueT") +_MatcherTypeT = TypeVar("_MatcherTypeT", bound=Type["BaseMatcherNode"]) +_OtherNodeMatcherTypeT = TypeVar( + "_OtherNodeMatcherTypeT", bound=Type["BaseMatcherNode"] +) _METADATA_MISSING_SENTINEL = object() +class AbstractBaseMatcherNodeMeta(ABCMeta): + """ + Metaclass that all matcher nodes uses. Allows chaining 2 node type + together with an bitwise-or operator to produce an :class:`TypeOf` + matcher. + """ + + def __or__(self, node: Type["BaseMatcherNode"]) -> "TypeOf[Type[BaseMatcherNode]]": + return TypeOf(self, node) + + class BaseMatcherNode: """ Base class that all concrete matchers subclass from. :class:`OneOf` and @@ -103,6 +120,81 @@ def DoNotCare() -> DoNotCareSentinel: return DoNotCareSentinel.DEFAULT +class TypeOf(Generic[_MatcherTypeT], BaseMatcherNode): + """ + Matcher that matches any one of the given types. Useful when you want to work + with trees where a common property might belong to more than a single type. + + For example, if you want either a binary operation or a boolean operation + where the left side has a name ``foo``:: + + m.TypeOf(m.BinaryOperation, m.BooleanOperation)(left = m.Name("foo")) + + Or you could use the shorthand, like:: + + (m.BinaryOperation | m.BooleanOperation)(left = m.Name("foo")) + + Also :class:`TypeOf` matchers can be used with initalizing in the default + state of other node matchers (without passing any extra patterns):: + + m.Name | m.SimpleString + + The will be equal to:: + + m.OneOf(m.Name(), m.SimpleString()) + """ + + def __init__(self, *options: Union[_MatcherTypeT, "TypeOf[_MatcherTypeT]"]) -> None: + actual_options: List[_MatcherTypeT] = [] + for option in options: + if isinstance(option, TypeOf): + if option.initalized: + raise Exception( + "Cannot chain an uninitalized TypeOf with an initalized one" + ) + actual_options.extend(option._raw_options) + else: + actual_options.append(option) + + self._initalized = False + self._call_items: Tuple[Tuple[object, ...], Dict[str, object]] = ((), {}) + self._raw_options: Tuple[_MatcherTypeT, ...] = tuple(actual_options) + + @property + def initalized(self) -> bool: + return self._initalized + + @property + def options(self) -> Iterator[BaseMatcherNode]: + for option in self._raw_options: + args, kwargs = self._call_items + matcher_pattern = option(*args, **kwargs) + yield matcher_pattern + + def __call__(self, *args: object, **kwargs: object) -> BaseMatcherNode: + self._initalized = True + self._call_items = (args, kwargs) + return self + + def __or__( + self, other: _OtherNodeMatcherTypeT + ) -> "TypeOf[Union[_MatcherTypeT, _OtherNodeMatcherTypeT]]": + return TypeOf[Union[_MatcherTypeT, _OtherNodeMatcherTypeT]](self, other) + + def __and__(self, other: _OtherNodeMatcherTypeT) -> NoReturn: + left, right = type(self).__name__, other.__name__ + raise TypeError( + f"TypeError: unsupported operand type(s) for &: {left!r} and {right!r}" + ) + + def __invert__(self) -> "AllOf[BaseMatcherNode]": + return AllOf(*map(DoesNotMatch, self.options)) + + def __repr__(self) -> str: + types = ", ".join(repr(option) for option in self._raw_options) + return f"TypeOf({types}, initalized = {self.initalized})" + + class OneOf(Generic[_MatcherT], BaseMatcherNode): """ Matcher that matches any one of its options. Useful when you want to match @@ -1387,7 +1479,7 @@ def _matches( return {} if isinstance(matcher, _InverseOf) else None # Now, evaluate the matcher node itself. - if isinstance(matcher, OneOf): + if isinstance(matcher, (OneOf, TypeOf)): for matcher in matcher.options: node_capture = _node_matches(node, matcher, metadata_lookup) if node_capture is not None: diff --git a/libcst/matchers/tests/test_matchers.py b/libcst/matchers/tests/test_matchers.py index ab1e5cf1e..11d6b5f53 100644 --- a/libcst/matchers/tests/test_matchers.py +++ b/libcst/matchers/tests/test_matchers.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import dataclasses + import libcst as cst import libcst.matchers as m from libcst.matchers import matches @@ -202,6 +204,88 @@ def test_complex_matcher_false(self) -> None: ) ) + def test_type_of_matcher_true(self) -> None: + self.assertTrue(matches(cst.Name("true"), m.TypeOf(m.Name))) + self.assertTrue(matches(cst.Name("true"), m.TypeOf(m.Name)(value="true"))) + self.assertTrue(matches(cst.Name("true"), m.Name | m.Float | m.SimpleString)) + self.assertTrue( + matches(cst.SimpleString("'foo'"), m.TypeOf(m.Name, m.SimpleString)) + ) + self.assertTrue( + matches( + cst.SimpleString("'foo'"), + m.TypeOf(m.Name, m.SimpleString)(value="'foo'"), + ) + ) + with self.assertRaises(Exception): + # pyre-ignore + m.TypeOf(cst.Float)(value=1.0) | cst.Name + + with self.assertRaises(TypeError): + # pyre-ignore + m.TypeOf(cst.Float) & cst.SimpleString + + for case in ( + cst.BinaryOperation( + left=cst.Name("foo"), operator=cst.Add(), right=cst.Name("bar") + ), + cst.BooleanOperation( + left=cst.Name("foo"), operator=cst.Or(), right=cst.Name("bar") + ), + ): + self.assertTrue( + matches( + case, (m.BinaryOperation | m.BooleanOperation)(left=m.Name("foo")) + ) + ) + new_case = dataclasses.replace(case, left=case.right, right=case.left) + self.assertTrue( + matches( + new_case, + ~(m.BinaryOperation | m.BooleanOperation)(left=m.Name("foo")), + ) + ) + + def test_type_of_matcher_false(self) -> None: + self.assertFalse(matches(cst.Name("true"), m.TypeOf(m.SimpleString))) + self.assertFalse(matches(cst.Name("true"), m.TypeOf(m.Name)(value="false"))) + self.assertFalse( + matches(cst.Name("true"), m.TypeOf(m.SimpleString)(value="true")) + ) + self.assertFalse( + matches(cst.SimpleString("'foo'"), m.TypeOf(m.Name, m.Attribute)) + ) + self.assertFalse( + matches( + cst.SimpleString("'foo'"), m.TypeOf(m.Name, m.Attribute)(value="'foo'") + ) + ) + self.assertFalse( + matches( + cst.SimpleString("'foo'"), + m.TypeOf(m.Name, m.SimpleString)(value="'bar'"), + ) + ) + + for case in ( + cst.BinaryOperation( + left=cst.Name("foo"), operator=cst.Add(), right=cst.Name("bar") + ), + cst.BooleanOperation( + left=cst.Name("foo"), operator=cst.Or(), right=cst.Name("bar") + ), + ): + self.assertFalse( + matches( + case, (m.BinaryOperation | m.BooleanOperation)(left=m.Name("bar")) + ) + ) + self.assertFalse( + matches( + case, ~(m.BinaryOperation | m.BooleanOperation)(left=m.Name("foo")) + ) + ) + def test_or_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue(