Skip to content

Commit

Permalink
WIP: Zero copy numpy shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 21, 2023
1 parent 095a09d commit 9c648b3
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 99 deletions.
1 change: 1 addition & 0 deletions distributed/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
dask_serialize,
deserialize,
deserialize_bytes,
deserialize_bytestream,
nested_deserialize,
register_generic,
register_serialization,
Expand Down
32 changes: 22 additions & 10 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import traceback
from array import array
from collections.abc import Iterator

Check warning on line 7 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L7

Added line #L7 was not covered by tests
from enum import Enum
from functools import partial
from types import ModuleType
Expand Down Expand Up @@ -680,20 +681,31 @@ def serialize_bytelist(
return frames2


def serialize_bytes(x, **kwargs):
def serialize_bytes(x: object, **kwargs: Any) -> bytes:

Check warning on line 684 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L684

Added line #L684 was not covered by tests
L = serialize_bytelist(x, **kwargs)
return b"".join(L)


def deserialize_bytes(b):
frames = unpack_frames(b)
header, frames = frames[0], frames[1:]
if header:
header = msgpack.loads(header, raw=False, use_list=False)
else:
header = {}
frames = decompress(header, frames)
return merge_and_deserialize(header, frames)
def deserialize_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:

Check warning on line 689 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L689

Added line #L689 was not covered by tests
"""Deserialize the concatenated output of multiple calls to :func:`serialize_bytes`"""
while True:
frames = unpack_frames(b, remainder=True)
bin_header, frames, remainder = frames[0], frames[1:-1], frames[-1]
if bin_header:
header = msgpack.loads(bin_header, raw=False, use_list=False)

Check warning on line 695 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L692-L695

Added lines #L692 - L695 were not covered by tests
else:
header = {}
frames2 = decompress(header, frames)
yield merge_and_deserialize(header, frames2)

Check warning on line 699 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L697-L699

Added lines #L697 - L699 were not covered by tests

if remainder.nbytes == 0:
break
b = remainder

Check warning on line 703 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L701-L703

Added lines #L701 - L703 were not covered by tests


def deserialize_bytes(b: bytes | bytearray | memoryview) -> Any:

Check warning on line 706 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L706

Added line #L706 was not covered by tests
"""Deserialize the output of a single call to :func:`serialize_bytes`"""
return next(deserialize_bytestream(b))

Check warning on line 708 in distributed/protocol/serialize.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/serialize.py#L708

Added line #L708 was not covered by tests


################################
Expand Down
15 changes: 14 additions & 1 deletion distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@ def test_pack_frames():
b = pack_frames(frames)
assert isinstance(b, bytes)
frames2 = unpack_frames(b)
assert frames2 == frames

assert frames == frames2

@pytest.mark.parametrize("extra", [b"456", b""])
def test_unpack_frames_remainder(extra):
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)

frames2 = unpack_frames(b + extra)
assert frames2 == frames

frames2 = unpack_frames(b + extra, remainder=True)
assert isinstance(frames2[-1], memoryview)
assert frames2 == frames + [extra]


class TestMergeMemroyviews:
Expand Down
56 changes: 53 additions & 3 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
dask_serialize,
deserialize,
deserialize_bytes,
deserialize_bytestream,
dumps,
loads,
nested_deserialize,
Expand Down Expand Up @@ -265,21 +266,70 @@ def test_empty_loads_deep():
assert isinstance(e2[0][0][0], Empty)


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes(kwargs):
for x in [
1,
"abc",
np.arange(5),
b"ab" * int(40e6),
int(2**26) * b"ab",
(int(2**25) * b"ab", int(2**25) * b"ab"),
]:
b = serialize_bytes(x, **kwargs)
assert isinstance(b, bytes)
y = deserialize_bytes(b)
assert str(x) == str(y)
assert x == y


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes_numpy(kwargs):
x = np.arange(5)
b = serialize_bytes(x, **kwargs)
assert isinstance(b, bytes)
y = deserialize_bytes(b)
assert (x == y).all()


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_deserialize_bytes_zero_copy_read_only():
x = np.arange(5)
x.setflags(write=False)
blob = serialize_bytes(x, compression=False)
x2 = deserialize_bytes(blob)
x3 = deserialize_bytes(blob)
addr2 = x2.__array_interface__["data"][0]
addr3 = x3.__array_interface__["data"][0]
assert addr2 == addr3


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_deserialize_bytes_zero_copy_writeable():
x = np.arange(5)
blob = bytearray(serialize_bytes(x, compression=False))
x2 = deserialize_bytes(blob)
x3 = deserialize_bytes(blob)
x2[0] = 123
assert x3[0] == 123


