Skip to content

Commit

Permalink
Zero copy numpy shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 3, 2023
1 parent c91a735 commit e322258
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 67 deletions.
7 changes: 6 additions & 1 deletion distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ShardsBuffer(Generic[ShardType]):

shards: defaultdict[str, _List[ShardType]]
sizes: defaultdict[str, int]
sizes_detail: defaultdict[str, list[int]]
concurrency_limit: int
memory_limiter: ResourceLimiter
diagnostics: dict[str, float]
Expand All @@ -71,6 +72,7 @@ def __init__(
self._accepts_input = True
self.shards = defaultdict(_List)
self.sizes = defaultdict(int)
self.sizes_detail = defaultdict(list)
self._exception = None
self.concurrency_limit = concurrency_limit
self._inputs_done = False
Expand Down Expand Up @@ -149,7 +151,7 @@ def _continue() -> bool:
try:
shard = self.shards[part_id].pop()
shards.append(shard)
s = sizeof(shard)
s = self.sizes_detail[part_id].pop()
size += s
self.sizes[part_id] -= s
except IndexError:
Expand All @@ -159,6 +161,8 @@ def _continue() -> bool:
del self.shards[part_id]
assert not self.sizes[part_id]
del self.sizes[part_id]
assert not self.sizes_detail[part_id]
del self.sizes_detail[part_id]
else:
shards = self.shards.pop(part_id)
size = self.sizes.pop(part_id)
Expand Down Expand Up @@ -201,6 +205,7 @@ async def write(self, data: dict[str, ShardType]) -> None:
async with self._shards_available:
for worker, shard in data.items():
self.shards[worker].append(shard)
self.sizes_detail[worker].append(sizes[worker])
self.sizes[worker] += sizes[worker]
self._shards_available.notify()
await self.memory_limiter.wait_for_available()
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CommShardsBuffer(ShardsBuffer):

def __init__(
self,
send: Callable[[str, list[tuple[Any, bytes]]], Awaitable[None]],
send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]],
memory_limiter: ResourceLimiter,
concurrency_limit: int = 10,
):
Expand All @@ -63,7 +63,7 @@ def __init__(
)
self.send = send

async def _process(self, address: str, shards: list[tuple[Any, bytes]]) -> None:
async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None:
"""Send one message off to a neighboring worker"""
with log_errors():
# Consider boosting total_size a bit here to account for duplication
Expand Down
16 changes: 8 additions & 8 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def barrier(self, run_ids: Sequence[int]) -> int:
return self.run_id

async def _send(
self, address: str, shards: list[tuple[_T_partition_id, bytes]]
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
Expand All @@ -150,7 +150,7 @@ async def _send(
)

async def send(
self, address: str, shards: list[tuple[_T_partition_id, bytes]]
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> None:
retry_count = dask.config.get("distributed.p2p.comm.retry.count")
retry_delay_min = parse_timedelta(
Expand Down Expand Up @@ -186,12 +186,12 @@ def heartbeat(self) -> dict[str, Any]:
}

async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, bytes]]
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)

async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None:
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
Expand Down Expand Up @@ -239,7 +239,7 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
await self._receive(data)

async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None:
Expand All @@ -259,7 +259,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""

@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""

def add_partition(
Expand All @@ -275,7 +275,7 @@ def add_partition(
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, bytes]]:
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""

