Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize and split #4541

Merged
merged 12 commits into from
Feb 26, 2021
43 changes: 11 additions & 32 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@

from .compression import compressions, maybe_compress, decompress
from .serialize import (
serialize,
deserialize,
Serialize,
Serialized,
extract_serialize,
msgpack_decode_default,
msgpack_encode_default,
merge_and_deserialize,
serialize_and_split,
)
from .utils import frame_split_size, merge_frames, msgpack_opts
from ..utils import is_writeable, nbytes

_deserialize = deserialize
from .utils import msgpack_opts


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,7 +45,7 @@ def dumps(msg, serializers=None, on_error="message", context=None):
}

data = {
key: serialize(
key: serialize_and_split(
value.data, serializers=serializers, on_error=on_error, context=context
)
for key, value in data.items()
Expand All @@ -60,39 +57,23 @@ def dumps(msg, serializers=None, on_error="message", context=None):
out_frames = []

for key, (head, frames) in data.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in head:
head["lengths"] = tuple(map(nbytes, frames))

# Compress frames that are not yet compressed
out_compression = []
_out_frames = []
for frame, compression in zip(
frames, head.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
_frames = frame_split_size(frame)
_compression, _frames = zip(
*[maybe_compress(frame, **compress_opts) for frame in _frames]
)
out_compression.extend(_compression)
_out_frames.extend(_frames)
else: # already specified, so pass
out_compression.append(compression)
_out_frames.append(frame)
if compression is None:
compression, frame = maybe_compress(frame, **compress_opts)

out_compression.append(compression)
out_frames.append(frame)

head["compression"] = out_compression
head["count"] = len(_out_frames)
head["count"] = len(frames)
header["headers"][key] = head
header["keys"].append(key)
out_frames.extend(_out_frames)

for key, (head, frames) in pre.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in head:
head["lengths"] = tuple(map(nbytes, frames))
head["count"] = len(frames)
header["headers"][key] = head
header["keys"].append(key)
Expand Down Expand Up @@ -146,9 +127,7 @@ def loads(frames, deserialize=True, deserializers=None):
if deserialize or key in bytestrings:
if "compression" in head:
fs = decompress(head, fs)
if not any(hasattr(f, "__cuda_array_interface__") for f in fs):
fs = merge_frames(head, fs)
value = _deserialize(head, fs, deserializers=deserializers)
value = merge_and_deserialize(head, fs, deserializers=deserializers)
else:
value = Serialized(head, fs)

Expand Down
2 changes: 0 additions & 2 deletions distributed/protocol/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def cuda_serialize_cupy_ndarray(x):

header = x.__cuda_array_interface__.copy()
header["strides"] = tuple(x.strides)
header["lengths"] = [x.nbytes]
frames = [
cupy.ndarray(
shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,)
Expand All @@ -47,7 +46,6 @@ def cuda_deserialize_cupy_ndarray(header, frames):
@dask_serialize.register(cupy.ndarray)
def dask_serialize_cupy_ndarray(x):
header, frames = cuda_serialize_cupy_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(cupy.asnumpy(f)) for f in frames]
return header, frames

Expand Down
2 changes: 0 additions & 2 deletions distributed/protocol/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def cuda_serialize_numba_ndarray(x):

header = x.__cuda_array_interface__.copy()
header["strides"] = tuple(x.strides)
header["lengths"] = [x.nbytes]
frames = [
numba.cuda.cudadrv.devicearray.DeviceNDArray(
shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data
Expand Down Expand Up @@ -51,7 +50,6 @@ def cuda_deserialize_numba_ndarray(header, frames):
@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_serialize_numba_ndarray(x):
header, frames = cuda_serialize_numba_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(f.copy_to_host()) for f in frames]
return header, frames

Expand Down
18 changes: 12 additions & 6 deletions distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .serialize import dask_serialize, dask_deserialize
from . import pickle

from ..utils import log_errors, nbytes
from ..utils import log_errors


def itemsize(dt):
Expand All @@ -29,7 +29,6 @@ def serialize_numpy_ndarray(x, context=None):
buffer_callback=buffer_callback,
protocol=(context or {}).get("pickle-protocol", None),
)
header["lengths"] = tuple(map(nbytes, frames))
return header, frames

# We cannot blindly pickle the dtype as some may fail pickling,
Expand Down Expand Up @@ -93,15 +92,17 @@ def serialize_numpy_ndarray(x, context=None):
# "ValueError: cannot include dtype 'M' in a buffer"
data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data

header = {"dtype": dt, "shape": x.shape, "strides": strides}
header = {
"dtype": dt,
"shape": x.shape,
"strides": strides,
"writeable": [x.flags.writeable],
}

if broadcast_to is not None:
header["broadcast_to"] = broadcast_to

frames = [data]

header["lengths"] = [x.nbytes]

return header, frames


Expand All @@ -112,6 +113,7 @@ def deserialize_numpy_ndarray(header, frames):
return pickle.loads(frames[0], buffers=frames[1:])

(frame,) = frames
(writeable,) = header["writeable"]

is_custom, dt = header["dtype"]
if is_custom:
Expand All @@ -125,6 +127,10 @@ def deserialize_numpy_ndarray(header, frames):
shape = header["shape"]

x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"])
if not writeable:
x.flags.writeable = False
else:
x = np.require(x, requirements=["W"])

return x

Expand Down
2 changes: 0 additions & 2 deletions distributed/protocol/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
def cuda_serialize_rmm_device_buffer(x):
header = x.__cuda_array_interface__.copy()
header["strides"] = (1,)
header["lengths"] = [x.nbytes]
frames = [x]
return header, frames

Expand All @@ -31,7 +30,6 @@ def cuda_deserialize_rmm_device_buffer(header, frames):
@dask_serialize.register(rmm.DeviceBuffer)
def dask_serialize_rmm_device_buffer(x):
header, frames = cuda_serialize_rmm_device_buffer(x)
header["writeable"] = (None,) * len(frames)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to note that None has a special meaning here. It basically means it doesn't matter whether this is read-only or writeable. IOW skip trying to copy this. The reason we include this (and in particular on the Dask serialization path) is to avoid an extra copy of buffers we plan to move to device later

That said, I think the changes here may already capture this use case. Just wanted to surface the logic to hopefully clarify what is going on currently and catch any remaining things not yet addressed

frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
return header, frames

Expand Down
94 changes: 82 additions & 12 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import msgpack

from . import pickle
from ..utils import has_keyword, nbytes, typename, ensure_bytes, is_writeable
from ..utils import has_keyword, typename, ensure_bytes
from .compression import maybe_compress, decompress
from .utils import (
unpack_frames,
pack_frames_prelude,
frame_split_size,
merge_frames,
msgpack_opts,
)

Expand All @@ -30,7 +29,7 @@


def dask_dumps(x, context=None):
"""Serialise object using the class-based registry"""
"""Serialize object using the class-based registry"""
type_name = typename(type(x))
try:
dumps = dask_serialize.dispatch(type(x))
Expand All @@ -54,19 +53,30 @@ def dask_loads(header, frames):


def pickle_dumps(x, context=None):
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=context.get("pickle-protocol", None) if context else None,
)
header = {
"serializer": "pickle",
"writeable": tuple(not f.readonly for f in frames[1:]),
}
return header, frames


def pickle_loads(header, frames):
x, buffers = frames[0], frames[1:]
writeable = header["writeable"]
for i in range(len(buffers)):
readonly = memoryview(buffers[i]).readonly
if writeable[i]:
if readonly:
buffers[i] = bytearray(buffers[i])
elif not readonly:
buffers[i] = bytes(buffers[i])
madsbk marked this conversation as resolved.
Show resolved Hide resolved
madsbk marked this conversation as resolved.
Show resolved Hide resolved
return pickle.loads(x, buffers=buffers)


Expand Down Expand Up @@ -374,6 +384,72 @@ def deserialize(header, frames, deserializers=None):
return loads(header, frames)


def serialize_and_split(x, serializers=None, on_error="message", context=None):
"""Serialize and split compressable frames

This function is a drop-in replacement of `serialize()` that calls `serialize()`
followed by `frame_split_size()` on frames that should be compressed.

Use `merge_and_deserialize()` to merge and deserialize the frames back.

See Also
--------
serialize
merge_and_deserialize
"""
header, frames = serialize(x, serializers, on_error, context)
num_sub_frames = []
offsets = []
out_frames = []
out_compression = []
for frame, compression in zip(
frames, header.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
sub_frames = frame_split_size(frame)
num_sub_frames.append(len(sub_frames))
offsets.append(len(out_frames))
out_frames.extend(sub_frames)
out_compression.extend([None] * len(sub_frames))
else:
num_sub_frames.append(1)
offsets.append(len(out_frames))
out_frames.append(frame)
out_compression.append(compression)
assert len(out_compression) == len(out_frames)

# Notice, in order to match msgpack's implicit convertion to tuples,
# we convert to tuples here as well.
header["split-num-sub-frames"] = tuple(num_sub_frames)
header["split-offsets"] = tuple(offsets)
header["compression"] = tuple(out_compression)
return header, out_frames


def merge_and_deserialize(header, frames, deserializers=None):
"""Merge and deserialize frames

This function is a drop-in replacement of `deserialize()` that merges
frames that were split by `serialize_and_split()`

See Also
--------
deserialize
serialize_and_split
"""
merged_frames = []
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
if n == 1:
merged_frames.append(frames[offset])
else:
merged_frames.append(bytearray().join(frames[offset : offset + n]))

return deserialize(header, merged_frames, deserializers=deserializers)


class Serialize:
"""Mark an object that should be serialized

Expand Down Expand Up @@ -534,13 +610,8 @@ def replace_inner(x):


def serialize_bytelist(x, **kwargs):
header, frames = serialize(x, **kwargs)
if "writeable" not in header:
header["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in header:
header["lengths"] = tuple(map(nbytes, frames))
header, frames = serialize_and_split(x, **kwargs)
if frames:
frames = sum(map(frame_split_size, frames), [])
compression, frames = zip(*map(maybe_compress, frames))
else:
compression = []
Expand All @@ -566,8 +637,7 @@ def deserialize_bytes(b):
else:
header = {}
frames = decompress(header, frames)
frames = merge_frames(header, frames)
return deserialize(header, frames)
return merge_and_deserialize(header, frames)


################################
Expand Down
6 changes: 0 additions & 6 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ def test_compression_takes_advantage_of_itemsize():
assert sum(map(nbytes, aa)) < sum(map(nbytes, bb))


def test_large_numpy_array():
x = np.ones((100000000,), dtype="u4")
header, frames = serialize(x)
assert sum(header["lengths"]) == sum(map(nbytes, frames))


@pytest.mark.parametrize(
"x",
[
Expand Down
Loading