Skip to content

Commit

Permalink
p2p_monolithic_bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 2, 2023
1 parent 5e57171 commit e02e945
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 32 deletions.
8 changes: 5 additions & 3 deletions distributed/protocol/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def test_get_host_array():
a = np.frombuffer(buf[1:], dtype="u1")
assert get_host_array(a) is buf.obj

a = np.frombuffer(bytearray(3), dtype="u1")
with pytest.raises(TypeError):
get_host_array(a)
for buf in (b"123", bytearray(b"123")):
a = np.frombuffer(buf, dtype="u1")
assert get_host_array(a) is buf
a = np.frombuffer(memoryview(buf), dtype="u1")
assert get_host_array(a) is buf
12 changes: 5 additions & 7 deletions distributed/protocol/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy


def get_host_array(a: numpy.ndarray) -> numpy.ndarray:
def get_host_array(a: numpy.ndarray) -> numpy.ndarray | bytes | bytearray:
"""Given a numpy array, find the underlying memory allocated by either
distributed.protocol.utils.host_array or internally by numpy
"""
Expand All @@ -22,9 +22,7 @@ def get_host_array(a: numpy.ndarray) -> numpy.ndarray:
o = o.base
else:
return o
else:
# distributed.comm.utils.host_array() uses numpy.empty()
raise TypeError(
"Array uses a buffer allocated neither internally nor by host_array: "
f"{type(o)}"
)
elif isinstance(o, (bytes, bytearray)):
return o
else: # pragma: nocover
raise TypeError(f"Unexpected numpy buffer: {o!r}")
21 changes: 11 additions & 10 deletions distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
import logging
from collections import defaultdict
from collections.abc import Iterator, Sized
from typing import Any, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from distributed.metrics import time
from distributed.shuffle._limiter import ResourceLimiter
from distributed.sizeof import sizeof

logger = logging.getLogger("distributed.shuffle")
if TYPE_CHECKING:
# TODO import from collections.abc (requires Python >=3.12)
from typing_extensions import Buffer
else:
Buffer = Sized

ShardType = TypeVar("ShardType", bound=Sized)
T = TypeVar("T")

ShardType = TypeVar("ShardType", bound=Buffer)

class _List(list[T]):
# This ensures that the distributed.protocol will not iterate over this collection
pass
T = TypeVar("T")


