Skip to content

Commit

Permalink
[data] [3/3] [no_early_kickoff] Async iter batches e2e (ray-project#3…
Browse files Browse the repository at this point in the history
…3620)

Final PR for async iter_batches.

The new codepath is enabled for streaming execution. The old codepath is still accessible via a feature flag in DatasetContext.

Bulk execution still uses the old codepath by default.

This also deprecated prefetch_batches from map_batches since that is half baked and not entirely supported.

---------

Signed-off-by: amogkam <[email protected]>
Signed-off-by: Jack He <[email protected]>
  • Loading branch information
amogkam authored and ProjectsByJackHe committed May 4, 2023
1 parent f67b8fc commit 4394179
Show file tree
Hide file tree
Showing 31 changed files with 813 additions and 257 deletions.
2 changes: 1 addition & 1 deletion doc/source/ray-air/doc_code/air_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
datasets={"train": dataset},
preprocessor=preprocessor,
num_epochs=1, # Stop after this number of epochs is read.
prefetch_blocks=1, # Number of blocks to prefetch when reading data.
prefetch_batches=1, # Number of batches to prefetch when reading data.
batch_size=None, # Use whole blocks as batches.
)
trainer.fit()
Expand Down
23 changes: 15 additions & 8 deletions python/ray/air/util/check_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DummyTrainer(DataParallelTrainer):
scaling_config: Configuration for how to scale training. This is the same
as for :class:`~ray.train.base_trainer.BaseTrainer`.
num_epochs: How many many times to iterate through the datasets for.
prefetch_blocks: The number of blocks to prefetch ahead of the
prefetch_batches: The number of batches to prefetch ahead of the
current block during the scan. This is the same as
:meth:`~ray.data.dataset.Dataset.iter_batches`
time_preprocessing_separately: Whether to time the preprocessing separately
Expand All @@ -44,16 +44,18 @@ def __init__(
*args,
scaling_config: Optional[ScalingConfig] = None,
num_epochs: int = 1,
prefetch_blocks: int = 1,
prefetch_batches: int = 1,
batch_size: Optional[int] = 4096,
time_preprocessing_separately: bool = False,
# Deprecated.
prefetch_blocks: int = 0,
**kwargs,
):
if not scaling_config:
scaling_config = ScalingConfig(num_workers=1)
super().__init__(
train_loop_per_worker=DummyTrainer.make_train_loop(
num_epochs, prefetch_blocks, batch_size
num_epochs, prefetch_batches, prefetch_blocks, batch_size
),
*args,
scaling_config=scaling_config,
Expand Down Expand Up @@ -81,7 +83,10 @@ def preprocess_datasets(self):

@staticmethod
def make_train_loop(
num_epochs: int, prefetch_blocks: int, batch_size: Optional[int]
num_epochs: int,
prefetch_batches: int,
prefetch_blocks: int,
batch_size: Optional[int],
):
"""Make a debug train loop that runs for the given amount of epochs."""

Expand All @@ -99,7 +104,9 @@ def train_loop_per_worker():
epochs_read += 1
batch_start = time.perf_counter()
for batch in data_shard.iter_batches(
prefetch_blocks=prefetch_blocks, batch_size=batch_size
prefetch_batches=prefetch_batches,
prefetch_blocks=prefetch_blocks,
batch_size=batch_size,
):
batch_delay = time.perf_counter() - batch_start
batch_delays.append(batch_delay)
Expand Down Expand Up @@ -189,11 +196,11 @@ def make_local_dataset_iterator(
"--num-epochs", "-e", type=int, default=1, help="Number of epochs to read."
)
parser.add_argument(
"--prefetch-blocks",
"--prefetch-batches",
"-b",
type=int,
default=1,
help="Number of blocks to prefetch when reading data.",
help="Number of batches to prefetch when reading data.",
)

args = parser.parse_args()
Expand All @@ -215,7 +222,7 @@ def make_local_dataset_iterator(
datasets={"train": dataset},
preprocessor=preprocessor,
num_epochs=args.num_epochs,
prefetch_blocks=args.prefetch_blocks,
prefetch_batches=args.prefetch_batches,
dataset_config={"train": DatasetConfig()},
batch_size=None,
)
Expand Down
33 changes: 13 additions & 20 deletions python/ray/data/_internal/block_batching/block_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
format_batches,
collate,
extract_data_from_batch,
make_async_gen,
WaitBlockPrefetcher,
ActorBlockPrefetcher,
)
from ray.data._internal.memory_tracing import trace_deallocation
from ray.data._internal.stats import DatasetPipelineStats, DatasetStats
from ray.data.block import Block, DataBatch
from ray.data.context import DatasetContext
Expand Down Expand Up @@ -45,7 +45,6 @@ def batch_block_refs(
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
prefetch_batches: int = 0,
) -> Iterator[DataBatch]:
"""Create formatted batches of data from 1 or more block object references.
Expand Down Expand Up @@ -79,17 +78,12 @@ def batch_block_refs(
shuffle_seed: The seed to use for the local random shuffle.
ensure_copy: Whether batches are always copied from the underlying base
blocks (not zero-copy views).
prefetch_batches: The number of batches to fetch ahead of the current batch to
process. If set to greater than 0, a separate thread will be used to fetch
the specified amount of formatted batches from blocks. This improves
performance for non-CPU bound UDFs, allowing batch fetching compute and
formatting to be overlapped with the UDF. Defaults to 0 (no prefetching
enabled).
Returns:
An iterator over record batches.
"""

