From 05c2d3414127849f8cdd0df852291369693480a0 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 27 Oct 2022 12:47:12 -0700 Subject: [PATCH] [Datasets] Change `map_batches` to fetch input blocks on-demand (#29289) Signed-off-by: Cheng Su scnju13@gmail.com This is the fix the issue we found during AIR benchmark. When the map_batches have multiple input blocks (it can happen when dynamic block splitting is enabled by default, or multiple input blocks are coalesced together), previously we always fetch and buffer all input blocks before producing first batch. This is bad especially for dynamic block splitting, because it essentially buffers all split blocks again in memory. So in this PR, change map_batches to fetch and buffer input blocks on-demand, i.e. only fetch blocks when needed to construct the next required batch. Signed-off-by: Weichen Xu --- python/ray/data/dataset.py | 21 ++++-- .../data/tests/test_dynamic_block_split.py | 72 +++++++++++++++++++ 2 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 python/ray/data/tests/test_dynamic_block_split.py diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4a756de94d65..c4562cd46d41 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -536,11 +536,8 @@ def transform( output_buffer = BlockOutputBuffer(None, context.target_max_block_size) # Ensure that zero-copy batch views are copied so mutating UDFs don't error. batcher = Batcher(batch_size, ensure_copy=batch_size is not None) - for block in blocks: - batcher.add(block) - batcher.done_adding() - while batcher.has_any(): - batch = batcher.next_batch() + + def process_next_batch(batch: Block) -> Iterator[Block]: # Convert to batch format. batch = BlockAccessor.for_block(batch).to_batch_format(batch_format) # Apply UDF. @@ -566,6 +563,20 @@ def transform( if output_buffer.has_next(): yield output_buffer.next() + # Process batches for each block. + for block in blocks: + batcher.add(block) + while batcher.has_batch(): + batch = batcher.next_batch() + yield from process_next_batch(batch) + + # Process any last remainder batch. + batcher.done_adding() + if batcher.has_any(): + batch = batcher.next_batch() + yield from process_next_batch(batch) + + # Yield remainder block from output buffer. output_buffer.finalize() if output_buffer.has_next(): yield output_buffer.next() diff --git a/python/ray/data/tests/test_dynamic_block_split.py b/python/ray/data/tests/test_dynamic_block_split.py new file mode 100644 index 000000000000..fc0ed9889f24 --- /dev/null +++ b/python/ray/data/tests/test_dynamic_block_split.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas as pd +import pytest + +import ray +from ray.data.block import BlockMetadata +from ray.data.datasource import Datasource +from ray.data.datasource.datasource import ReadTask, Reader + +from ray.tests.conftest import * # noqa + + +def test_read_large_data(ray_start_cluster): + # Test 20G input with single task + num_batch = 20 + ctx = ray.data.context.DatasetContext.get_current() + block_splitting_enabled = ctx.block_splitting_enabled + ctx.block_splitting_enabled = True + + try: + cluster = ray_start_cluster + cluster.add_node(num_cpus=1) + + ray.init(cluster.address) + + # Data source generates multiple 1G random bytes data + class LargeBytesDatasource(Datasource): + def create_reader(self, **read_args): + return LargeBytesReader() + + class LargeBytesReader(Reader): + def estimate_inmemory_data_size(self): + return None + + def get_read_tasks(self, parallelism: int): + def _1g_batches_generator(): + for _ in range(num_batch): + yield pd.DataFrame( + {"one": [np.random.bytes(1024 * 1024 * 1024)]} + ) + + return parallelism * [ + ReadTask( + lambda: _1g_batches_generator(), + BlockMetadata( + num_rows=None, + size_bytes=None, + schema=None, + input_files=None, + exec_stats=None, + ), + ) + ] + + def foo(batch): + return pd.DataFrame({"one": [1]}) + + ds = ray.data.read_datasource( + LargeBytesDatasource(), + parallelism=1, + ) + + ds = ds.map_batches(foo, batch_size=None) + assert ds.count() == num_batch + finally: + ctx.block_splitting_enabled = block_splitting_enabled + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__]))