Skip to content

Commit

Permalink
Test for memory hog in MemoryBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 27, 2023
1 parent e1db838 commit 4dccfd6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
7 changes: 2 additions & 5 deletions distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,14 @@
# uses hugepages when available ( https://github.com/numpy/numpy/pull/14216 ).
import numpy

def numpy_host_array(n: int) -> memoryview:
def host_array(n: int) -> memoryview:

Check warning on line 30 in distributed/comm/utils.py

View check run for this annotation

Codecov / codecov/patch

distributed/comm/utils.py#L30

Added line #L30 was not covered by tests
return numpy.empty((n,), dtype="u1").data

host_array = numpy_host_array
except ImportError:

def builtin_host_array(n: int) -> memoryview:
def host_array(n: int) -> memoryview:
return memoryview(bytearray(n))

host_array = builtin_host_array


async def to_frames(
msg,
Expand Down
55 changes: 55 additions & 0 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dask.array.rechunk import normalize_chunks, rechunk
from dask.array.utils import assert_eq

from distributed import Event
from distributed.shuffle._core import ShuffleId
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._rechunk import (
Expand Down Expand Up @@ -1153,3 +1154,57 @@ async def test_preserve_writeable_flag(c, s, a, b):
arr = arr.map_blocks(lambda chunk: chunk.flags["WRITEABLE"])
out = await c.compute(arr)
assert out.tolist() == [True, True]


def get_host_array(a):
assert isinstance(a, np.ndarray)
while True:
if isinstance(a, memoryview):
a = a.obj
elif isinstance(a, np.ndarray):
if a.base is not None:
a = a.base
else:
return a
else:
# distributed.comm.utils.host_array() uses numpy.empty()
raise AssertionError("unreachable") # pragma: nocover


@gen_cluster(client=True, config={"distributed.p2p.disk": False})
async def test_rechunk_in_memory_shards_dont_share_buffer(c, s, a, b):
"""Test that, if two shards are sent in the same RPC call and they contribute to
different output chunks, downstream tasks don't need to consume all output chunks in
order to release the memory of the output chunks that have already been consumed.
This can happen if all numpy arrays in the same RPC call share the same buffer
coming out of the TCP stack.
"""
in_map = Event()
block_map = Event()

def blocked(chunk, in_map, block_map):
in_map.set()
block_map.wait()
return chunk

# 8 MiB array, 256 kiB chunks, 8 kiB shards
arr = da.random.random((1024, 1024), chunks=(-1, 32))
arr = arr.rechunk((32, -1), method="p2p")

arr = arr.map_blocks(blocked, in_map=in_map, block_map=block_map, dtype=arr.dtype)
fut = c.compute(arr)
await in_map.wait()

[run] = a.extensions["shuffle"].shuffle_runs._runs
shards = [
s3 for s1 in run._disk_buffer._shards.values() for s2 in s1 for _, s3 in s2
]

assert shards
for shard in shards:
buf = get_host_array(shard)
# Allow some margin for storing the pickle metadata
assert buf.nbytes <= shard.nbytes * 1.1

await block_map.set()

0 comments on commit 4dccfd6

Please sign in to comment.