def get_output_partition(
Expand All @@ -299,7 +299,7 @@ def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""

@abc.abstractmethod
def deserialize(self, buffer: bytes) -> Any:
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""


Expand Down
26 changes: 18 additions & 8 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import pathlib
import shutil
import threading
from collections.abc import Generator
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from typing import Any, Callable
from typing import Any

from toolz import concat

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import pickle_bytelist
from distributed.utils import Deadline, log_errors


Expand Down Expand Up @@ -135,7 +138,7 @@ def __init__(
self._read = read
self._directory_lock = ReadWriteLock()

async def _process(self, id: str, shards: list[bytes]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:
"""Write one buffer to file
This function was built to offload the disk IO, but since then we've
Expand All @@ -157,11 +160,18 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with self._directory_lock.read():
if self._closed:
raise RuntimeError("Already closed")
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
f.write(shard)

frames: Iterable[bytes | bytearray | memoryview]

if not shards or isinstance(shards[0], bytes):
# Manually serialized dataframes
frames = shards
else:
# Unserialized numpy arrays
frames = concat(pickle_bytelist(shard) for shard in shards)

with open(self.directory / str(id), mode="ab") as f:
f.writelines(frames)

def read(self, id: str) -> Any:
"""Read a complete file back into memory"""
Expand Down
12 changes: 5 additions & 7 deletions distributed/shuffle/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@


class MemoryShardsBuffer(ShardsBuffer):
_deserialize: Callable[[bytes], Any]
_shards: defaultdict[str, deque[bytes]]
_deserialize: Callable[[Any], Any]
_shards: defaultdict[str, deque[Any]]

def __init__(self, deserialize: Callable[[bytes], Any]) -> None:
super().__init__(
memory_limiter=ResourceLimiter(None),
)
def __init__(self, deserialize: Callable[[Any], Any]) -> None:
super().__init__(memory_limiter=ResourceLimiter(None))
self._deserialize = deserialize
self._shards = defaultdict(deque)

async def _process(self, id: str, shards: list[bytes]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:
# TODO: This can be greatly simplified, there's no need for
# background threads at all.
with log_errors():
Expand Down
42 changes: 42 additions & 0 deletions distributed/shuffle/_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import pickle
from collections.abc import Iterator
from typing import Any

from distributed.protocol.utils import pack_frames_prelude, unpack_frames


def pickle_bytelist(obj: object) -> list[bytes | memoryview]:
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
defined classes, or any of its other fancy features but runs 10x faster for numpy
arrays
See Also
--------
serialize_bytelist
unpickle_bytestream
"""
frames: list = []
pik = pickle.dumps(
obj, protocol=5, buffer_callback=lambda pb: frames.append(pb.raw())
)
frames.insert(0, pik)
frames.insert(0, pack_frames_prelude(frames))
return frames


def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:
"""Unpickle the concatenated output of multiple calls to :func:`pickle_bytelist`
See Also
--------
pickle_bytelist
deserialize_bytes
"""
while True:
pik, *buffers, remainder = unpack_frames(b, remainder=True)
yield pickle.loads(pik, buffers=buffers)
if remainder.nbytes == 0:
break
b = remainder
94 changes: 55 additions & 39 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@

from __future__ import annotations

import mmap
import os
import pickle
from collections import defaultdict
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -123,6 +123,7 @@
handle_unpack_errors,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
Expand Down Expand Up @@ -281,6 +282,9 @@ def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray:
for index, shard in indexed.items():
rec_cat_arg[tuple(index)] = shard
arrs = rec_cat_arg.tolist()

# This may block for several seconds, as it physically reads the memory-mapped
# buffers from disk
return concatenate3(arrs)


Expand Down Expand Up @@ -364,72 +368,84 @@ def __init__(
self.worker_for = worker_for
self.split_axes = split_axes(old, new)

async def _receive(self, data: list[tuple[NDIndex, bytes]]) -> None:
async def _receive(
self,
data: list[tuple[NDIndex, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]]],
) -> None:
self.raise_if_closed()

filtered = []
# Repartition shards and filter out already received ones
shards = defaultdict(list)

Check warning on line 378 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L378

Added line #L378 was not covered by tests
for d in data:
id, payload = d
if id in self.received:
id1, payload = d
if id1 in self.received:

Check warning on line 381 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L380-L381

Added lines #L380 - L381 were not covered by tests
continue
filtered.append(payload)
self.received.add(id)
self.received.add(id1)
for id2, shard in payload:
shards[id2].append(shard)

Check warning on line 385 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L383-L385

Added lines #L383 - L385 were not covered by tests
self.total_recvd += sizeof(d)
del data
if not filtered:
if not shards:

Check warning on line 388 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L388

Added line #L388 was not covered by tests
return

try:
shards = await self.offload(self._repartition_shards, filtered)
del filtered
await self._write_to_disk(shards)
except Exception as e:
self._exception = e
raise

def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]:
repartitioned: defaultdict[
NDIndex, list[tuple[NDIndex, np.ndarray]]
] = defaultdict(list)
for buffer in data:
for id, shard in pickle.loads(buffer):
repartitioned[id].append(shard)
return {k: pickle.dumps(v) for k, v in repartitioned.items()}

def _shard_partition(
self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any
) -> dict[str, tuple[NDIndex, bytes]]:
self, data: np.ndarray, partition_id: NDIndex
) -> dict[str, tuple[NDIndex, Any]]:
out: dict[str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]] = defaultdict(
list
)
from itertools import product

