Skip to content

Commit

Permalink
Don't share host_array between objects
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Nov 2, 2023
1 parent 954e9d0 commit bf43081
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 34 deletions.
114 changes: 85 additions & 29 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,16 @@ async def read(self, deserializers=None):
fmt_size = struct.calcsize(fmt)

try:
frames_nbytes = await stream.read_bytes(fmt_size)
(frames_nbytes,) = struct.unpack(fmt, frames_nbytes)

frames = host_array(frames_nbytes)
for i, j in sliding_window(
2,
range(0, frames_nbytes + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE),
):
chunk = frames[i:j]
chunk_nbytes = chunk.nbytes
n = await stream.read_into(chunk)
assert n == chunk_nbytes, (n, chunk_nbytes)
# Don't store multiple numpy or parquet buffers into the same buffer, or
# none will be released until all are released.
frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
(frames_nosplit_nbytes,) = struct.unpack(fmt, frames_nosplit_nbytes_bin)
frames_nosplit = await read_bytes_rw(stream, frames_nosplit_nbytes)
frames, buffers_nbytes = unpack_frames(frames_nosplit, partial=True)
for buffer_nbytes in buffers_nbytes:
buffer = await read_bytes_rw(stream, buffer_nbytes)
frames.append(buffer)

except StreamClosedError as e:
self.stream = None
self._closed = True
Expand All @@ -247,8 +245,6 @@ async def read(self, deserializers=None):
raise
else:
try:
frames = unpack_frames(frames)

msg = await from_frames(
frames,
deserialize=self.deserialize,
Expand Down Expand Up @@ -278,23 +274,10 @@ async def write(self, msg, serializers=None, on_error="message"):
},
frame_split_size=self.max_shard_size,
)
frames_nbytes = [nbytes(f) for f in frames]
frames_nbytes_total = sum(frames_nbytes)

header = pack_frames_prelude(frames)
header = struct.pack("Q", nbytes(header) + frames_nbytes_total) + header

frames = [header, *frames]
frames_nbytes = [nbytes(header), *frames_nbytes]
frames_nbytes_total += frames_nbytes[0]

if frames_nbytes_total < 2**17: # 128kiB
# small enough, send in one go
frames = [b"".join(frames)]
frames_nbytes = [frames_nbytes_total]
frames, frames_nbytes, frames_nbytes_total = _add_frames_header(frames)

try:
# trick to enque all frames for writing beforehand
# trick to enqueue all frames for writing beforehand
for each_frame_nbytes, each_frame in zip(frames_nbytes, frames):
if each_frame_nbytes:
# Make sure that `len(data) == data.nbytes`
Expand Down Expand Up @@ -371,6 +354,79 @@ def extra_info(self):
return self._extra


async def read_bytes_rw(stream: IOStream, n: int) -> memoryview:
"""Read n bytes from stream. Unlike stream.read_bytes, allow for
very large messages and return a writeable buffer.
"""
buf = host_array(n)

for i, j in sliding_window(
2,
range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE),
):
chunk = buf[i:j]
chunk_nbytes = chunk.nbytes
n = await stream.read_into(chunk) # type: ignore[arg-type]
assert n == chunk_nbytes, (n, chunk_nbytes)

return buf


def _add_frames_header(
frames: list[bytes | memoryview],
) -> tuple[list[bytes | memoryview], list[int], int]:
""" """
frames_nbytes = [nbytes(f) for f in frames]
frames_nbytes_total = sum(frames_nbytes)

