Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Pickle's protocol 5 #3784

Merged
merged 18 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import pickle
from pickle import HIGHEST_PROTOCOL

import cloudpickle

Expand All @@ -23,36 +24,46 @@ def _always_use_pickle_for(x):
return False


def dumps(x):
def dumps(x, *, buffer_callback=None):
""" Manage between cloudpickle and pickle

1. Try pickle
2. If it is short then check if it contains __main__
3. If it is long, then first check type, then check __main__
"""
buffers = []
dump_kwargs = {"protocol": HIGHEST_PROTOCOL}
if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None:
dump_kwargs["buffer_callback"] = buffers.append
try:
result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
buffers.clear()
result = pickle.dumps(x, **dump_kwargs)
if len(result) < 1000:
if b"__main__" in result:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
else:
return result
else:
if _always_use_pickle_for(x) or b"__main__" not in result:
return result
else:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
elif not _always_use_pickle_for(x) and b"__main__" in result:
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception:
try:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
buffers.clear()
result = cloudpickle.dumps(x, **dump_kwargs)
except Exception as e:
logger.info("Failed to serialize %s. Exception: %s", x, e)
raise
if buffer_callback is not None:
for b in buffers:
buffer_callback(b)
return result


def loads(x):
def loads(x, *, buffers=()):
try:
return pickle.loads(x)
if buffers:
return pickle.loads(x, buffers=buffers)
else:
return pickle.loads(x)
except Exception:
logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
raise
9 changes: 7 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ def dask_loads(header, frames):


def pickle_dumps(x):
return {"serializer": "pickle"}, [pickle.dumps(x)]
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(x, buffer_callback=buffer_callback)
return header, frames


def pickle_loads(header, frames):
return pickle.loads(b"".join(frames))
x, buffers = frames[0], frames[1:]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My expectation is that frames[1:] would be an empty list here. It might not be though. If it's not, that would be good to fix. Whether that has to do with deserialize_bytes behavior or not, I'm less sure. Maybe someone else knows?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

frames in pickle_loads comes from deserialize_bytes for me (take a look at exception stack trace). And for me frames[1:] is not empty. That's why I am asking if this is expected or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. So do we have a reproducer? That should help us debug this further.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I've been trying to reproduce this too, but haven't had any luck. Though I believe you that there could be a problem here. It's just tricky to come up with fixes in the dark. So anything you can come up with would help 🙂

Copy link

@gshimansky gshimansky May 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you say that Dask is able to split frames, may it happen when large pandas DataFrames are serialized? My test is a benchmark which operates on such objects, about 20M of records long and 40 columns wide.
Modin splits them for parallel processing, but they still remain considerably large in memory.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like PR ( #3841 ) will help? Still lack a test case though. So I have no way to confirm whether it actually helps (or ensure we don't accidentally regress).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even when sending large bytes that do get split by frame_split_size, I'm unable to reproduce an error. Things are properly reassembled on the other side. Not sure what's going on there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah me neither unfortunately. Happy to dig deeper once we identify a reproducer 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar issue was raised recently ( #3851 ). Not sure if that is the same. Suspect PR ( #3639 ) fixes this. Would be good if you can try though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was fixed by PR ( #3639 ).

return pickle.loads(x, buffers=buffers)


def msgpack_dumps(x):
Expand Down
82 changes: 80 additions & 2 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,93 @@

import pytest

from distributed.protocol.pickle import dumps, loads
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads

try:
from pickle import PickleBuffer
except ImportError:
pass


def test_pickle_data():
data = [1, b"123", "123", [123], {}, set()]
for d in data:
assert loads(dumps(d)) == d
assert deserialize(*serialize(d, serializers=("pickle",))) == d


def test_pickle_out_of_band():
class MemoryviewHolder:
def __init__(self, mv):
self.mv = memoryview(mv)

def __reduce_ex__(self, protocol):
if protocol >= 5:
return MemoryviewHolder, (PickleBuffer(self.mv),)
else:
return MemoryviewHolder, (self.mv.tobytes(),)

mv = memoryview(b"123")
mvh = MemoryviewHolder(mv)

if HIGHEST_PROTOCOL >= 5:
l = []
d = dumps(mvh, buffer_callback=l.append)
mvh2 = loads(d, buffers=l)

assert len(l) == 1
assert isinstance(l[0], PickleBuffer)
assert memoryview(l[0]) == mv
else:
mvh2 = loads(dumps(mvh))

assert isinstance(mvh2, MemoryviewHolder)
assert isinstance(mvh2.mv, memoryview)
assert mvh2.mv == mv

h, f = serialize(mvh, serializers=("pickle",))
mvh3 = deserialize(h, f)

assert isinstance(mvh3, MemoryviewHolder)
assert isinstance(mvh3.mv, memoryview)
assert mvh3.mv == mv

if HIGHEST_PROTOCOL >= 5:
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert f[1] == mv
else:
assert len(f) == 1
assert isinstance(f[0], bytes)


def test_pickle_numpy():
np = pytest.importorskip("numpy")
x = np.ones(5)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()

x = np.ones(5000)
assert (loads(dumps(x)) == x).all()
assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all()

if HIGHEST_PROTOCOL >= 5:
x = np.ones(5000)

l = []
d = dumps(x, buffer_callback=l.append)
assert len(l) == 1
assert isinstance(l[0], PickleBuffer)
assert memoryview(l[0]) == memoryview(x)
assert (loads(d, buffers=l) == x).all()

h, f = serialize(x, serializers=("pickle",))
assert len(f) == 2
assert isinstance(f[0], bytes)
assert isinstance(f[1], memoryview)
assert (deserialize(h, f) == x).all()


@pytest.mark.xfail(
Expand All @@ -45,10 +116,17 @@ def funcs():

for func in funcs():
wr = weakref.ref(func)

func2 = loads(dumps(func))
wr2 = weakref.ref(func2)
assert func2(1) == func(1)
del func, func2

func3 = deserialize(*serialize(func, serializers=("pickle",)))
wr3 = weakref.ref(func3)
assert func3(1) == func(1)

del func, func2, func3
gc.collect()
assert wr() is None
assert wr2() is None
assert wr3() is None
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
click >= 6.6
cloudpickle >= 0.2.2
cloudpickle >= 1.3.0
contextvars;python_version<'3.7'
dask >= 2.9.0
msgpack >= 0.6.0
Expand Down