ndsplits = product(*(axis[i] for axis, i in zip(self.split_axes, partition_id)))

for ndsplit in ndsplits:
chunk_index, shard_index, ndslice = zip(*ndsplit)

shard = data[ndslice]

Check warning on line 408 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L408

Added line #L408 was not covered by tests
# Don't wait until all shards have been transferred over the network
# before data can be released
if shard.base is not None:
shard = shard.copy()

Check warning on line 412 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L411-L412

Added lines #L411 - L412 were not covered by tests

out[self.worker_for[chunk_index]].append(
(chunk_index, (shard_index, data[ndslice]))
(chunk_index, (shard_index, shard))
)
return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()}
return {k: (partition_id, v) for k, v in out.items()}

Check warning on line 417 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L417

Added line #L417 was not covered by tests

def _get_output_partition(
self, partition_id: NDIndex, key: str, **kwargs: Any
) -> np.ndarray:
# Quickly read metadata from disk.
# This is a bunch of seek()'s interleaved with short reads.
data = self._read_from_disk(partition_id)
return convert_chunk(data)

def deserialize(self, buffer: bytes) -> Any:
result = pickle.loads(buffer)
return result

def read(self, path: Path) -> tuple[Any, int]:
shards: list[list[tuple[NDIndex, np.ndarray]]] = []
with path.open(mode="rb") as f:
size = f.seek(0, os.SEEK_END)
f.seek(0)
while f.tell() < size:
shards.append(pickle.load(f))
return shards, size
# Copy the memory-mapped buffers from disk into memory.
# This is where we'll spend most time.
with self._disk_buffer.time("read"):
return convert_chunk(data)

Check warning on line 428 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L427-L428

Added lines #L427 - L428 were not covered by tests

def deserialize(self, buffer: Any) -> Any:
return buffer

Check warning on line 431 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L431

Added line #L431 was not covered by tests

def read(self, path: Path) -> tuple[list[list[tuple[NDIndex, np.ndarray]]], int]:
"""Open a memory-mapped file descriptor to disk, read all metadata, and unpickle
all arrays. This is a fast sequence of short reads interleaved with seeks.
Do not read in memory the actual data; the arrays' buffers will point to the
memory-mapped area.
The file descriptor will be automatically closed by the kernel when all the
returned arrays are dereferenced, which will happen after the call to
concatenate3.
"""
with path.open(mode="r+b") as fh:
buffer = memoryview(mmap.mmap(fh.fileno(), 0))

Check warning on line 444 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L443-L444

Added lines #L443 - L444 were not covered by tests

# The file descriptor has *not* been closed!
shards = list(unpickle_bytestream(buffer))
return shards, buffer.nbytes

Check warning on line 448 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L447-L448

Added lines #L447 - L448 were not covered by tests

def _get_assigned_worker(self, id: NDIndex) -> str:
return self.worker_for[id]
Expand Down
Loading

0 comments on commit e322258

Please sign in to comment.