@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_deserialize_bytestream(kwargs):
objs = [1, "abc", b"abc"]
blob = b"".join(serialize_bytes(obj, **kwargs) for obj in objs)
objs2 = list(deserialize_bytestream(blob))
assert objs == objs2


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_deserialize_bytestream_numpy(kwargs):
x = np.arange(5)
y = np.arange(3, 8)
blob = serialize_bytes(x, **kwargs) + serialize_bytes(y, **kwargs)
x2, y2 = deserialize_bytestream(blob)
assert (x2 == x).all()
assert (y2 == y).all()


@pytest.mark.skipif(np is None, reason="Test needs numpy")
Expand Down
16 changes: 15 additions & 1 deletion distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,23 @@ def pack_frames(frames):
return b"".join([pack_frames_prelude(frames), *frames])


def unpack_frames(b):
def unpack_frames(

Check warning on line 64 in distributed/protocol/utils.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/utils.py#L64

Added line #L64 was not covered by tests
b: bytes | bytearray | memoryview, *, remainder: bool = False
) -> list[memoryview]:
"""Unpack bytes into a sequence of frames
This assumes that length information is at the front of the bytestring,
as performed by pack_frames
Parameters
----------
b:
packed frames, as returned by :func:`pack_frames`
remainder:
if True, return one extra frame at the end which is the continuation of a
stream created by concatenating multiple calls to :func:`pack_frames`.
This last frame will be empty at the end of the stream.
See Also
--------
pack_frames
Expand All @@ -86,6 +97,9 @@ def unpack_frames(b):
frames.append(b[start:end])
start = end

if remainder:
frames.append(b[start:])

Check warning on line 101 in distributed/protocol/utils.py

View check run for this annotation

Codecov / codecov/patch

distributed/protocol/utils.py#L101

Added line #L101 was not covered by tests

return frames


Expand Down
5 changes: 4 additions & 1 deletion distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ShardsBuffer(Generic[ShardType]):

shards: defaultdict[str, _List[ShardType]]
sizes: defaultdict[str, int]
size_per_shard: dict[int, int]

Check warning on line 48 in distributed/shuffle/_buffer.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_buffer.py#L48

Added line #L48 was not covered by tests
concurrency_limit: int
memory_limiter: ResourceLimiter
diagnostics: dict[str, float]
Expand All @@ -71,6 +72,7 @@ def __init__(
self._accepts_input = True
self.shards = defaultdict(_List)
self.sizes = defaultdict(int)
self.size_per_shard = {}

Check warning on line 75 in distributed/shuffle/_buffer.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_buffer.py#L75

Added line #L75 was not covered by tests
self._exception = None
self.concurrency_limit = concurrency_limit
self._inputs_done = False
Expand Down Expand Up @@ -149,7 +151,7 @@ def _continue() -> bool:
try:
shard = self.shards[part_id].pop()
shards.append(shard)
s = sizeof(shard)
s = self.size_per_shard.pop(id(shard))

Check warning on line 154 in distributed/shuffle/_buffer.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_buffer.py#L154

Added line #L154 was not covered by tests
size += s
self.sizes[part_id] -= s
except IndexError:
Expand Down Expand Up @@ -201,6 +203,7 @@ async def write(self, data: dict[str, ShardType]) -> None:
async with self._shards_available:
for worker, shard in data.items():
self.shards[worker].append(shard)
self.size_per_shard[id(shard)] = sizes[worker]

Check warning on line 206 in distributed/shuffle/_buffer.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_buffer.py#L206

Added line #L206 was not covered by tests
self.sizes[worker] += sizes[worker]
self._shards_available.notify()
await self.memory_limiter.wait_for_available()
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CommShardsBuffer(ShardsBuffer):

def __init__(
self,
send: Callable[[str, list[tuple[Any, bytes]]], Awaitable[None]],
send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]],
memory_limiter: ResourceLimiter,
concurrency_limit: int = 10,
):
Expand All @@ -63,7 +63,7 @@ def __init__(
)
self.send = send

async def _process(self, address: str, shards: list[tuple[Any, bytes]]) -> None:
async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None:

Check warning on line 66 in distributed/shuffle/_comms.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_comms.py#L66

Added line #L66 was not covered by tests
"""Send one message off to a neighboring worker"""
with log_errors():
# Consider boosting total_size a bit here to account for duplication
Expand Down
10 changes: 5 additions & 5 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ def heartbeat(self) -> dict[str, Any]:
}

async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, bytes]]
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)

async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None:
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:

Check warning on line 183 in distributed/shuffle/_core.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_core.py#L183

Added line #L183 was not covered by tests
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
Expand Down Expand Up @@ -228,7 +228,7 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:

Check warning on line 231 in distributed/shuffle/_core.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_core.py#L231

Added line #L231 was not covered by tests
await self._receive(data)

async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None:
Expand All @@ -248,7 +248,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""

@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:

Check warning on line 251 in distributed/shuffle/_core.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_core.py#L251

Added line #L251 was not covered by tests
"""Receive shards belonging to output partitions of this shuffle run"""

async def add_partition(
Expand Down Expand Up @@ -286,7 +286,7 @@ def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""

@abc.abstractmethod
def deserialize(self, buffer: bytes) -> Any:
def deserialize(self, buffer: Any) -> Any:

Check warning on line 289 in distributed/shuffle/_core.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_core.py#L289

Added line #L289 was not covered by tests
"""Deserialize shards"""


Expand Down
29 changes: 21 additions & 8 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import pathlib
import shutil
import threading
from collections.abc import Generator
from collections.abc import Callable, Generator, Iterable

Check warning on line 7 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L7

Added line #L7 was not covered by tests
from contextlib import contextmanager
from typing import Any, Callable
from typing import Any

Check warning on line 9 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L9

Added line #L9 was not covered by tests

from toolz import concat

Check warning on line 11 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L11

Added line #L11 was not covered by tests

from distributed.protocol import serialize_bytelist

Check warning on line 13 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L13

Added line #L13 was not covered by tests
from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import Deadline, log_errors
Expand Down Expand Up @@ -135,7 +138,7 @@ def __init__(
self._read = read
self._directory_lock = ReadWriteLock()

async def _process(self, id: str, shards: list[bytes]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:

Check warning on line 141 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L141

Added line #L141 was not covered by tests
"""Write one buffer to file
This function was built to offload the disk IO, but since then we've
Expand All @@ -157,11 +160,21 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with self._directory_lock.read():
if self._closed:
raise RuntimeError("Already closed")
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
f.write(shard)