class ShardsBuffer(Generic[ShardType]):
Expand All @@ -43,7 +44,7 @@ class ShardsBuffer(Generic[ShardType]):
Flushing will not raise an exception. To ensure that the buffer finished successfully, please call `ShardsBuffer.raise_on_exception`
"""

shards: defaultdict[str, _List[ShardType]]
shards: defaultdict[str, list[ShardType]]
sizes: defaultdict[str, int]
sizes_detail: defaultdict[str, list[int]]
concurrency_limit: int
Expand All @@ -70,7 +71,7 @@ def __init__(
max_message_size: int = -1,
) -> None:
self._accepts_input = True
self.shards = defaultdict(_List)
self.shards = defaultdict(list)
self.sizes = defaultdict(int)
self.sizes_detail = defaultdict(list)
self._exception = None
Expand Down Expand Up @@ -146,7 +147,7 @@ def _continue() -> bool:
part_id = max(self.sizes, key=self.sizes.__getitem__)
if self.max_message_size > 0:
size = 0
shards: _List[ShardType] = _List()
shards = []
while size < self.max_message_size:
try:
shard = self.shards[part_id].pop()
Expand Down
41 changes: 36 additions & 5 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
import asyncio
import contextlib
import itertools
import pickle
import time
from collections import defaultdict
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast

from tornado.ioloop import IOLoop

import dask.config
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_timedelta

Expand Down Expand Up @@ -140,7 +142,7 @@ async def barrier(self, run_ids: Sequence[int]) -> int:
return self.run_id

async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> None:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
Expand All @@ -159,8 +161,19 @@ async def send(
retry_delay_max = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)

if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards

Check warning on line 173 in distributed/shuffle/_core.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_core.py#L173

Added line #L173 was not covered by tests

return await retry(
partial(self._send, address, shards),
partial(self._send, address, shards_or_bytes),
count=retry_count,
delay_min=retry_delay_min,
delay_max=retry_delay_max,
Expand Down Expand Up @@ -239,7 +252,10 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)

async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None:
Expand Down Expand Up @@ -422,3 +438,18 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]:
raise Reschedule()
except Exception as e:
raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e


def _mean_shard_size(shards: Iterable) -> int:
"""Return estimated mean size in bytes of each shard"""
size = 0
count = 0
for shard in flatten(shards, container=(tuple, list)):
if not isinstance(shard, int):
# This also asserts that shard is a Buffer and that we didn't forget
# a container or metadata type above
size += memoryview(shard).nbytes
count += 1
if count == 10:
break
return size // count if count else 0
1 change: 0 additions & 1 deletion distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray:
for sublist in shards:
for index, shard in sublist:
indexed[index] = shard
del shards

subshape = [max(dim) + 1 for dim in zip(*indexed.keys())]
assert len(indexed) == np.prod(subshape)
Expand Down
3 changes: 0 additions & 3 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,6 @@ def __init__(
self.partitions_of = dict(partitions_of)
self.worker_for = pd.Series(worker_for, name="_workers").astype("category")

async def receive(self, data: list[tuple[int, bytes]]) -> None:
await self._receive(data)

async def _receive(self, data: list[tuple[int, bytes]]) -> None:
self.raise_if_closed()

Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async def shuffle_receive(
self,
shuffle_id: ShuffleId,
run_id: int,
data: list[tuple[int, Any]],
data: list[tuple[int, Any]] | bytes,
) -> None:
"""
Handler: Receive an incoming shard of data from a peer worker.
Expand Down
29 changes: 29 additions & 0 deletions distributed/shuffle/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

import pytest

from distributed.shuffle._core import _mean_shard_size


def test_mean_shard_size():
assert _mean_shard_size([]) == 0
assert _mean_shard_size([b""]) == 0
assert _mean_shard_size([b"123", b"45678"]) == 4
# Don't fully iterate over large collections
assert _mean_shard_size([b"12" * n for n in range(1000)]) == 9
# Support any Buffer object
assert _mean_shard_size([b"12", bytearray(b"1234"), memoryview(b"123456")]) == 4
# Recursion into lists or tuples; ignore int
assert _mean_shard_size([(1, 2, [3, b"123456"])]) == 6
# Don't blindly call sizeof() on unexpected objects
with pytest.raises(TypeError):
_mean_shard_size([1.2])
with pytest.raises(TypeError):
_mean_shard_size([{1: 2}])


def test_mean_shard_size_numpy():
"""Test that _mean_shard_size doesn't call len() on multi-byte data types"""
np = pytest.importorskip("numpy")
assert _mean_shard_size([np.zeros(10, dtype="u1")]) == 10
assert _mean_shard_size([np.zeros(10, dtype="u8")]) == 80
7 changes: 6 additions & 1 deletion distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ async def test_lowlevel_rechunk(tmp_path, n_workers, barrier_first_worker, disk)
total_bytes_recvd += metrics["disk"]["total"]
total_bytes_recvd_shuffle += s.total_recvd

assert total_bytes_recvd_shuffle == total_bytes_sent
# Allow for some uncertainty due to slight differences in measuring
assert (
total_bytes_sent * 0.95
< total_bytes_recvd_shuffle
< total_bytes_sent * 1.05
)

all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O")
for ix, worker in worker_for_mapping.items():
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,7 +1822,7 @@ async def test_error_receive(tmp_path, loop_in_thread):
partitions_for_worker[w].append(part)

class ErrorReceive(DataFrameShuffleRun):
async def receive(self, data: list[tuple[int, bytes]]) -> None:
async def _receive(self, data: list[tuple[int, bytes]]) -> None:
raise RuntimeError("Error during receive")

with DataFrameShuffleTestPool() as local_shuffle_pool:
Expand Down

0 comments on commit e02e945

Please sign in to comment.