# Calculate the number of bytes that are inclusive of:
# - prelude
# - msgpack header
# - simple pickle bytes
# - compressed buffers
# - first uncompressed buffer (possibly sharded), IFF the pickle bytes are
# negligible in size
#
# All these can be fetched by read() into a single buffer with a single call to
# Tornado, because they will be dereferenced soon after they are deserialized.
# Read uncompressed numpy/parquet buffers, which will survive indefinitely past
# the end of read(), into their own host arrays so that their memory can be
# released independently.
frames_nbytes_nosplit = 0
first_uncompressed_buffer: object = None
for frame, nb in zip(frames, frames_nbytes):
buffer = frame.obj if isinstance(frame, memoryview) else frame
if not isinstance(buffer, bytes):
# Uncompressed buffer; it will be referenced by the unpickled object
if first_uncompressed_buffer is None:
if frames_nbytes_nosplit > max(2048, nb * 0.05):
# Don't extend the lifespan of non-trivial amounts of pickled bytes
# to that of the buffers
break
first_uncompressed_buffer = buffer
elif first_uncompressed_buffer is not buffer: # don't split sharded frame
# Always store 2+ separate numpy/parquet objects onto separate
# buffers
break

frames_nbytes_nosplit += nb

header = pack_frames_prelude(frames)
header = struct.pack("Q", nbytes(header) + frames_nbytes_nosplit) + header
header_nbytes = nbytes(header)

frames = [header, *frames]
frames_nbytes = [header_nbytes, *frames_nbytes]
frames_nbytes_total += header_nbytes

if frames_nbytes_total < 2**17: # 128kiB
# small enough, send in one go
frames = [b"".join(frames)]
frames_nbytes = [frames_nbytes_total]

return frames, frames_nbytes, frames_nbytes_total


class TLS(TCP):
"""
A TLS-specific version of TCP.
Expand Down
95 changes: 94 additions & 1 deletion distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
)
from distributed.comm.registry import backends, get_backend
from distributed.comm.tcp import get_stream_address
from distributed.compatibility import asyncio_run
from distributed.compatibility import WINDOWS, asyncio_run
from distributed.config import get_loop_factory
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.protocol.utils_test import get_host_array
from distributed.utils import get_ip, get_ipv6, get_mp_context, wait_for
from distributed.utils_test import (
gen_test,
Expand Down Expand Up @@ -1379,3 +1380,95 @@ def test_register_backend_entrypoint(tmp_path):
with get_mp_context().Pool(1) as pool:
assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1
pool.join()


class OpaqueList(list):
"""Don't let the serialization layer travese this object"""

pass


@pytest.mark.parametrize(
"list_cls",
[
list, # Use protocol.numpy.serialize_numpy_array / deserialize_numpy_array
OpaqueList, # Use generic pickle.dumps / pickle.loads
],
)
@gen_test()
async def test_do_not_share_buffers(tcp, list_cls):
"""Test that two objects with buffer interface in the same message do not share
their buffer upon deserialization
See Also
--------
test_share_buffer_with_header
test_ucx.py::test_do_not_share_buffers
"""
np = pytest.importorskip("numpy")

async def handle_comm(comm):
msg = await comm.read()
msg["data"] = to_serialize(list_cls([np.array([1, 2]), np.array([3, 4])]))
await comm.write(msg)
await comm.close()

listener = await tcp.TCPListener("127.0.0.1", handle_comm)
comm = await connect(listener.contact_address)

await comm.write({"op": "ping"})
msg = await comm.read()
await comm.close()

a, b = msg["data"]
assert get_host_array(a) is not get_host_array(b)


@pytest.mark.parametrize(
"nbytes_np,nbytes_other,expect_separate_buffer",
[
(1, 0, False), # <2 kiB (including prologue and msgpack header)
(1, 1800, False), # <2 kiB
(1, 2100, True), # >2 kiB
(200_000, 9500, False), # <5% of numpy array
(200_000, 10500, True), # >5% of numpy array
(350_000, 0, False), # sharded buffer
],
)
@gen_test()
async def test_share_buffer_with_header(
tcp, nbytes_np, nbytes_other, expect_separate_buffer
):
"""Test that a numpy or parquet object shares the buffer with its serialized header
to improve performance, but only as long as the header is trivial in size.
See Also
--------
test_do_not_share_buffers
"""
np = pytest.importorskip("numpy")
if tcp is asyncio_tcp and WINDOWS:
pytest.xfail("asyncio_tcp is faulty on windows")

async def handle_comm(comm):
comm.max_shard_size = 250_000
msg = await comm.read()
msg["np"] = to_serialize(np.random.randint(0, 256, nbytes_np, dtype="u1"))
msg["other"] = np.random.bytes(nbytes_other)
await comm.write(msg)
await comm.close()

