Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add slots to base classes, @add_slots takes bases into account #605

Merged
merged 7 commits into from
Jan 16, 2022
19 changes: 15 additions & 4 deletions libcst/_add_slots.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# This file is derived from github.com/ericvsmith/dataclasses, and is Apache 2 licensed.
# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f188f452/LICENSE.txt
# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f/dataclass_tools.py
# Changed: takes slots in base classes into account when creating slots

import dataclasses
from itertools import chain, filterfalse
from typing import Any, Mapping, Type, TypeVar

_T = TypeVar("_T")
Expand All @@ -19,7 +21,14 @@ def add_slots(cls: Type[_T]) -> Type[_T]:
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in dataclasses.fields(cls))
cls_dict["__slots__"] = field_names
inherited_slots = set(
chain.from_iterable(
superclass.__dict__.get("__slots__", ()) for superclass in cls.mro()
)
)
cls_dict["__slots__"] = tuple(
filterfalse(inherited_slots.__contains__, field_names)
)
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.
Expand Down Expand Up @@ -50,12 +59,14 @@ def add_slots(cls: Type[_T]) -> Type[_T]:

def __getstate__(self: object) -> Mapping[str, Any]:
return {
slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
field.name: getattr(self, field.name)
for field in dataclasses.fields(self)
if hasattr(self, field.name)
}

def __setstate__(self: object, state: Mapping[str, Any]) -> None:
for slot, value in state.items():
object.__setattr__(self, slot, value)
for fieldname, value in state.items():
object.__setattr__(self, fieldname, value)

cls.__getstate__ = __getstate__
cls.__setstate__ = __setstate__
Expand Down
10 changes: 9 additions & 1 deletion libcst/_nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field, fields, replace
from typing import Any, cast, Dict, List, Mapping, Sequence, TypeVar, Union
from typing import Any, cast, ClassVar, Dict, List, Mapping, Sequence, TypeVar, Union

from libcst._flatten_sentinel import FlattenSentinel
from libcst._nodes.internal import CodegenState
Expand Down Expand Up @@ -109,6 +109,9 @@ def _clone(val: object) -> object:

@dataclass(frozen=True)
class CSTNode(ABC):

__slots__: ClassVar[Sequence[str]] = ()

def __post_init__(self) -> None:
# PERF: It might make more sense to move validation work into the visitor, which
# would allow us to avoid validating the tree when parsing a file.
Expand Down Expand Up @@ -468,6 +471,9 @@ def field(cls, *args: object, **kwargs: object) -> Any:


class BaseLeaf(CSTNode, ABC):

__slots__ = ()

@property
def children(self) -> Sequence[CSTNode]:
# override this with an optimized implementation
Expand All @@ -487,6 +493,8 @@ class BaseValueToken(BaseLeaf, ABC):
into the parent CSTNode, and hard-coded into the implementation of _codegen.
"""

__slots__ = ()

value: str

def _codegen_impl(self, state: CodegenState) -> None:
Expand Down
39 changes: 35 additions & 4 deletions libcst/_nodes/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ class _BaseParenthesizedNode(CSTNode, ABC):
this to get that functionality.
"""

__slots__ = ()

lpar: Sequence[LeftParen] = ()
# Sequence of parenthesis for precedence dictation.
rpar: Sequence[RightParen] = ()
Expand Down Expand Up @@ -254,6 +256,8 @@ class BaseExpression(_BaseParenthesizedNode, ABC):
An base class for all expressions. :class:`BaseExpression` contains no fields.
"""

__slots__ = ()

def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
"""
Returns true if this expression is safe to be use with a word operator
Expand Down Expand Up @@ -296,7 +300,7 @@ class BaseAssignTargetExpression(BaseExpression, ABC):
<https://github.com/python/cpython/blob/v3.8.0a4/Python/ast.c#L1120>`_.
"""

pass
__slots__ = ()


class BaseDelTargetExpression(BaseExpression, ABC):
Expand All @@ -316,7 +320,7 @@ class BaseDelTargetExpression(BaseExpression, ABC):
<https://github.com/python/cpython/blob/v3.8.0a4/Python/compile.c#L4854>`_.
"""

