From 4c0085a37c2a8eaa746124eeb0c7f9c312cfc68c Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Fri, 14 Jan 2022 17:26:24 +0100 Subject: [PATCH 1/7] add slots to base classes, @add_slots takes bases into account --- libcst/_add_slots.py | 15 +++++- libcst/_nodes/base.py | 11 +++++ libcst/_nodes/expression.py | 39 +++++++++++++-- libcst/_nodes/op.py | 14 ++++++ libcst/_nodes/statement.py | 17 ++++++- libcst/_nodes/whitespace.py | 2 + libcst/tests/test_batched_visitor.py | 72 ++++++++++++++-------------- libcst/tests/test_pickle.py | 15 ++++++ 8 files changed, 141 insertions(+), 44 deletions(-) create mode 100644 libcst/tests/test_pickle.py diff --git a/libcst/_add_slots.py b/libcst/_add_slots.py index 6e9c00415..215f19b13 100644 --- a/libcst/_add_slots.py +++ b/libcst/_add_slots.py @@ -3,6 +3,7 @@ # https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f/dataclass_tools.py import dataclasses +from itertools import chain, filterfalse from typing import Any, Mapping, Type, TypeVar _T = TypeVar("_T") @@ -19,7 +20,14 @@ def add_slots(cls: Type[_T]) -> Type[_T]: # Create a new dict for our new class. cls_dict = dict(cls.__dict__) field_names = tuple(f.name for f in dataclasses.fields(cls)) - cls_dict["__slots__"] = field_names + inherited_slots = set( + chain.from_iterable( + superclass.__dict__.get("__slots__", ()) for superclass in cls.mro() + ) + ) + cls_dict["__slots__"] = tuple( + filterfalse(inherited_slots.__contains__, field_names) + ) for field_name in field_names: # Remove our attributes, if present. They'll still be # available in _MARKER. @@ -50,7 +58,10 @@ def add_slots(cls: Type[_T]) -> Type[_T]: def __getstate__(self: object) -> Mapping[str, Any]: return { - slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot) + slot: getattr(self, slot) + # pyre-ignore[16]: `object` has no attribute `__dataclass_fields__`. + for slot in self.__dataclass_fields__ + if hasattr(self, slot) } def __setstate__(self: object, state: Mapping[str, Any]) -> None: diff --git a/libcst/_nodes/base.py b/libcst/_nodes/base.py index 9173414b1..6ad81d90e 100644 --- a/libcst/_nodes/base.py +++ b/libcst/_nodes/base.py @@ -109,6 +109,12 @@ def _clone(val: object) -> object: @dataclass(frozen=True) class CSTNode(ABC): + + # pyre-ignore[4]: Attribute `__slots__` of class `CSTNode` + # has type `typing.Tuple[]` but no type is specified. + # But we know it's str + __slots__ = () + def __post_init__(self) -> None: # PERF: It might make more sense to move validation work into the visitor, which # would allow us to avoid validating the tree when parsing a file. @@ -468,6 +474,9 @@ def field(cls, *args: object, **kwargs: object) -> Any: class BaseLeaf(CSTNode, ABC): + + __slots__ = () + @property def children(self) -> Sequence[CSTNode]: # override this with an optimized implementation @@ -487,6 +496,8 @@ class BaseValueToken(BaseLeaf, ABC): into the parent CSTNode, and hard-coded into the implementation of _codegen. """ + __slots__ = () + value: str def _codegen_impl(self, state: CodegenState) -> None: diff --git a/libcst/_nodes/expression.py b/libcst/_nodes/expression.py index dba5faf3f..6b86a8b24 100644 --- a/libcst/_nodes/expression.py +++ b/libcst/_nodes/expression.py @@ -222,6 +222,8 @@ class _BaseParenthesizedNode(CSTNode, ABC): this to get that functionality. """ + __slots__ = () + lpar: Sequence[LeftParen] = () # Sequence of parenthesis for precedence dictation. rpar: Sequence[RightParen] = () @@ -254,6 +256,8 @@ class BaseExpression(_BaseParenthesizedNode, ABC): An base class for all expressions. :class:`BaseExpression` contains no fields. """ + __slots__ = () + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: """ Returns true if this expression is safe to be use with a word operator @@ -296,7 +300,7 @@ class BaseAssignTargetExpression(BaseExpression, ABC): `_. """ - pass + __slots__ = () class BaseDelTargetExpression(BaseExpression, ABC): @@ -316,7 +320,7 @@ class BaseDelTargetExpression(BaseExpression, ABC): `_. """ - pass + __slots__ = () @add_slots @@ -393,6 +397,8 @@ class BaseNumber(BaseExpression, ABC): used anywhere that you need to explicitly take any number type. """ + __slots__ = () + def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool: """ Numbers are funny. The expression "5in [1,2,3,4,5]" is a valid expression @@ -522,13 +528,16 @@ class BaseString(BaseExpression, ABC): :class:`SimpleString`, :class:`ConcatenatedString`, and :class:`FormattedString`. """ - pass + __slots__ = () StringQuoteLiteral = Literal['"', "'", '"""', "'''"] class _BasePrefixedString(BaseString, ABC): + + __slots__ = () + @property def prefix(self) -> str: """ @@ -699,7 +708,7 @@ class BaseFormattedStringContent(CSTNode, ABC): sequence of :class:`BaseFormattedStringContent` parts. """ - pass + __slots__ = () @add_slots @@ -1415,6 +1424,8 @@ class BaseSlice(CSTNode, ABC): This node is purely for typing. """ + __slots__ = () + @add_slots @dataclass(frozen=True) @@ -2190,6 +2201,8 @@ class _BaseExpressionWithArgs(BaseExpression, ABC): in typing. So, we have common validation functions here. """ + __slots__ = () + #: Sequence of arguments that will be passed to the function call. args: Sequence[Arg] = () @@ -2631,6 +2644,8 @@ class _BaseElementImpl(CSTNode, ABC): An internal base class for :class:`Element` and :class:`DictElement`. """ + __slots__ = () + value: BaseExpression comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT @@ -2668,6 +2683,8 @@ class BaseElement(_BaseElementImpl, ABC): BaseDictElement. """ + __slots__ = () + class BaseDictElement(_BaseElementImpl, ABC): """ @@ -2675,6 +2692,8 @@ class BaseDictElement(_BaseElementImpl, ABC): BaseElement. """ + __slots__ = () + @add_slots @dataclass(frozen=True) @@ -2957,6 +2976,8 @@ class BaseList(BaseExpression, ABC): object when evaluated. """ + __slots__ = () + lbracket: LeftSquareBracket = LeftSquareBracket.field() #: Brackets surrounding the list. rbracket: RightSquareBracket = RightSquareBracket.field() @@ -3037,6 +3058,8 @@ class _BaseSetOrDict(BaseExpression, ABC): shouldn't be exported. """ + __slots__ = () + lbrace: LeftCurlyBrace = LeftCurlyBrace.field() #: Braces surrounding the set or dict. rbrace: RightCurlyBrace = RightCurlyBrace.field() @@ -3062,6 +3085,8 @@ class BaseSet(_BaseSetOrDict, ABC): a set object when evaluated. """ + __slots__ = () + @add_slots @dataclass(frozen=True) @@ -3131,6 +3156,8 @@ class BaseDict(_BaseSetOrDict, ABC): a dict object when evaluated. """ + __slots__ = () + @add_slots @dataclass(frozen=True) @@ -3407,6 +3434,8 @@ class BaseComp(BaseExpression, ABC): :class:`GeneratorExp`, :class:`ListComp`, :class:`SetComp`, and :class:`DictComp`. """ + __slots__ = () + for_in: CompFor @@ -3417,6 +3446,8 @@ class BaseSimpleComp(BaseComp, ABC): ``value``. """ + __slots__ = () + #: The expression evaluated during each iteration of the comprehension. This #: lexically comes before the ``for_in`` clause, but it is semantically the #: inner-most element, evaluated inside the ``for_in`` clause. diff --git a/libcst/_nodes/op.py b/libcst/_nodes/op.py index ea02835a9..e19d24d34 100644 --- a/libcst/_nodes/op.py +++ b/libcst/_nodes/op.py @@ -19,6 +19,8 @@ class _BaseOneTokenOp(CSTNode, ABC): Any node that has a static value and needs to own whitespace on both sides. """ + __slots__ = () + whitespace_before: BaseParenthesizableWhitespace whitespace_after: BaseParenthesizableWhitespace @@ -51,6 +53,8 @@ class _BaseTwoTokenOp(CSTNode, ABC): in beteween them. """ + __slots__ = () + whitespace_before: BaseParenthesizableWhitespace whitespace_between: BaseParenthesizableWhitespace @@ -93,6 +97,8 @@ class BaseUnaryOp(CSTNode, ABC): Any node that has a static value used in a :class:`UnaryOperation` expression. """ + __slots__ = () + #: Any space that appears directly after this operator. whitespace_after: BaseParenthesizableWhitespace @@ -119,6 +125,8 @@ class BaseBooleanOp(_BaseOneTokenOp, ABC): This node is purely for typing. """ + __slots__ = () + class BaseBinaryOp(CSTNode, ABC): """ @@ -126,6 +134,8 @@ class BaseBinaryOp(CSTNode, ABC): This node is purely for typing. """ + __slots__ = () + class BaseCompOp(CSTNode, ABC): """ @@ -133,6 +143,8 @@ class BaseCompOp(CSTNode, ABC): This node is purely for typing. """ + __slots__ = () + class BaseAugOp(CSTNode, ABC): """ @@ -140,6 +152,8 @@ class BaseAugOp(CSTNode, ABC): This node is purely for typing. """ + __slots__ = () + @add_slots @dataclass(frozen=True) diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 9493f57c2..675ae6f2d 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -79,6 +79,8 @@ class BaseSuite(CSTNode, ABC): -- https://docs.python.org/3/reference/compound_stmts.html """ + __slots__ = () + body: Union[Sequence["BaseStatement"], Sequence["BaseSmallStatement"]] @@ -88,7 +90,7 @@ class BaseStatement(CSTNode, ABC): in a particular location. """ - pass + __slots__ = () class BaseSmallStatement(CSTNode, ABC): @@ -99,6 +101,8 @@ class BaseSmallStatement(CSTNode, ABC): simplify type definitions and isinstance checks. """ + __slots__ = () + #: An optional semicolon that appears after a small statement. This is optional #: for the last small statement in a :class:`SimpleStatementLine` or #: :class:`SimpleStatementSuite`, but all other small statements inside a simple @@ -165,7 +169,8 @@ def _codegen_impl( semicolon._codegen(state) -@add_slots +# TODO: re-add slots after fixing test_batched_visitor.py +# @add_slots @dataclass(frozen=True) class Pass(BaseSmallStatement): """ @@ -370,6 +375,8 @@ class _BaseSimpleStatement(CSTNode, ABC): small statement. """ + __slots__ = () + #: Sequence of small statements. All but the last statement are required to have #: a semicolon. body: Sequence[BaseSmallStatement] @@ -554,6 +561,8 @@ class BaseCompoundStatement(BaseStatement, ABC): -- https://docs.python.org/3/reference/compound_stmts.html """ + __slots__ = () + #: The body of this compound statement. body: BaseSuite @@ -2633,6 +2642,8 @@ class MatchPattern(_BaseParenthesizedNode, ABC): statement. """ + __slots__ = () + @add_slots @dataclass(frozen=True) @@ -2960,6 +2971,8 @@ class MatchSequence(MatchPattern, ABC): otherwise matches a fixed length sequence. """ + __slots__ = () + #: Patterns to be matched against the subject elements if it is a sequence. patterns: Sequence[Union[MatchSequenceElement, MatchStar]] diff --git a/libcst/_nodes/whitespace.py b/libcst/_nodes/whitespace.py index 686c14fb6..b1332c132 100644 --- a/libcst/_nodes/whitespace.py +++ b/libcst/_nodes/whitespace.py @@ -48,6 +48,8 @@ class BaseParenthesizableWhitespace(CSTNode, ABC): ``iftest``), it has some semantic value. """ + __slots__ = () + # TODO: Should we somehow differentiate places where we require non-zero whitespace # with a separate type? diff --git a/libcst/tests/test_batched_visitor.py b/libcst/tests/test_batched_visitor.py index 9bcc562fe..0e34343c8 100644 --- a/libcst/tests/test_batched_visitor.py +++ b/libcst/tests/test_batched_visitor.py @@ -16,57 +16,57 @@ def test_simple(self) -> None: mock = Mock() class ABatchable(BatchableCSTVisitor): - def visit_Pass(self, node: cst.Pass) -> None: - mock.visited_a() - object.__setattr__(node, "a_attr", True) + def visit_Del(self, node: cst.Del) -> None: + object.__setattr__(node, "target", mock.visited_a()) class BBatchable(BatchableCSTVisitor): - def visit_Pass(self, node: cst.Pass) -> None: - mock.visited_b() - object.__setattr__(node, "b_attr", 1) + def visit_Del(self, node: cst.Del) -> None: + object.__setattr__(node, "semicolon", mock.visited_b()) - module = visit_batched(parse_module("pass"), [ABatchable(), BBatchable()]) - pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] - - # Check properties were set - self.assertEqual(object.__getattribute__(pass_, "a_attr"), True) - self.assertEqual(object.__getattribute__(pass_, "b_attr"), 1) + module = visit_batched(parse_module("del 5"), [ABatchable(), BBatchable()]) + del_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] # Check that each visitor was only called once mock.visited_a.assert_called_once() mock.visited_b.assert_called_once() + # Check properties were set + self.assertEqual(object.__getattribute__(del_, "target"), mock.visited_a()) + self.assertEqual(object.__getattribute__(del_, "semicolon"), mock.visited_b()) + def test_all_visits(self) -> None: mock = Mock() class Batchable(BatchableCSTVisitor): - def visit_Pass(self, node: cst.Pass) -> None: - mock.visit_Pass() - object.__setattr__(node, "visit_Pass", True) + def visit_If(self, node: cst.If) -> None: + object.__setattr__(node, "test", mock.visit_If()) - def visit_Pass_semicolon(self, node: cst.Pass) -> None: - mock.visit_Pass_semicolon() - object.__setattr__(node, "visit_Pass_semicolon", True) + def visit_If_body(self, node: cst.If) -> None: + object.__setattr__(node, "leading_lines", mock.visit_If_body()) - def leave_Pass_semicolon(self, node: cst.Pass) -> None: - mock.leave_Pass_semicolon() - object.__setattr__(node, "leave_Pass_semicolon", True) + def leave_If_body(self, node: cst.If) -> None: + object.__setattr__(node, "orelse", mock.leave_If_body()) - def leave_Pass(self, original_node: cst.Pass) -> None: - mock.leave_Pass() - object.__setattr__(original_node, "leave_Pass", True) + def leave_If(self, original_node: cst.If) -> None: + object.__setattr__( + original_node, "whitespace_before_test", mock.leave_If() + ) - module = visit_batched(parse_module("pass"), [Batchable()]) - pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] - - # Check properties were set - self.assertEqual(object.__getattribute__(pass_, "visit_Pass"), True) - self.assertEqual(object.__getattribute__(pass_, "leave_Pass"), True) - self.assertEqual(object.__getattribute__(pass_, "visit_Pass_semicolon"), True) - self.assertEqual(object.__getattribute__(pass_, "leave_Pass_semicolon"), True) + module = visit_batched(parse_module("if True: pass"), [Batchable()]) + if_ = cast(cst.SimpleStatementLine, module.body[0]) # Check that each visitor was only called once - mock.visit_Pass.assert_called_once() - mock.leave_Pass.assert_called_once() - mock.visit_Pass_semicolon.assert_called_once() - mock.leave_Pass_semicolon.assert_called_once() + mock.visit_If.assert_called_once() + mock.leave_If.assert_called_once() + mock.visit_If_body.assert_called_once() + mock.leave_If_body.assert_called_once() + + # Check properties were set + self.assertEqual(object.__getattribute__(if_, "test"), mock.visit_If()) + self.assertEqual( + object.__getattribute__(if_, "leading_lines"), mock.visit_If_body() + ) + self.assertEqual(object.__getattribute__(if_, "orelse"), mock.leave_If_body()) + self.assertEqual( + object.__getattribute__(if_, "whitespace_before_test"), mock.leave_If() + ) diff --git a/libcst/tests/test_pickle.py b/libcst/tests/test_pickle.py new file mode 100644 index 000000000..38fe4b4ca --- /dev/null +++ b/libcst/tests/test_pickle.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pickle + +from libcst import parse_module +from libcst.testing.utils import UnitTest + + +class PickleTest(UnitTest): + def test_load_and_dump(self) -> None: + module = parse_module("5 + 3") + self.assertTrue(module.deep_equals(pickle.loads(pickle.dumps(module)))) From 4ce4a3a0caddc65651fb0f9a3d0eca8063d4a401 Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Fri, 14 Jan 2022 22:24:08 +0100 Subject: [PATCH 2/7] state changes in apache 2.0 licensed add_slots --- libcst/_add_slots.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libcst/_add_slots.py b/libcst/_add_slots.py index 215f19b13..acf63c9de 100644 --- a/libcst/_add_slots.py +++ b/libcst/_add_slots.py @@ -1,6 +1,7 @@ # This file is derived from github.com/ericvsmith/dataclasses, and is Apache 2 licensed. # https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f188f452/LICENSE.txt # https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f/dataclass_tools.py +# Changed: takes slots in base classes into account when creating slots import dataclasses from itertools import chain, filterfalse From 54997b68af9dd219d14624f980af9dd1b553e52d Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Fri, 14 Jan 2022 22:27:25 +0100 Subject: [PATCH 3/7] naming tweak in add_slots --- libcst/_add_slots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libcst/_add_slots.py b/libcst/_add_slots.py index acf63c9de..4f9d2a2b2 100644 --- a/libcst/_add_slots.py +++ b/libcst/_add_slots.py @@ -59,10 +59,10 @@ def add_slots(cls: Type[_T]) -> Type[_T]: def __getstate__(self: object) -> Mapping[str, Any]: return { - slot: getattr(self, slot) + field: getattr(self, field) # pyre-ignore[16]: `object` has no attribute `__dataclass_fields__`. - for slot in self.__dataclass_fields__ - if hasattr(self, slot) + for field in self.__dataclass_fields__ + if hasattr(self, field) } def __setstate__(self: object, state: Mapping[str, Any]) -> None: From 04a93af0df2b3c5ea8df9abe8ade287d2ab67c94 Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Fri, 14 Jan 2022 22:28:16 +0100 Subject: [PATCH 4/7] remove lingering TODO --- libcst/_nodes/statement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 675ae6f2d..ded7c7c6c 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -169,8 +169,7 @@ def _codegen_impl( semicolon._codegen(state) -# TODO: re-add slots after fixing test_batched_visitor.py -# @add_slots +@add_slots @dataclass(frozen=True) class Pass(BaseSmallStatement): """ From 1bfbb457f066ef4ae2dd1555a943e9d866d40652 Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Sat, 15 Jan 2022 16:38:05 +0100 Subject: [PATCH 5/7] remove pyre ignore, add test for fixed add_slots inheritance behavior --- libcst/_nodes/base.py | 7 ++---- libcst/tests/test_add_slots.py | 40 ++++++++++++++++++++++++++++++++++ libcst/tests/test_pickle.py | 15 ------------- 3 files changed, 42 insertions(+), 20 deletions(-) create mode 100644 libcst/tests/test_add_slots.py delete mode 100644 libcst/tests/test_pickle.py diff --git a/libcst/_nodes/base.py b/libcst/_nodes/base.py index 6ad81d90e..035976418 100644 --- a/libcst/_nodes/base.py +++ b/libcst/_nodes/base.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from dataclasses import dataclass, field, fields, replace -from typing import Any, cast, Dict, List, Mapping, Sequence, TypeVar, Union +from typing import Any, cast, ClassVar, Dict, List, Mapping, Sequence, TypeVar, Union from libcst._flatten_sentinel import FlattenSentinel from libcst._nodes.internal import CodegenState @@ -110,10 +110,7 @@ def _clone(val: object) -> object: @dataclass(frozen=True) class CSTNode(ABC): - # pyre-ignore[4]: Attribute `__slots__` of class `CSTNode` - # has type `typing.Tuple[]` but no type is specified. - # But we know it's str - __slots__ = () + __slots__: ClassVar[Sequence[str]] = () def __post_init__(self) -> None: # PERF: It might make more sense to move validation work into the visitor, which diff --git a/libcst/tests/test_add_slots.py b/libcst/tests/test_add_slots.py new file mode 100644 index 000000000..612147c85 --- /dev/null +++ b/libcst/tests/test_add_slots.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +from dataclasses import dataclass + +from libcst._add_slots import add_slots + +from libcst.testing.utils import UnitTest + + +@add_slots +@dataclass(frozen=True) +class A: + x: int + y: str + + +class AddSlotsTest(UnitTest): + def test_pickle(self) -> None: + a = A(1, "foo") + self.assertEqual(a, pickle.loads(pickle.dumps(a))) + + def test_prevents_slots_overlap(self) -> None: + class A: + __slots__ = ("x",) + + class B(A): + __slots__ = ("z",) + + @add_slots + @dataclass + class C(B): + x: int + y: str + z: bool + + self.assertSequenceEqual(C.__slots__, ("y",)) diff --git a/libcst/tests/test_pickle.py b/libcst/tests/test_pickle.py deleted file mode 100644 index 38fe4b4ca..000000000 --- a/libcst/tests/test_pickle.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import pickle - -from libcst import parse_module -from libcst.testing.utils import UnitTest - - -class PickleTest(UnitTest): - def test_load_and_dump(self) -> None: - module = parse_module("5 + 3") - self.assertTrue(module.deep_equals(pickle.loads(pickle.dumps(module)))) From fbe97faebce4d7d158a4c50d61c062073087d279 Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Sat, 15 Jan 2022 20:46:56 +0100 Subject: [PATCH 6/7] fix example in visitor test causing ParserSyntaxError on native --- libcst/tests/test_batched_visitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libcst/tests/test_batched_visitor.py b/libcst/tests/test_batched_visitor.py index 0e34343c8..9009847c6 100644 --- a/libcst/tests/test_batched_visitor.py +++ b/libcst/tests/test_batched_visitor.py @@ -23,7 +23,7 @@ class BBatchable(BatchableCSTVisitor): def visit_Del(self, node: cst.Del) -> None: object.__setattr__(node, "semicolon", mock.visited_b()) - module = visit_batched(parse_module("del 5"), [ABatchable(), BBatchable()]) + module = visit_batched(parse_module("del a"), [ABatchable(), BBatchable()]) del_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] # Check that each visitor was only called once From dcb6447f53cd5cd73dac656bb83a62e022a4c2af Mon Sep 17 00:00:00 2001 From: Arie Bovenberg Date: Sun, 16 Jan 2022 12:09:46 +0100 Subject: [PATCH 7/7] some more rigorous testing of add_slots pickling --- libcst/_add_slots.py | 11 +++++------ libcst/tests/test_add_slots.py | 6 ++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/libcst/_add_slots.py b/libcst/_add_slots.py index 4f9d2a2b2..bbe2c6343 100644 --- a/libcst/_add_slots.py +++ b/libcst/_add_slots.py @@ -59,15 +59,14 @@ def add_slots(cls: Type[_T]) -> Type[_T]: def __getstate__(self: object) -> Mapping[str, Any]: return { - field: getattr(self, field) - # pyre-ignore[16]: `object` has no attribute `__dataclass_fields__`. - for field in self.__dataclass_fields__ - if hasattr(self, field) + field.name: getattr(self, field.name) + for field in dataclasses.fields(self) + if hasattr(self, field.name) } def __setstate__(self: object, state: Mapping[str, Any]) -> None: - for slot, value in state.items(): - object.__setattr__(self, slot, value) + for fieldname, value in state.items(): + object.__setattr__(self, fieldname, value) cls.__getstate__ = __getstate__ cls.__setstate__ = __setstate__ diff --git a/libcst/tests/test_add_slots.py b/libcst/tests/test_add_slots.py index 612147c85..e354f60b6 100644 --- a/libcst/tests/test_add_slots.py +++ b/libcst/tests/test_add_slots.py @@ -5,23 +5,29 @@ import pickle from dataclasses import dataclass +from typing import ClassVar from libcst._add_slots import add_slots from libcst.testing.utils import UnitTest +# this test class needs to be defined at module level to test pickling. @add_slots @dataclass(frozen=True) class A: x: int y: str + Z: ClassVar[int] = 5 + class AddSlotsTest(UnitTest): def test_pickle(self) -> None: a = A(1, "foo") self.assertEqual(a, pickle.loads(pickle.dumps(a))) + object.__delattr__(a, "y") + self.assertEqual(a.x, pickle.loads(pickle.dumps(a)).x) def test_prevents_slots_overlap(self) -> None: class A: