Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Async batch fetching for map_batches #31576

Merged
merged 23 commits into from
Jan 21, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion python/ray/data/_internal/block_batching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import collections
import itertools
import queue
import sys
from typing import Iterator, Optional, Union
import threading
from typing import Iterator, Optional, TypeVar, Union

import ray
from ray.actor import ActorHandle
Expand All @@ -13,6 +15,7 @@
from ray.types import ObjectRef
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

T = TypeVar("T")

if sys.version_info >= (3, 7):
from contextlib import nullcontext
Expand All @@ -39,6 +42,7 @@ def batch_block_refs(
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
prefetch_batches: int = 0,
) -> Iterator[DataBatch]:
"""Create formatted batches of data from 1 or more block object references.

Expand Down Expand Up @@ -71,6 +75,12 @@ def batch_block_refs(
shuffle_seed: The seed to use for the local random shuffle.
ensure_copy: Whether batches are always copied from the underlying base
blocks (not zero-copy views).
prefetch_batches: The number of batches to fetch ahead of the current batch to
process. If set to greater than 0, a separate thread will be used to fetch
the specified amount of formatted batches from blocks. This improves
performance for non-CPU bound UDFs, allowing batch fetching compute and
formatting to be overlapped with the UDF. Defaults to 0 (no prefetching
enabled).

Returns:
An iterator over record batches.
Expand Down Expand Up @@ -107,6 +117,7 @@ def batch_block_refs(
shuffle_buffer_min_size=shuffle_buffer_min_size,
shuffle_seed=shuffle_seed,
ensure_copy=ensure_copy,
prefetch_batches=prefetch_batches,
)


Expand All @@ -120,6 +131,7 @@ def batch_blocks(
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
prefetch_batches: int = 0,
) -> Iterator[DataBatch]:
"""Create formatted batches of data from 1 or more blocks.

Expand All @@ -142,12 +154,57 @@ def batch_blocks(
stats=stats,
)

if prefetch_batches > 0:
batch_iter = _make_async_gen(batch_iter, prefetch_buffer_size=prefetch_batches)

for formatted_batch in batch_iter:
user_timer = stats.iter_user_s.timer() if stats else nullcontext()
with user_timer:
yield formatted_batch


def _make_async_gen(
amogkam marked this conversation as resolved.
Show resolved Hide resolved
base_iterator: Iterator[T], prefetch_buffer_size: int = 1
c21 marked this conversation as resolved.
Show resolved Hide resolved
amogkam marked this conversation as resolved.
Show resolved Hide resolved
) -> Iterator[T]:
"""Returns a new iterator with elements fetched from the base_iterator
in an async fashion using a background thread.

Args:
base_iterator: The iterator to asynchronously fetch from.
prefetch_buffer_size: The maximum number of items to prefetch. Increasing the
size allows for more computation overlap for very expensive downstream UDFs.
However it comes at the cost of additional memory overhead. Defaults to 1.

Returns:
An iterator with the same elements as the base_iterator.
"""

fetch_queue = queue.Queue(maxsize=prefetch_buffer_size)

sentinel = object()

def _async_fetch():
for item in base_iterator:
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
fetch_queue.put(item, block=True)
amogkam marked this conversation as resolved.
Show resolved Hide resolved

# Indicate done adding items.
fetch_queue.put(sentinel, block=True)

fetch_thread = threading.Thread(target=_async_fetch)
fetch_thread.start()

while True:
next_item = fetch_queue.get(block=True)
if next_item is not sentinel:
yield next_item
fetch_queue.task_done()
if next_item is sentinel:
break

fetch_queue.join()
fetch_thread.join()