pass
__slots__ = ()


@add_slots
Expand Down Expand Up @@ -393,6 +397,8 @@ class BaseNumber(BaseExpression, ABC):
used anywhere that you need to explicitly take any number type.
"""

__slots__ = ()

def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
"""
Numbers are funny. The expression "5in [1,2,3,4,5]" is a valid expression
Expand Down Expand Up @@ -522,13 +528,16 @@ class BaseString(BaseExpression, ABC):
:class:`SimpleString`, :class:`ConcatenatedString`, and :class:`FormattedString`.
"""

pass
__slots__ = ()


StringQuoteLiteral = Literal['"', "'", '"""', "'''"]


class _BasePrefixedString(BaseString, ABC):

__slots__ = ()

@property
def prefix(self) -> str:
"""
Expand Down Expand Up @@ -699,7 +708,7 @@ class BaseFormattedStringContent(CSTNode, ABC):
sequence of :class:`BaseFormattedStringContent` parts.
"""

pass
__slots__ = ()


@add_slots
Expand Down Expand Up @@ -1415,6 +1424,8 @@ class BaseSlice(CSTNode, ABC):
This node is purely for typing.
"""

__slots__ = ()


@add_slots
@dataclass(frozen=True)
Expand Down Expand Up @@ -2190,6 +2201,8 @@ class _BaseExpressionWithArgs(BaseExpression, ABC):
in typing. So, we have common validation functions here.
"""

__slots__ = ()

#: Sequence of arguments that will be passed to the function call.
args: Sequence[Arg] = ()

Expand Down Expand Up @@ -2631,6 +2644,8 @@ class _BaseElementImpl(CSTNode, ABC):
An internal base class for :class:`Element` and :class:`DictElement`.
"""

__slots__ = ()

value: BaseExpression
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT

Expand Down Expand Up @@ -2668,13 +2683,17 @@ class BaseElement(_BaseElementImpl, ABC):
BaseDictElement.
"""

__slots__ = ()


class BaseDictElement(_BaseElementImpl, ABC):
"""
An element of a literal dict. For elements of a list, tuple, or set, see
BaseElement.
"""

__slots__ = ()


@add_slots
@dataclass(frozen=True)
Expand Down Expand Up @@ -2957,6 +2976,8 @@ class BaseList(BaseExpression, ABC):
object when evaluated.
"""

__slots__ = ()

lbracket: LeftSquareBracket = LeftSquareBracket.field()
#: Brackets surrounding the list.
rbracket: RightSquareBracket = RightSquareBracket.field()
Expand Down Expand Up @@ -3037,6 +3058,8 @@ class _BaseSetOrDict(BaseExpression, ABC):
shouldn't be exported.
"""

__slots__ = ()

lbrace: LeftCurlyBrace = LeftCurlyBrace.field()
#: Braces surrounding the set or dict.
rbrace: RightCurlyBrace = RightCurlyBrace.field()
Expand All @@ -3062,6 +3085,8 @@ class BaseSet(_BaseSetOrDict, ABC):
a set object when evaluated.
"""

__slots__ = ()


@add_slots
@dataclass(frozen=True)
Expand Down Expand Up @@ -3131,6 +3156,8 @@ class BaseDict(_BaseSetOrDict, ABC):
a dict object when evaluated.
"""

__slots__ = ()


@add_slots
@dataclass(frozen=True)
Expand Down Expand Up @@ -3407,6 +3434,8 @@ class BaseComp(BaseExpression, ABC):
:class:`GeneratorExp`, :class:`ListComp`, :class:`SetComp`, and :class:`DictComp`.
"""

__slots__ = ()

for_in: CompFor


Expand All @@ -3417,6 +3446,8 @@ class BaseSimpleComp(BaseComp, ABC):
``value``.
"""

__slots__ = ()