listener = await tcp.TCPListener("127.0.0.1", handle_comm)
comm = await connect(listener.contact_address)

await comm.write({"op": "ping"})
msg = await comm.read()
await comm.close()

a = msg["np"]
ha = get_host_array(a)
if tcp is asyncio_tcp:
# TODO unimplemented optimization. Buffers are always split.
assert ha.nbytes == a.nbytes
else:
assert (ha.nbytes == a.nbytes) == expect_separate_buffer
28 changes: 28 additions & 0 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
has_cuda_context,
)
from distributed.protocol import to_serialize
from distributed.protocol.utils_test import get_host_array
from distributed.utils_test import gen_test, inc

try:
Expand Down Expand Up @@ -432,3 +433,30 @@ async def test_embedded_cupy_array(
x = da.from_array(a, chunks=(10000,))
b = await client.compute(x)
cupy.testing.assert_array_equal(a, b)


@gen_test()
async def test_do_not_share_buffers(ucx_loop):
"""Test that two objects with buffer interface in the same message do not share
their buffer upon deserialization.
See Also
--------
test_comms.py::test_do_not_share_buffers
"""
np = pytest.importorskip("numpy")

com, serv_com = await get_comm_pair()
msg = {"data": to_serialize([np.array([1, 2]), np.array([3, 4])])}

await com.write(msg)
result = await serv_com.read()
await com.close()
await serv_com.close()

a, b = result["data"]
ha = get_host_array(a)
hb = get_host_array(b)
assert ha is not hb
assert ha.nbytes == a.nbytes
assert hb.nbytes == a.nbytes
42 changes: 40 additions & 2 deletions distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,54 @@

import pytest

from distributed.protocol.utils import merge_memoryviews, pack_frames, unpack_frames
from distributed.protocol.utils import (
merge_memoryviews,
pack_frames,
pack_frames_prelude,
unpack_frames,
)


def test_pack_frames():
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)
frames2 = unpack_frames(b)
assert frames2 == frames

assert frames == frames2

@pytest.mark.parametrize("extra", [b"456", b""])
def test_unpack_frames_remainder(extra):
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)

frames2 = unpack_frames(b + extra)
assert frames2 == frames

frames2 = unpack_frames(b + extra, remainder=True)
assert isinstance(frames2[-1], memoryview)
assert frames2 == frames + [extra]


def test_unpack_frames_partial():
frames = [b"123", b"asdf"]
frames.insert(0, pack_frames_prelude(frames))

frames2, missing_lenghts = unpack_frames(b"".join(frames), partial=True)
assert frames2 == frames[1:]
assert missing_lenghts == []

frames2, missing_lenghts = unpack_frames(b"".join(frames[:-1]), partial=True)
assert frames2 == frames[1:-1]
assert missing_lenghts == [4]

frames2, missing_lenghts = unpack_frames(frames[0], partial=True)
assert frames2 == []
assert missing_lenghts == [3, 4]

with pytest.raises(AssertionError):
unpack_frames(b"".join(frames[:-1]))


class TestMergeMemroyviews:
Expand Down
26 changes: 26 additions & 0 deletions distributed/protocol/tests/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import pytest

from distributed.protocol.utils import host_array
from distributed.protocol.utils_test import get_host_array


def test_get_host_array():
np = pytest.importorskip("numpy")

a = np.array([1, 2, 3])
assert get_host_array(a) is a
assert get_host_array(a[1:]) is a
assert get_host_array(a[1:][1:]) is a

buf = host_array(3)
a = np.frombuffer(buf, dtype="u1")
assert get_host_array(a) is buf.obj
assert get_host_array(a[1:]) is buf.obj
a = np.frombuffer(buf[1:], dtype="u1")
assert get_host_array(a) is buf.obj

a = np.frombuffer(bytearray(3), dtype="u1")
with pytest.raises(TypeError):
get_host_array(a)
Loading

0 comments on commit bf43081

Please sign in to comment.