def _resolve_blocks(
block_ref_iter: Iterator[ObjectRef[Block]],
stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None,
Expand Down
34 changes: 24 additions & 10 deletions python/ray/data/_internal/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _apply(
# Bin blocks by target block size.
if target_block_size is not None:
_check_batch_size(blocks, target_block_size, name)
block_bundles = _bundle_blocks_up_to_size(blocks, target_block_size, name)
block_bundles = _bundle_blocks_up_to_size(blocks, target_block_size)
else:
block_bundles = [((b,), (m,)) for b, m in blocks]
del blocks
Expand Down Expand Up @@ -254,15 +254,30 @@ def _apply(

if name is None:
name = "map"
blocks_in = block_list.get_blocks_with_metadata()
# Bin blocks by target block size.
blocks_in: List[
Tuple[ObjectRef[Block], BlockMetadata]
] = block_list.get_blocks_with_metadata()

# We bundle blocks according to the following rules:
# 1. Attempt to bundle up to the target block size.
# 2. If the max concurrency of the ActorPool is set, then
# cap the number of bundles to match the size of the ActorPool.
# This avoids additional overhead in submitting new actor tasks and allows
# the actor task to do optimizations such as batch prefetching.
target_num_bundles = float("inf")
if target_block_size is not None:
_check_batch_size(blocks_in, target_block_size, name)
block_bundles = _bundle_blocks_up_to_size(
blocks_in, target_block_size, name
)
total_size = sum(metadata.num_rows for _, metadata in blocks_in)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata.num_rows could technically be None, but shouldn't happen in practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to handle None the same way as _bundle_blocks_up_to_size

target_num_bundles = min(target_num_bundles, total_size / target_block_size)
target_num_bundles = min(target_num_bundles, self.max_size)
if not math.isinf(target_num_bundles):
target_block_size = total_size // target_num_bundles
block_bundles: List[
Tuple[Tuple[ObjectRef[Block]], Tuple[BlockMetadata]]
] = _bundle_blocks_up_to_size(blocks_in, target_block_size)
amogkam marked this conversation as resolved.
Show resolved Hide resolved
else:
block_bundles = [((b,), (m,)) for b, m in blocks_in]

del blocks_in
owned_by_consumer = block_list._owned_by_consumer

Expand Down Expand Up @@ -502,13 +517,12 @@ def _map_block_nosplit(
def _bundle_blocks_up_to_size(
blocks: List[Tuple[ObjectRef[Block], BlockMetadata]],
target_size: int,
name: str,
) -> List[Tuple[List[ObjectRef[Block]], List[BlockMetadata]]]:
) -> List[Tuple[Tuple[ObjectRef[Block]], Tuple[BlockMetadata]]]:
"""Group blocks into bundles that are up to (but not exceeding) the provided target
size.
"""
block_bundles = []
curr_bundle = []
block_bundles: List[List[Tuple[ObjectRef[Block], BlockMetadata]]] = []
curr_bundle: List[Tuple[ObjectRef[Block], BlockMetadata]] = []
curr_bundle_size = 0
for b, m in blocks:
num_rows = m.num_rows
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def map_batches(
batch_size: Optional[Union[int, Literal["default"]]] = "default",
compute: Optional[Union[str, ComputeStrategy]] = None,
batch_format: Literal["default", "pandas", "pyarrow", "numpy"] = "default",
prefetch_batches: int = 0,
zero_copy_batch: bool = False,
fn_args: Optional[Iterable[Any]] = None,
fn_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -480,6 +481,9 @@ def map_batches(
``pandas.DataFrame``, "pyarrow" to select ``pyarrow.Table``, or
``"numpy"`` to select ``numpy.ndarray`` for tensor datasets and
``Dict[str, numpy.ndarray]`` for tabular datasets. Default is "default".
prefetch_batches: The number of batches to fetch ahead of the current batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When porting this to the new executor, we should try to consolidate prefetch_batches and prefetch_blocks into a single prefetch_batches argument, where we always prefetch enough blocks to satisfy prefetch_batches, which should be simple enough to implement since we have the size for each to-be-fetched block on hand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep +1!

to process. If set to greater than 0, a separate thread will be used
to fetch the specified amount of formatted batches from blocks. This improves performance for non-CPU bound UDFs, allowing batch fetching compute and formatting to be overlapped with the UDF. Defaults to 0 (no prefetching enabled.) Increasing the number of batches to prefetch can result in higher throughput, at the expense of requiring more heap memory to buffer the batches.
zero_copy_batch: Whether ``fn`` should be provided zero-copy, read-only
batches. If this is ``True`` and no copy is required for the
``batch_format`` conversion, the batch will be a zero-copy, read-only
Expand Down Expand Up @@ -633,6 +637,7 @@ def process_next_batch(batch: DataBatch) -> Iterator[Block]:
batch_size=batch_size,
batch_format=batch_format,
ensure_copy=not zero_copy_batch and batch_size is not None,
prefetch_batches=prefetch_batches,
)

for batch in formatted_batch_iter:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/dataset_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def map_batches(
batch_size: Optional[Union[int, Literal["default"]]] = "default",
compute: Optional[Union[str, ComputeStrategy]] = None,
batch_format: Literal["default", "pandas", "pyarrow", "numpy"] = "default",
prefetch_batches: int = 0,
fn_args: Optional[Iterable[Any]] = None,
fn_kwargs: Optional[Dict[str, Any]] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
Expand All @@ -802,6 +803,7 @@ def map_batches(
batch_size=batch_size,
compute=compute,
batch_format=batch_format,
prefetch_batches=prefetch_batches,
fn_args=fn_args,
fn_kwargs=fn_kwargs,
fn_constructor_args=fn_constructor_args,
Expand Down
99 changes: 99 additions & 0 deletions python/ray/data/tests/test_block_batching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import time
from typing import List
from unittest import mock

Expand All @@ -14,6 +15,7 @@
_prefetch_blocks,
_blocks_to_batches,
_format_batches,
_make_async_gen,
)


Expand Down Expand Up @@ -121,6 +123,103 @@ def test_format_batches(batch_format):
assert isinstance(batch["foo"], np.ndarray)


def test_make_async_gen():
"""Tests that make_async_gen overlaps compute."""

num_items = 10

def gen():
for i in range(num_items):
time.sleep(2)
yield i

def sleep_udf(item):
time.sleep(3)
return item

iterator = _make_async_gen(gen())

start_time = time.time()
outputs = []
for item in iterator:
outputs.append(sleep_udf(item))
end_time = time.time()

assert outputs == list(range(num_items))

assert end_time - start_time < num_items * 3 + 3


def test_make_async_gen_buffer_size():
"""Tests that multiple items can be prefetched at a time
with larger buffer size."""

num_items = 5

def gen():
for i in range(num_items):
time.sleep(1)
yield i

def sleep_udf(item):
time.sleep(5)
return item

iterator = _make_async_gen(gen(), prefetch_buffer_size=4)

start_time = time.time()

# Only sleep for first item.
sleep_udf(next(iterator))

# All subsequent items should already be prefetched and should be ready.
for _ in iterator:
pass
end_time = time.time()

# 1 second for first item, 5 seconds for udf, 0.5 seconds buffer
assert end_time - start_time < 6.5


# Test for 3 cases
# 1. Batch size is less than block size
# 2. Batch size is more than block size
# 3. Block size is not divisble by batch size
@pytest.mark.parametrize("batch_size", [4, 10, 7])
def test_async_batch_fetching(batch_size):
blocks = block_generator(num_blocks=5, num_rows=8)

def sleep_batch_format(batch_iter, *args, **kwargs):
for batch in batch_iter:
time.sleep(2)
yield batch

with mock.patch(
"ray.data._internal.block_batching._format_batches", sleep_batch_format
):
batch_iter = batch_blocks(
batch_size=batch_size, blocks=blocks, prefetch_batches=1
)
outputs = []
start_time = time.time()
for batch in batch_iter:
time.sleep(3)
outputs.append(batch)
end_time = time.time()

total_time = end_time - start_time
# Total time should be based on number of times the udf is called
# (which is equal to len(outputs)).
# The 2 seconds sleep in sleep_batch_format is overlapped, so does not count
# towards total time.
assert total_time < len(outputs) * 3 + 3

# There should be no dropped rows.
assert sum(len(output_batch) for output_batch in outputs) == 40, sum(
len(output_batch) for output_batch in outputs
) # 5 blocks with 8 rows each.


if __name__ == "__main__":
import sys

Expand Down
1 change: 1 addition & 0 deletions python/ray/train/batch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def __call__(self, input_batch: DataBatchType) -> DataBatchType:
if self.get_preprocessor() is not None
else predict_stage_batch_format,
batch_size=batch_size,
prefetch_batches=int(num_gpus_per_worker > 0),
**ray_remote_args,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def to_tensor(batch: np.ndarray) -> torch.Tensor:
predictor.predict(
dataset,
num_gpus_per_worker=int(not smoke_test),
min_scoring_workers=1,
max_scoring_workers=int(ray.cluster_resources()["GPU"]),
batch_size=512,
)
total_time_s = round(time.time() - start, 2)
Expand Down
2 changes: 2 additions & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@
script: python workloads/gpu_batch_prediction.py --data-size-gb 100
type: job

wait_for_nodes:
num_nodes: 4

alert: default

Expand Down