frames: Iterable[bytes | bytearray | memoryview]

if not shards or isinstance(shards[0], bytes):

Check warning on line 166 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L166

Added line #L166 was not covered by tests
# Manually serialized dataframes
frames = shards

Check warning on line 168 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L168

Added line #L168 was not covered by tests
else:
# Unserialized numpy arrays
frames = concat(

Check warning on line 171 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L171

Added line #L171 was not covered by tests
serialize_bytelist(shard, compression=False)
for shard in shards
)

with open(self.directory / str(id), mode="ab") as f:
f.writelines(frames)

Check warning on line 177 in distributed/shuffle/_disk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_disk.py#L176-L177

Added lines #L176 - L177 were not covered by tests

def read(self, id: str) -> Any:
"""Read a complete file back into memory"""
Expand Down
12 changes: 5 additions & 7 deletions distributed/shuffle/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@


class MemoryShardsBuffer(ShardsBuffer):
_deserialize: Callable[[bytes], Any]
_shards: defaultdict[str, deque[bytes]]
_deserialize: Callable[[Any], Any]
_shards: defaultdict[str, deque[Any]]

Check warning on line 15 in distributed/shuffle/_memory.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_memory.py#L14-L15

Added lines #L14 - L15 were not covered by tests

def __init__(self, deserialize: Callable[[bytes], Any]) -> None:
super().__init__(
memory_limiter=ResourceLimiter(None),
)
def __init__(self, deserialize: Callable[[Any], Any]) -> None:
super().__init__(memory_limiter=ResourceLimiter(None))

Check warning on line 18 in distributed/shuffle/_memory.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_memory.py#L17-L18

Added lines #L17 - L18 were not covered by tests
self._deserialize = deserialize
self._shards = defaultdict(deque)

async def _process(self, id: str, shards: list[bytes]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:

Check warning on line 22 in distributed/shuffle/_memory.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_memory.py#L22

Added line #L22 was not covered by tests
# TODO: This can be greatly simplified, there's no need for
# background threads at all.
with log_errors():
Expand Down
Loading

0 comments on commit 9c648b3

Please sign in to comment.