From e322258d966aea55216bdef692d0f5b4824d6acc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 2 Nov 2023 21:30:13 +0100 Subject: [PATCH 1/2] Zero copy numpy shuffle --- distributed/shuffle/_buffer.py | 7 +- distributed/shuffle/_comms.py | 4 +- distributed/shuffle/_core.py | 16 ++-- distributed/shuffle/_disk.py | 26 +++++-- distributed/shuffle/_memory.py | 12 ++- distributed/shuffle/_pickle.py | 42 ++++++++++ distributed/shuffle/_rechunk.py | 94 +++++++++++++---------- distributed/shuffle/_shuffle.py | 2 +- distributed/shuffle/_worker_plugin.py | 2 +- distributed/shuffle/tests/test_pickle.py | 34 ++++++++ distributed/shuffle/tests/test_rechunk.py | 50 ++++++++++++ 11 files changed, 222 insertions(+), 67 deletions(-) create mode 100644 distributed/shuffle/_pickle.py create mode 100644 distributed/shuffle/tests/test_pickle.py diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index b0d24ace02..bfacac2790 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -45,6 +45,7 @@ class ShardsBuffer(Generic[ShardType]): shards: defaultdict[str, _List[ShardType]] sizes: defaultdict[str, int] + sizes_detail: defaultdict[str, list[int]] concurrency_limit: int memory_limiter: ResourceLimiter diagnostics: dict[str, float] @@ -71,6 +72,7 @@ def __init__( self._accepts_input = True self.shards = defaultdict(_List) self.sizes = defaultdict(int) + self.sizes_detail = defaultdict(list) self._exception = None self.concurrency_limit = concurrency_limit self._inputs_done = False @@ -149,7 +151,7 @@ def _continue() -> bool: try: shard = self.shards[part_id].pop() shards.append(shard) - s = sizeof(shard) + s = self.sizes_detail[part_id].pop() size += s self.sizes[part_id] -= s except IndexError: @@ -159,6 +161,8 @@ def _continue() -> bool: del self.shards[part_id] assert not self.sizes[part_id] del self.sizes[part_id] + assert not self.sizes_detail[part_id] + del self.sizes_detail[part_id] else: shards = self.shards.pop(part_id) size = self.sizes.pop(part_id) @@ -201,6 +205,7 @@ async def write(self, data: dict[str, ShardType]) -> None: async with self._shards_available: for worker, shard in data.items(): self.shards[worker].append(shard) + self.sizes_detail[worker].append(sizes[worker]) self.sizes[worker] += sizes[worker] self._shards_available.notify() await self.memory_limiter.wait_for_available() diff --git a/distributed/shuffle/_comms.py b/distributed/shuffle/_comms.py index 020313debe..6dac1a3d30 100644 --- a/distributed/shuffle/_comms.py +++ b/distributed/shuffle/_comms.py @@ -52,7 +52,7 @@ class CommShardsBuffer(ShardsBuffer): def __init__( self, - send: Callable[[str, list[tuple[Any, bytes]]], Awaitable[None]], + send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]], memory_limiter: ResourceLimiter, concurrency_limit: int = 10, ): @@ -63,7 +63,7 @@ def __init__( ) self.send = send - async def _process(self, address: str, shards: list[tuple[Any, bytes]]) -> None: + async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None: """Send one message off to a neighboring worker""" with log_errors(): # Consider boosting total_size a bit here to account for duplication diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index c79033da09..7e15a1a570 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -140,7 +140,7 @@ async def barrier(self, run_ids: Sequence[int]) -> int: return self.run_id async def _send( - self, address: str, shards: list[tuple[_T_partition_id, bytes]] + self, address: str, shards: list[tuple[_T_partition_id, Any]] ) -> None: self.raise_if_closed() return await self.rpc(address).shuffle_receive( @@ -150,7 +150,7 @@ async def _send( ) async def send( - self, address: str, shards: list[tuple[_T_partition_id, bytes]] + self, address: str, shards: list[tuple[_T_partition_id, Any]] ) -> None: retry_count = dask.config.get("distributed.p2p.comm.retry.count") retry_delay_min = parse_timedelta( @@ -186,12 +186,12 @@ def heartbeat(self) -> dict[str, Any]: } async def _write_to_comm( - self, data: dict[str, tuple[_T_partition_id, bytes]] + self, data: dict[str, tuple[_T_partition_id, Any]] ) -> None: self.raise_if_closed() await self._comm_buffer.write(data) - async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None: + async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None: self.raise_if_closed() await self._disk_buffer.write( {"_".join(str(i) for i in k): v for k, v in data.items()} @@ -239,7 +239,7 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() return self._disk_buffer.read("_".join(str(i) for i in id)) - async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: + async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: await self._receive(data) async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None: @@ -259,7 +259,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str: """Get the address of the worker assigned to the output partition""" @abc.abstractmethod - async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: + async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: """Receive shards belonging to output partitions of this shuffle run""" def add_partition( @@ -275,7 +275,7 @@ def add_partition( @abc.abstractmethod def _shard_partition( self, data: _T_partition_type, partition_id: _T_partition_id - ) -> dict[str, tuple[_T_partition_id, bytes]]: + ) -> dict[str, tuple[_T_partition_id, Any]]: """Shard an input partition by the assigned output workers""" def get_output_partition( @@ -299,7 +299,7 @@ def read(self, path: Path) -> tuple[Any, int]: """Read shards from disk""" @abc.abstractmethod - def deserialize(self, buffer: bytes) -> Any: + def deserialize(self, buffer: Any) -> Any: """Deserialize shards""" diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 87fea2cb99..6d44f58742 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -4,12 +4,15 @@ import pathlib import shutil import threading -from collections.abc import Generator +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager -from typing import Any, Callable +from typing import Any + +from toolz import concat from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._pickle import pickle_bytelist from distributed.utils import Deadline, log_errors @@ -135,7 +138,7 @@ def __init__( self._read = read self._directory_lock = ReadWriteLock() - async def _process(self, id: str, shards: list[bytes]) -> None: + async def _process(self, id: str, shards: list[Any]) -> None: """Write one buffer to file This function was built to offload the disk IO, but since then we've @@ -157,11 +160,18 @@ async def _process(self, id: str, shards: list[bytes]) -> None: with self._directory_lock.read(): if self._closed: raise RuntimeError("Already closed") - with open( - self.directory / str(id), mode="ab", buffering=100_000_000 - ) as f: - for shard in shards: - f.write(shard) + + frames: Iterable[bytes | bytearray | memoryview] + + if not shards or isinstance(shards[0], bytes): + # Manually serialized dataframes + frames = shards + else: + # Unserialized numpy arrays + frames = concat(pickle_bytelist(shard) for shard in shards) + + with open(self.directory / str(id), mode="ab") as f: + f.writelines(frames) def read(self, id: str) -> Any: """Read a complete file back into memory""" diff --git a/distributed/shuffle/_memory.py b/distributed/shuffle/_memory.py index 27c00dac90..ac523b8fe4 100644 --- a/distributed/shuffle/_memory.py +++ b/distributed/shuffle/_memory.py @@ -11,17 +11,15 @@ class MemoryShardsBuffer(ShardsBuffer): - _deserialize: Callable[[bytes], Any] - _shards: defaultdict[str, deque[bytes]] + _deserialize: Callable[[Any], Any] + _shards: defaultdict[str, deque[Any]] - def __init__(self, deserialize: Callable[[bytes], Any]) -> None: - super().__init__( - memory_limiter=ResourceLimiter(None), - ) + def __init__(self, deserialize: Callable[[Any], Any]) -> None: + super().__init__(memory_limiter=ResourceLimiter(None)) self._deserialize = deserialize self._shards = defaultdict(deque) - async def _process(self, id: str, shards: list[bytes]) -> None: + async def _process(self, id: str, shards: list[Any]) -> None: # TODO: This can be greatly simplified, there's no need for # background threads at all. with log_errors(): diff --git a/distributed/shuffle/_pickle.py b/distributed/shuffle/_pickle.py new file mode 100644 index 0000000000..5e0a76425e --- /dev/null +++ b/distributed/shuffle/_pickle.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import pickle +from collections.abc import Iterator +from typing import Any + +from distributed.protocol.utils import pack_frames_prelude, unpack_frames + + +def pickle_bytelist(obj: object) -> list[bytes | memoryview]: + """Variant of :func:`serialize_bytelist`, that doesn't support compression, locally + defined classes, or any of its other fancy features but runs 10x faster for numpy + arrays + + See Also + -------- + serialize_bytelist + unpickle_bytestream + """ + frames: list = [] + pik = pickle.dumps( + obj, protocol=5, buffer_callback=lambda pb: frames.append(pb.raw()) + ) + frames.insert(0, pik) + frames.insert(0, pack_frames_prelude(frames)) + return frames + + +def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]: + """Unpickle the concatenated output of multiple calls to :func:`pickle_bytelist` + + See Also + -------- + pickle_bytelist + deserialize_bytes + """ + while True: + pik, *buffers, remainder = unpack_frames(b, remainder=True) + yield pickle.loads(pik, buffers=buffers) + if remainder.nbytes == 0: + break + b = remainder diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 96aec17ead..1f47eb3c5e 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,8 +96,8 @@ from __future__ import annotations +import mmap import os -import pickle from collections import defaultdict from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor @@ -123,6 +123,7 @@ handle_unpack_errors, ) from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._pickle import unpickle_bytestream from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin @@ -281,6 +282,9 @@ def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray: for index, shard in indexed.items(): rec_cat_arg[tuple(index)] = shard arrs = rec_cat_arg.tolist() + + # This may block for several seconds, as it physically reads the memory-mapped + # buffers from disk return concatenate3(arrs) @@ -364,72 +368,84 @@ def __init__( self.worker_for = worker_for self.split_axes = split_axes(old, new) - async def _receive(self, data: list[tuple[NDIndex, bytes]]) -> None: + async def _receive( + self, + data: list[tuple[NDIndex, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]]], + ) -> None: self.raise_if_closed() - filtered = [] + # Repartition shards and filter out already received ones + shards = defaultdict(list) for d in data: - id, payload = d - if id in self.received: + id1, payload = d + if id1 in self.received: continue - filtered.append(payload) - self.received.add(id) + self.received.add(id1) + for id2, shard in payload: + shards[id2].append(shard) self.total_recvd += sizeof(d) del data - if not filtered: + if not shards: return + try: - shards = await self.offload(self._repartition_shards, filtered) - del filtered await self._write_to_disk(shards) except Exception as e: self._exception = e raise - def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: - repartitioned: defaultdict[ - NDIndex, list[tuple[NDIndex, np.ndarray]] - ] = defaultdict(list) - for buffer in data: - for id, shard in pickle.loads(buffer): - repartitioned[id].append(shard) - return {k: pickle.dumps(v) for k, v in repartitioned.items()} - def _shard_partition( - self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any - ) -> dict[str, tuple[NDIndex, bytes]]: + self, data: np.ndarray, partition_id: NDIndex + ) -> dict[str, tuple[NDIndex, Any]]: out: dict[str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]] = defaultdict( list ) - from itertools import product - ndsplits = product(*(axis[i] for axis, i in zip(self.split_axes, partition_id))) for ndsplit in ndsplits: chunk_index, shard_index, ndslice = zip(*ndsplit) + + shard = data[ndslice] + # Don't wait until all shards have been transferred over the network + # before data can be released + if shard.base is not None: + shard = shard.copy() + out[self.worker_for[chunk_index]].append( - (chunk_index, (shard_index, data[ndslice])) + (chunk_index, (shard_index, shard)) ) - return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()} + return {k: (partition_id, v) for k, v in out.items()} def _get_output_partition( self, partition_id: NDIndex, key: str, **kwargs: Any ) -> np.ndarray: + # Quickly read metadata from disk. + # This is a bunch of seek()'s interleaved with short reads. data = self._read_from_disk(partition_id) - return convert_chunk(data) - - def deserialize(self, buffer: bytes) -> Any: - result = pickle.loads(buffer) - return result - - def read(self, path: Path) -> tuple[Any, int]: - shards: list[list[tuple[NDIndex, np.ndarray]]] = [] - with path.open(mode="rb") as f: - size = f.seek(0, os.SEEK_END) - f.seek(0) - while f.tell() < size: - shards.append(pickle.load(f)) - return shards, size + # Copy the memory-mapped buffers from disk into memory. + # This is where we'll spend most time. + with self._disk_buffer.time("read"): + return convert_chunk(data) + + def deserialize(self, buffer: Any) -> Any: + return buffer + + def read(self, path: Path) -> tuple[list[list[tuple[NDIndex, np.ndarray]]], int]: + """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle + all arrays. This is a fast sequence of short reads interleaved with seeks. + Do not read in memory the actual data; the arrays' buffers will point to the + memory-mapped area. + + The file descriptor will be automatically closed by the kernel when all the + returned arrays are dereferenced, which will happen after the call to + concatenate3. + """ + with path.open(mode="r+b") as fh: + buffer = memoryview(mmap.mmap(fh.fileno(), 0)) + + # The file descriptor has *not* been closed! + shards = list(unpickle_bytestream(buffer)) + return shards, buffer.nbytes def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index d1e092f445..82b451538c 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -520,7 +520,7 @@ def _get_assigned_worker(self, id: int) -> str: def read(self, path: Path) -> tuple[pa.Table, int]: return read_from_disk(path) - def deserialize(self, buffer: bytes) -> Any: + def deserialize(self, buffer: Any) -> Any: return deserialize_table(buffer) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index f40627799f..8828058e69 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -299,7 +299,7 @@ async def shuffle_receive( self, shuffle_id: ShuffleId, run_id: int, - data: list[tuple[int, bytes]], + data: list[tuple[int, Any]], ) -> None: """ Handler: Receive an incoming shard of data from a peer worker. diff --git a/distributed/shuffle/tests/test_pickle.py b/distributed/shuffle/tests/test_pickle.py new file mode 100644 index 0000000000..9378b8a4e7 --- /dev/null +++ b/distributed/shuffle/tests/test_pickle.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import pytest + +from distributed.shuffle._pickle import pickle_bytelist, unpickle_bytestream + + +def test_pickle(): + frames = pickle_bytelist("abc") + pickle_bytelist(123) + bin = b"".join(frames) + objs = list(unpickle_bytestream(bin)) + assert objs == ["abc", 123] + + +def test_pickle_numpy(): + np = pytest.importorskip("numpy") + a = np.array([1, 2, 3]) + frames = pickle_bytelist(a) + bin = b"".join(frames) + [a2] = unpickle_bytestream(bin) + assert (a2 == a).all() + + +def test_pickle_zero_copy(): + np = pytest.importorskip("numpy") + a = np.array([1, 2, 3]) + frames = pickle_bytelist(a) + a[0] = 4 # Test that pickle_bytelist does not deep copy + bin = bytearray(b"".join(frames)) # Deep-copies buffers + [a2] = unpickle_bytestream(bin) + a2[1] = 5 # Test that unpickle_bytelist does not deep copy + [a3] = unpickle_bytestream(bin) + expect = np.array([4, 5, 3]) + assert (a3 == expect).all() diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 976160b926..c31f122f9f 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -20,6 +20,8 @@ from dask.array.rechunk import normalize_chunks, rechunk from dask.array.utils import assert_eq +from distributed import Event +from distributed.protocol.utils_test import get_host_array from distributed.shuffle._core import ShuffleId from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._rechunk import ( @@ -1145,3 +1147,51 @@ def test_split_axes_with_zero(): [[Split(0, 0, slice(0, 1, None)), Split(1, 0, slice(1, 2, None))]], ] assert result == expected + + +@gen_cluster(client=True) +async def test_preserve_writeable_flag(c, s, a, b): + """Make sure that the shuffled array doesn't accidentally become read-only after + the round-trip to e.g. read-only file descriptors or byte objects as buffers + """ + arr = da.random.random(10, chunks=5) + arr = arr.rechunk(((4, 6),), method="p2p") + arr = arr.map_blocks(lambda chunk: chunk.flags["WRITEABLE"]) + out = await c.compute(arr) + assert out.tolist() == [True, True] + + +@gen_cluster(client=True, config={"distributed.p2p.disk": False}) +async def test_rechunk_in_memory_shards_dont_share_buffer(c, s, a, b): + """Test that, if two shards are sent in the same RPC call and they contribute to + different output chunks, downstream tasks don't need to consume all output chunks in + order to release the memory of the output chunks that have already been consumed. + + This can happen if all numpy arrays in the same RPC call share the same buffer + coming out of the TCP stack. + """ + in_map = Event() + block_map = Event() + + def blocked(chunk, in_map, block_map): + in_map.set() + block_map.wait() + return chunk + + # 8 MiB array, 256 kiB chunks, 8 kiB shards + arr = da.random.random((1024, 1024), chunks=(-1, 32)) + arr = arr.rechunk((32, -1), method="p2p") + + arr = arr.map_blocks(blocked, in_map=in_map, block_map=block_map, dtype=arr.dtype) + fut = c.compute(arr) + await in_map.wait() + + [run] = a.extensions["shuffle"].shuffle_runs._runs + shards = [ + s3 for s1 in run._disk_buffer._shards.values() for s2 in s1 for _, s3 in s2 + ] + assert shards + + buf_ids = {id(get_host_array(shard)) for shard in shards} + assert len(buf_ids) == len(shards) + await block_map.set() From 1b26ec6a167a4b92dea12d64db379d585fe60d5d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 7 Nov 2023 15:22:05 +0100 Subject: [PATCH 2/2] shards should never be empty --- distributed/shuffle/_disk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 6d44f58742..6db447bc66 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -163,7 +163,7 @@ async def _process(self, id: str, shards: list[Any]) -> None: frames: Iterable[bytes | bytearray | memoryview] - if not shards or isinstance(shards[0], bytes): + if isinstance(shards[0], bytes): # Manually serialized dataframes frames = shards else: