Skip to content

Commit

Permalink
[Datasets] Change map_batches to fetch input blocks on-demand (ray-…
Browse files Browse the repository at this point in the history
…project#29289)

Signed-off-by: Cheng Su [email protected]

This is the fix the issue we found during AIR benchmark. When the map_batches have multiple input blocks (it can happen when dynamic block splitting is enabled by default, or multiple input blocks are coalesced together), previously we always fetch and buffer all input blocks before producing first batch. This is bad especially for dynamic block splitting, because it essentially buffers all split blocks again in memory. So in this PR, change map_batches to fetch and buffer input blocks on-demand, i.e. only fetch blocks when needed to construct the next required batch.

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
c21 authored and WeichenXu123 committed Dec 19, 2022
1 parent cbce66d commit 05c2d34
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 5 deletions.
21 changes: 16 additions & 5 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,8 @@ def transform(
output_buffer = BlockOutputBuffer(None, context.target_max_block_size)
# Ensure that zero-copy batch views are copied so mutating UDFs don't error.
batcher = Batcher(batch_size, ensure_copy=batch_size is not None)
for block in blocks:
batcher.add(block)
batcher.done_adding()
while batcher.has_any():
batch = batcher.next_batch()

def process_next_batch(batch: Block) -> Iterator[Block]:
# Convert to batch format.
batch = BlockAccessor.for_block(batch).to_batch_format(batch_format)
# Apply UDF.
Expand All @@ -566,6 +563,20 @@ def transform(
if output_buffer.has_next():
yield output_buffer.next()

# Process batches for each block.
for block in blocks:
batcher.add(block)
while batcher.has_batch():
batch = batcher.next_batch()
yield from process_next_batch(batch)

# Process any last remainder batch.
batcher.done_adding()
if batcher.has_any():
batch = batcher.next_batch()
yield from process_next_batch(batch)

# Yield remainder block from output buffer.
output_buffer.finalize()
if output_buffer.has_next():
yield output_buffer.next()
Expand Down
72 changes: 72 additions & 0 deletions python/ray/data/tests/test_dynamic_block_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
import pandas as pd
import pytest

import ray
from ray.data.block import BlockMetadata
from ray.data.datasource import Datasource
from ray.data.datasource.datasource import ReadTask, Reader

from ray.tests.conftest import * # noqa


def test_read_large_data(ray_start_cluster):
# Test 20G input with single task
num_batch = 20
ctx = ray.data.context.DatasetContext.get_current()
block_splitting_enabled = ctx.block_splitting_enabled
ctx.block_splitting_enabled = True

try:
cluster = ray_start_cluster
cluster.add_node(num_cpus=1)

ray.init(cluster.address)

# Data source generates multiple 1G random bytes data
class LargeBytesDatasource(Datasource):
def create_reader(self, **read_args):
return LargeBytesReader()

class LargeBytesReader(Reader):
def estimate_inmemory_data_size(self):
return None

def get_read_tasks(self, parallelism: int):
def _1g_batches_generator():
for _ in range(num_batch):
yield pd.DataFrame(
{"one": [np.random.bytes(1024 * 1024 * 1024)]}
)

return parallelism * [
ReadTask(
lambda: _1g_batches_generator(),
BlockMetadata(
num_rows=None,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
),
)
]

def foo(batch):
return pd.DataFrame({"one": [1]})

ds = ray.data.read_datasource(
LargeBytesDatasource(),
parallelism=1,
)

ds = ds.map_batches(foo, batch_size=None)
assert ds.count() == num_batch
finally:
ctx.block_splitting_enabled = block_splitting_enabled


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 05c2d34

Please sign in to comment.