Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero-copy array shuffle #8282

Merged
merged 3 commits into from
Nov 7, 2023
Merged

Conversation

crusaderky
Copy link
Collaborator

@crusaderky crusaderky commented Oct 18, 2023

Perform the P2P rechunk of dask.array objects without ever deep-copying the buffers.

Currently not working yet due to what looks like a trivial bug somewhere; however I'm eager to receive an early feedback on the design.

CC @fjetter @hendrikmakait

Step main this PR
1. break into shards create views of the original monolithic array Ensure that the shards are never views. This lets us release the input chunk faster. First and only deep-copy.
2. _add_partition pickle.dumps the views into bytes. First deep-copy. Preserve the numpy arrays from step 1
3. RPC call Send the message with the bytes objects to the network stack. Here they are pickled again into another, monolithic bytes object (second deep-copy) which is then sent through the network. Send the message containing the numpy views created at step 1 to the network stack. Here, pickle5 creates memoryviews of the numpy buffers, which are then handed over to the kernel. All memory still points to the shards created at step 1.
4. _receive The network stack unpickles the message (third deep-copy); then the rpc system calls _receive, which in turn unpickles the individual bytes objects (fourth deep-copy), reorganizes them, and re-pickles them (fifth deep-copy). The network stack unpickles the message into the same numpy arrays produced at step 1. Their buffers point to the host_buffer (numpy.empty) created by the network stack when receiving the raw data, just like in Worker.gather_dep(). Note that this introduces an unnecessary unpickle->pickle round-trip, much like in a gather_dep->spill situation, but it should be trivial as long as there are no object dtypes.
5. write to disk sequentially write the bytes objects to disk Serialize the numpy arrays into pickle.Buffer objects + metadata and then write the frames to disk. This is like in the SpillBuffer except that you have multiple objects per file.
6. read from disk Read the bytes from disk, then unpickle the numpy arrays into the same objects created at step 1. Sixth deep-copy. Create a memory-mapped memoryview and read only the metadata from disk. Deserialize everything; this causes a bunch of seek calls interleaved with short reads. The numpy buffers point to the disk; there's an open file descriptor.
7. merge The shards in memory are passed to concatenate3 and then released. Seventh and final deep-copy. The memory-mapped shards are passed to concatenate3, which transparently performs the necessary disk reads as it writes into the final location in memory. As soon as the shards are dereferenced, the memory-mapped file descriptor is closed.

@crusaderky crusaderky self-assigned this Oct 18, 2023
@@ -157,11 +160,21 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with self._directory_lock.read():
if self._closed:
raise RuntimeError("Already closed")
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
Copy link
Collaborator Author

@crusaderky crusaderky Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This huge buffering setting was just wasting memory.

@fjetter
Copy link
Member

fjetter commented Oct 18, 2023

Eradicating seven deep copies? Bold claim! Very exciting! Haven't reviewed the changes, yet