if stats:
stats._legacy_iter_batches = True
context = DatasetContext.get_current()

if (
Expand All @@ -107,11 +101,10 @@ def batch_block_refs(
_prefetch_blocks(
block_ref_iter=block_refs,
prefetcher=prefetcher,
stats=stats,
num_blocks_to_prefetch=prefetch_blocks,
eager_free=eager_free,
),
stats=stats,
eager_free=eager_free,
)

yield from batch_blocks(
Expand All @@ -124,7 +117,6 @@ def batch_block_refs(
shuffle_buffer_min_size=shuffle_buffer_min_size,
shuffle_seed=shuffle_seed,
ensure_copy=ensure_copy,
prefetch_batches=prefetch_batches,
)


Expand All @@ -139,7 +131,6 @@ def batch_blocks(
shuffle_buffer_min_size: Optional[int] = None,
shuffle_seed: Optional[int] = None,
ensure_copy: bool = False,
prefetch_batches: int = 0,
) -> Iterator[DataBatch]:
"""Create formatted batches of data from 1 or more blocks.
Expand All @@ -164,17 +155,12 @@ def _iterator_fn(base_iterator: Iterator[Block]) -> Iterator[DataBatch]:
)

if collate_fn is not None:
batch_iter = collate(batch_iter, collate_fn=collate_fn)
batch_iter = collate(batch_iter, collate_fn=collate_fn, stats=stats)

batch_iter = extract_data_from_batch(batch_iter)
yield from batch_iter

if prefetch_batches > 0:
batch_iter = make_async_gen(
blocks, fn=_iterator_fn, num_workers=prefetch_batches
)
else:
batch_iter = _iterator_fn(blocks)
batch_iter = _iterator_fn(blocks)

for formatted_batch in batch_iter:
user_timer = stats.iter_user_s.timer() if stats else nullcontext()
Expand All @@ -186,6 +172,7 @@ def _prefetch_blocks(
block_ref_iter: Iterator[ObjectRef[Block]],
prefetcher: BlockPrefetcher,
num_blocks_to_prefetch: int,
eager_free: bool = False,
stats: Optional[Union[DatasetStats, DatasetPipelineStats]] = None,
) -> Iterator[ObjectRef[Block]]:
"""Given an iterable of Block Object References, returns an iterator
Expand All @@ -201,6 +188,9 @@ def _prefetch_blocks(
if num_blocks_to_prefetch == 0:
for block_ref in block_ref_iter:
yield block_ref
trace_deallocation(
block_ref, "block_batching._prefetch_blocks", free=eager_free
)

window_size = num_blocks_to_prefetch
# Create the initial set of blocks to prefetch.
Expand All @@ -219,3 +209,6 @@ def _prefetch_blocks(
except StopIteration:
pass
yield block_ref
trace_deallocation(
block_ref, "block_batching._prefetch_blocks", free=eager_free
)
Loading

0 comments on commit 4394179

Please sign in to comment.