#: The expression evaluated during each iteration of the comprehension. This
#: lexically comes before the ``for_in`` clause, but it is semantically the
#: inner-most element, evaluated inside the ``for_in`` clause.
Expand Down
14 changes: 14 additions & 0 deletions libcst/_nodes/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class _BaseOneTokenOp(CSTNode, ABC):
Any node that has a static value and needs to own whitespace on both sides.
"""

__slots__ = ()

whitespace_before: BaseParenthesizableWhitespace

whitespace_after: BaseParenthesizableWhitespace
Expand Down Expand Up @@ -51,6 +53,8 @@ class _BaseTwoTokenOp(CSTNode, ABC):
in beteween them.
"""

__slots__ = ()

whitespace_before: BaseParenthesizableWhitespace

whitespace_between: BaseParenthesizableWhitespace
Expand Down Expand Up @@ -93,6 +97,8 @@ class BaseUnaryOp(CSTNode, ABC):
Any node that has a static value used in a :class:`UnaryOperation` expression.
"""

__slots__ = ()

#: Any space that appears directly after this operator.
whitespace_after: BaseParenthesizableWhitespace

Expand All @@ -119,27 +125,35 @@ class BaseBooleanOp(_BaseOneTokenOp, ABC):
This node is purely for typing.
"""

__slots__ = ()


class BaseBinaryOp(CSTNode, ABC):
"""
Any node that has a static value used in a :class:`BinaryOperation` expression.
This node is purely for typing.
"""

__slots__ = ()


class BaseCompOp(CSTNode, ABC):
"""
Any node that has a static value used in a :class:`Comparison` expression.
This node is purely for typing.
"""

__slots__ = ()


class BaseAugOp(CSTNode, ABC):
"""
Any node that has a static value used in an :class:`AugAssign` assignment.
This node is purely for typing.
"""

__slots__ = ()


@add_slots
@dataclass(frozen=True)
Expand Down
14 changes: 13 additions & 1 deletion libcst/_nodes/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class BaseSuite(CSTNode, ABC):
-- https://docs.python.org/3/reference/compound_stmts.html
"""

__slots__ = ()

body: Union[Sequence["BaseStatement"], Sequence["BaseSmallStatement"]]


Expand All @@ -88,7 +90,7 @@ class BaseStatement(CSTNode, ABC):
in a particular location.
"""

pass
__slots__ = ()


class BaseSmallStatement(CSTNode, ABC):
Expand All @@ -99,6 +101,8 @@ class BaseSmallStatement(CSTNode, ABC):
simplify type definitions and isinstance checks.
"""

__slots__ = ()

#: An optional semicolon that appears after a small statement. This is optional
#: for the last small statement in a :class:`SimpleStatementLine` or
#: :class:`SimpleStatementSuite`, but all other small statements inside a simple
Expand Down Expand Up @@ -370,6 +374,8 @@ class _BaseSimpleStatement(CSTNode, ABC):
small statement.
"""

__slots__ = ()

#: Sequence of small statements. All but the last statement are required to have
#: a semicolon.
body: Sequence[BaseSmallStatement]
Expand Down Expand Up @@ -554,6 +560,8 @@ class BaseCompoundStatement(BaseStatement, ABC):
-- https://docs.python.org/3/reference/compound_stmts.html
"""

__slots__ = ()

#: The body of this compound statement.
body: BaseSuite

Expand Down Expand Up @@ -2633,6 +2641,8 @@ class MatchPattern(_BaseParenthesizedNode, ABC):
statement.
"""

__slots__ = ()


@add_slots
@dataclass(frozen=True)
Expand Down Expand Up @@ -2960,6 +2970,8 @@ class MatchSequence(MatchPattern, ABC):
otherwise matches a fixed length sequence.
"""

__slots__ = ()

#: Patterns to be matched against the subject elements if it is a sequence.
patterns: Sequence[Union[MatchSequenceElement, MatchStar]]

Expand Down
2 changes: 2 additions & 0 deletions libcst/_nodes/whitespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class BaseParenthesizableWhitespace(CSTNode, ABC):
``iftest``), it has some semantic value.
"""

__slots__ = ()

# TODO: Should we somehow differentiate places where we require non-zero whitespace
# with a separate type?

Expand Down
Loading