else:
# Unserialized numpy arrays
frames = concat(
serialize_bytelist(shard, compression=False)
Copy link
Collaborator Author

@crusaderky crusaderky Oct 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compression=False is crucial for memory-mapped reads later.
We could however consider enabling it at a later point without changing the design (but it will require offloading).

@github-actions
Copy link
Contributor

github-actions bot commented Oct 18, 2023

Unit Test Results

See test report for an extended history of previous test failures. This is useful for diagnosing flaky tests.

       27 files  ±  0         27 suites  ±0   14h 26m 18s ⏱️ + 1h 9m 48s
  3 924 tests +  5    3 800 ✔️ +  6     117 💤 ±0  6  - 2  1 🔥 +1 
49 365 runs  +63  46 944 ✔️ +60  2 411 💤 +3  9  - 1  1 🔥 +1 

For more details on these failures and errors, see this check.

Results for commit 1b26ec6. ± Comparison against base commit 010f896.

♻️ This comment has been updated with latest results.

@crusaderky crusaderky force-pushed the zero_copy_numpy_shuffle branch 7 times, most recently from 9c648b3 to f657c54 Compare October 21, 2023 01:10
@crusaderky
Copy link
Collaborator Author

Functional bug fixed; moving to perf tests

@crusaderky
Copy link
Collaborator Author

crusaderky commented Oct 22, 2023

I've got some early LocalCluster benchmarks and they look very exciting.
coiled-benchmark A/B tests will follow.

import pickle
import dask.array as da
import distributed
from distributed.diagnostics.memory_sampler import MemorySampler

client = distributed.Client(n_workers=4, threads_per_worker=2)

N = 250  # 29 GiB
old = ((50, 40, 35, 25, 45, 55),) * 4  # 3 to 180 MiB chunks
new = ((40, 35, 45, 70, 60),) * 4
a = da.random.random((N, N, N, N), chunks=old)
b = a.rechunk(chunks=new, method="p2p")
c = b.sum()

try:
    with open("ms.pickle", "rb") as fh:
        ms = pickle.load(fh)
except FileNotFoundError:
    ms = MemorySampler()
    
with ms.sample("main", interval=0.1):
    c.compute()

with open("ms.pickle", "wb") as fh:
    pickle.dump(ms, fh)

Legend

main

main branch

zero-copy mmap

What's described in the opening post. Data is never deep-copied, except when slicing by column and possibly when the kernel performs mmap caching.

zero-copy no-mmap

Variant that doesn't use memory-mapped I/O and removes a wealth of seek() and short read() disk accesses at kernel level, replacing them with a single monolithic read(), like in the SpillBuffer, followed by a deep-copy in concatenate3()

copy-on-shard mmap

Variant that reintroduces a single, explicit deep-copy as soon as the shards are cut out of the original chunk, to allow releasing it faster.

Screenshot from 2023-10-21 21-59-29

I don't quite understand why the copy-on-shard variant is faster than the zero-copy one. Unless pickle.dumps is much slower than ndarray.copy() at generating a non-view buffer for some reason...

@crusaderky
Copy link
Collaborator Author

The initial A/B test results leave me a bit miffed. They deliver the expected performance boost for trivial dataset sizes, but the benefit seems to disappear as the dataset grows. I'll need to investigate why. It doesn't make sense to me; p2p runtime should scale linearly with data size and p2p memory usage should be constant.

image
image

@hendrikmakait
Copy link
Member

hendrikmakait commented Oct 23, 2023

Looking at the metrics for test_swap_axes[1-128 MiB-p2p], we seem to be mostly IO-bound with the baseline already. So, I wouldn't expect these optimizations to have much effect on runtime there. Instead, I'd expect to see significant performance benefits for diskless rechunking using #8279.

@crusaderky
Copy link
Collaborator Author

More tests highlight a substantial performance regression for small chunk sizes. Investigating.

image

@crusaderky
Copy link
Collaborator Author

reproduced locally. There's a leak - trying to hunt it down.

image

@hendrikmakait
Copy link
Member

FYI, this may or may not be related to a leak I've identified (but not yet root-caused) for many chunks with disk=False.

@crusaderky
Copy link
Collaborator Author

The "leak" was just a warmup artifact. Not sure why the warmup is larger than in the main branch, but the delta is negligible IMHO.
Here's 10 sequential runs of swap_axes on 4 workers, 8 MiB chunks, without restarting the workers, after I set the deque lengths to zero. The effect is completely invisible on 128 MiB chunks.

image

@crusaderky
Copy link
Collaborator Author

crusaderky commented Oct 23, 2023

The slowdown with 8 MiB chunks (which in turn generate 8 kiB shards) was caused by the fact that serialize_bytelist is 10x slower than pickle. I've written a stripped down variant which is just as fast. This IMHO really highlights the need to drastically simplify the serialization stack.

A/B tests, including the new diskless p2p:

Screenshot from 2023-10-24 00-39-51
Screenshot from 2023-10-24 00-40-04

Detail on the two regressions shows that these use cases are, in fact, 20-25% slower, but the increased memory consumption is quite negligible:

image

image

# Don't wait until all shards have been transferred over the network
# before data can be released
if shard.base is not None:
shard = shard.copy()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicit deep-copy mentioned in step 3 of the opening post

@crusaderky crusaderky marked this pull request as ready for review October 23, 2023 23:05
@crusaderky crusaderky changed the title [WIP] Zero-copy array shuffle Zero-copy array shuffle Oct 23, 2023
@fjetter
Copy link
Member

fjetter commented Oct 24, 2023

First pass over this looks good. I'm curious about the regression, though.

IIUC we're now much faster for reasonably sized arrays but slower for arrays with very small chunks. In a perfect world, we would be agnostic to chunk sizes using P2P (at least first order; using smart buffering and such things).

Do you see a way to reduce that penalty in a follow up PR?

Copy link
Member

@hendrikmakait hendrikmakait left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial pass looks good; I'll look at a few workloads up close before signing off.

distributed/shuffle/_buffer.py Outdated Show resolved Hide resolved
distributed/shuffle/_rechunk.py Outdated Show resolved Hide resolved
distributed/shuffle/_rechunk.py Outdated Show resolved Hide resolved
@crusaderky
Copy link
Collaborator Author

crusaderky commented Oct 24, 2023

Do you see a way to reduce that penalty in a follow up PR?

I have not run a profile to understand where the slowdown is. I suspect it may be in the general serialization and networking stack, in which case it would be hard (but very worthy, considering it would impact all network comms, not just shuffle).

@crusaderky
Copy link
Collaborator Author

I suspect that the zero-copy approach causes us to hold onto the buffers created by the network stack in step 4 and only frees them once all shards from that buffer have been freed.

Yes, that's exactly the intended behaviour. And that should happen immediately after concatenate3().
You have a 200 GiB array in memory during your map_blocks(sleep) - that's the output of concatenate3(). Not entirely sure why you expected to see less RAM usage?

@hendrikmakait
Copy link
Member

hendrikmakait commented Oct 25, 2023

So that we're the same page, here's the exact code I execute:

def _(x):
    time.sleep(1)
    return x

arr = da.random.random(size, chunks=input_chunks)
with dask.config.set({"array.rechunk.method": rechunk_method, "distributed.p2p.disk": disk}):
   client.compute(arr.rechunk(output_chunks).map_blocks(_).sum(), sync=True)

The important bit is the sum() as an aggregation at the end.

Ideally, I'd expect the output chunks that get aggregated into the partial results of the sum to vanish and instantly free their memory. From what I see, that does not happen. Instead, memory is only freed once the last (few) output chunks have been processed. I realize this is a trade-off, but this new behavior may hold the outputs much longer in memory if there is a more involved processing/reduction chain in place than .sum().

This may be more desirable than the previous sluggish rechunking, but it's a significant change in the memory footprint.

@hendrikmakait
Copy link
Member

Yes, that's exactly the intended behaviour. And that should happen immediately after concatenate3().
You have a 200 GiB array in memory during your map_blocks(sleep) - that's the output of concatenate3(). Not entirely sure why you expected to see less RAM usage?

To reiterate: If the buffer contains all shards of a single input chunk that belong to the worker, this means that it will contain a shard for every output chunk on that worker. To free this buffer, we need to concatenate3 all output buffers.

@crusaderky
Copy link
Collaborator Author

I realize this is a trade-off, but this new behavior may hold the outputs much longer in memory if there is a more involved processing/reduction chain in place than .sum().

I understand now. This is definitely not OK. I'll figure out why some mmap'ed buffers are surviving concatenation.

@hendrikmakait
Copy link
Member

I'll figure out why some mmap'ed buffers are surviving concatenation.

IIUC, this should have nothing to do with mmap. This occurs when I use in-memory buffering, i.e, disk=False.

@crusaderky
Copy link
Collaborator Author

To reiterate: If the buffer contains all shards of a single input chunk that belong to the worker, this means that it will contain a shard for every output chunk on that worker. To free this buffer, we need to concatenate3 all output buffers.

It should not be like that.

The shards in the send phase are deep-copied immediately after sharding, so they don't hold a reference to the buffer of the original chunk. FWIW the initial version, without the explicit deep copy, was just mildly more memory intensive and slower than the one with the deep copy.

The shards in the receive phase share a single buffer per RPC call. This is a known issue (can't find the ticket). So the memory for all the shards in a RPC call will be released, all at once, when the last of them is written to disk and dereferenced.

The shards in the read from disk phase in the same file share the same buffer, so the buffer will be released when all the shards are released. However if I understand correctly there's one file per output chunk, so all shards in a file should contribute to the same concanate3 and be released all at once afterwards?

@crusaderky
Copy link
Collaborator Author

I'll figure out why some mmap'ed buffers are surviving concatenation.

IIUC, this should have nothing to do with mmap. This occurs when I use in-memory buffering, i.e, disk=False.

Ah, this makes sense. You're retaining the per-RPC call aggregation of the buffers until all concatenate calls have completed. I need to introduce an explicit deep-copy specifically for memory.

@fjetter
Copy link
Member

fjetter commented Oct 26, 2023

I'm not an expert here but related to the above "memory leak", maybe it is a good and easy thing to just copy the output array before we return it in get_output_partition instead of doing something super smart. This way we can preserve internal zero-copy and have a very unsophisticated copy at the outer layer. At the same time, this would allow us to ensure this array is a contiguous memory region which has benefits for other kinds of operations that happen afterwards.
As I said, I'm not an expert here so if this doesn't make sense, I'm happy to be overruled

Comment on lines +1198 to +1192
[run] = a.extensions["shuffle"].shuffle_runs._runs
shards = [
s3 for s1 in run._disk_buffer._shards.values() for s2 in s1 for _, s3 in s2
]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hendrikmakait this is ugly; could you think of a nicer way to get the shards in the MemoryBuffer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't really an API to get all the shards from the buffer as it wasn't designed for that.

@crusaderky
Copy link
Collaborator Author

crusaderky commented Oct 29, 2023

Workload from #8282 (comment)
10 workers, p2p rechunk in memory, map(sleep 1) before reduce

#8282 + #8308 vs. main
🔥 🔥 🔥 🔥

image

@crusaderky
Copy link
Collaborator Author

#8282 + #8308

Moving on to investigate the regression use case

image
image

@crusaderky
Copy link
Collaborator Author

crusaderky commented Oct 30, 2023

I rebased this PR; it's now blocked by and incorporates #8308 and #8318.

@crusaderky
Copy link
Collaborator Author

Moving on to investigate the regression use case

image

This is fixed in #8321.

Copy link
Member

@hendrikmakait hendrikmakait left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The improvements look great! Thanks, @crusaderky, for diving deeper into this and giving rechunking a thorough treatment. There are some bits and pieces in the code that feel less clean, but I'm also refactoring the buffers, so I don't expect those to stick around but become part of the refactored version.


frames: Iterable[bytes | bytearray | memoryview]

if not shards or isinstance(shards[0], bytes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle not shards here? IIUC, we should not attempt to write an empty list.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to change it to assert shards and see if anything breaks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give that a shot, if this breaks any tests that would mean something doesn't work as (I) expected.

Comment on lines +1198 to +1192
[run] = a.extensions["shuffle"].shuffle_runs._runs
shards = [
s3 for s1 in run._disk_buffer._shards.values() for s2 in s1 for _, s3 in s2
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't really an API to get all the shards from the buffer as it wasn't designed for that.

@crusaderky
Copy link
Collaborator Author

There are some bits and pieces in the code that feel less clean

Fully agree. I particularly dislike that there's half of the (de)serialization logic in _disk.py and the other half in _rechunk.py.

@hendrikmakait hendrikmakait merged commit a387b3b into dask:main Nov 7, 2023
24 of 34 checks passed
@crusaderky crusaderky deleted the zero_copy_numpy_shuffle branch November 7, 2023 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants