Skip to content

Commit

Permalink
some more rigorous testing of add_slots pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
ariebovenberg committed Jan 16, 2022
1 parent fbe97fa commit dcb6447
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 5 additions & 6 deletions libcst/_add_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,14 @@ def add_slots(cls: Type[_T]) -> Type[_T]:

def __getstate__(self: object) -> Mapping[str, Any]:
return {
field: getattr(self, field)
# pyre-ignore[16]: `object` has no attribute `__dataclass_fields__`.
for field in self.__dataclass_fields__
if hasattr(self, field)
field.name: getattr(self, field.name)
for field in dataclasses.fields(self)
if hasattr(self, field.name)
}

def __setstate__(self: object, state: Mapping[str, Any]) -> None:
for slot, value in state.items():
object.__setattr__(self, slot, value)
for fieldname, value in state.items():
object.__setattr__(self, fieldname, value)

cls.__getstate__ = __getstate__
cls.__setstate__ = __setstate__
Expand Down
6 changes: 6 additions & 0 deletions libcst/tests/test_add_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,29 @@

import pickle
from dataclasses import dataclass
from typing import ClassVar

from libcst._add_slots import add_slots

from libcst.testing.utils import UnitTest


# this test class needs to be defined at module level to test pickling.
@add_slots
@dataclass(frozen=True)
class A:
x: int
y: str

Z: ClassVar[int] = 5


class AddSlotsTest(UnitTest):
def test_pickle(self) -> None:
a = A(1, "foo")
self.assertEqual(a, pickle.loads(pickle.dumps(a)))
object.__delattr__(a, "y")
self.assertEqual(a.x, pickle.loads(pickle.dumps(a)).x)

def test_prevents_slots_overlap(self) -> None:
class A:
Expand Down

0 comments on commit dcb6447

Please sign in to comment.