Skip to content

Commit

Permalink
gh-102103: add module argument to dataclasses.make_dataclass (#10…
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn committed Mar 11, 2023
1 parent ee6f841 commit b48be8f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
6 changes: 5 additions & 1 deletion Doc/library/dataclasses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ Module contents
:func:`astuple` raises :exc:`TypeError` if ``obj`` is not a dataclass
instance.

.. function:: make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)
.. function:: make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False, module=None)

Creates a new dataclass with name ``cls_name``, fields as defined
in ``fields``, base classes as given in ``bases``, and initialized
Expand All @@ -401,6 +401,10 @@ Module contents
``match_args``, ``kw_only``, ``slots``, and ``weakref_slot`` have
the same meaning as they do in :func:`dataclass`.

If ``module`` is defined, the ``__module__`` attribute
of the dataclass is set to that value.
By default, it is set to the module name of the caller.

This function is not strictly required, because any Python
mechanism for creating a new class with ``__annotations__`` can
then apply the :func:`dataclass` function to convert that class to
Expand Down
15 changes: 14 additions & 1 deletion Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ def _astuple_inner(obj, tuple_factory):
def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
repr=True, eq=True, order=False, unsafe_hash=False,
frozen=False, match_args=True, kw_only=False, slots=False,
weakref_slot=False):
weakref_slot=False, module=None):
"""Return a new dynamically created dataclass.
The dataclass name will be 'cls_name'. 'fields' is an iterable
Expand Down Expand Up @@ -1455,6 +1455,19 @@ def exec_body_callback(ns):
# of generic dataclasses.
cls = types.new_class(cls_name, bases, {}, exec_body_callback)

# For pickling to work, the __module__ variable needs to be set to the frame
# where the dataclass is created.
if module is None:
try:
module = sys._getframemodulename(1) or '__main__'
except AttributeError:
try:
module = sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
if module is not None:
cls.__module__ = module

# Apply the normal decorator.
return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen,
Expand Down
39 changes: 39 additions & 0 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3606,6 +3606,15 @@ def test_text_annotations(self):
'return': type(None)})


ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',
[('x', int)],
module='test.test_dataclasses')
WrongNameMakeDataclass = make_dataclass('Wrong', [('x', int)])
WrongModuleMakeDataclass = make_dataclass('WrongModuleMakeDataclass',
[('x', int)],
module='custom')

class TestMakeDataclass(unittest.TestCase):
def test_simple(self):
C = make_dataclass('C',
Expand Down Expand Up @@ -3715,6 +3724,36 @@ def test_no_types(self):
'y': int,
'z': 'typing.Any'})

def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
self.assertEqual(ByMakeDataClass(1).__module__, __name__)
self.assertEqual(WrongModuleMakeDataclass.__module__, "custom")
Nested = make_dataclass('Nested', [])
self.assertEqual(Nested.__module__, __name__)
self.assertEqual(Nested().__module__, __name__)

def test_pickle_support(self):
for klass in [ByMakeDataClass, ManualModuleMakeDataClass]:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
self.assertEqual(
pickle.loads(pickle.dumps(klass, proto)),
klass,
)
self.assertEqual(
pickle.loads(pickle.dumps(klass(1), proto)),
klass(1),
)

def test_cannot_be_pickled(self):
for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]:
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
with self.assertRaises(pickle.PickleError):
pickle.dumps(klass, proto)
with self.assertRaises(pickle.PickleError):
pickle.dumps(klass(1), proto)

def test_invalid_type_specification(self):
for bad_field in [(),
(1, 2, 3, 4),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add ``module`` argument to :func:`dataclasses.make_dataclass` and make
classes produced by it pickleable.

0 comments on commit b48be8f

Please sign in to comment.