diff --git a/python/ray/data/_internal/plan.py b/python/ray/data/_internal/plan.py index 7defa3a6e0f5..9253027c50e2 100644 --- a/python/ray/data/_internal/plan.py +++ b/python/ray/data/_internal/plan.py @@ -329,6 +329,8 @@ def _optimize(self) -> Tuple[BlockList, DatasetStats, List[Stage]]: """ context = DatasetContext.get_current() blocks, stats, stages = self._get_source_blocks_and_stages() + if context.optimize_reorder_stages: + stages = _reorder_stages(stages) if context.optimize_fuse_stages: if context.optimize_fuse_read_stages: # If using a lazy datasource, rewrite read stage into one-to-one stage @@ -692,6 +694,18 @@ def __call__( return blocks, stage_info +class RandomizeBlocksStage(AllToAllStage): + def __init__(self, seed: Optional[int]): + def do_randomize_block_order(block_list, *_): + num_blocks = block_list.executed_num_blocks() # Blocking. + if num_blocks == 0: + return block_list, {} + randomized_block_list = block_list.randomize_block_order(seed) + return randomized_block_list, {} + + super().__init__("randomize_block_order", None, do_randomize_block_order) + + def _rewrite_read_stages( blocks: BlockList, stats: DatasetStats, @@ -741,6 +755,36 @@ def block_fn(read_fn: Callable[[], Iterator[Block]]) -> Iterator[Block]: return block_list, stats, stage +def _reorder_stages(stages: List[Stage]) -> List[Stage]: + """Reorder randomize stages to the end to enable better stage fusion. + + This applies to RandomizeBlockOrder stages specifically (issue #26057). + + Args: + stages: Stages to try to reorder. + + Returns: + Reordered stages. + """ + + output: List[Stage] = [] + reorder_buf: List[RandomizeBlocksStage] = [] + + for s in stages: + if isinstance(s, RandomizeBlocksStage): + # Buffer it for later reordering. + reorder_buf.append(s) + else: + # Barrier: flush the reorder buffer. + if isinstance(s, AllToAllStage): + output.extend(reorder_buf) + reorder_buf = [] + output.append(s) + + output.extend(reorder_buf) + return output + + def _fuse_one_to_one_stages(stages: List[Stage]) -> List[Stage]: """Fuses compatible one-to-one stages. diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 056e19b386b1..4e9778cb6e35 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -23,6 +23,9 @@ # Whether to enable stage-fusion optimizations for dataset pipelines. DEFAULT_OPTIMIZE_FUSE_STAGES = True +# Whether to enable stage-reorder optimizations for dataset pipelines. +DEFAULT_OPTIMIZE_REORDER_STAGES = True + # Whether to furthermore fuse read stages. When this is enabled, data will also be # re-read from the base dataset in each repetition of a DatasetPipeline. DEFAULT_OPTIMIZE_FUSE_READ_STAGES = True @@ -62,6 +65,7 @@ def __init__( optimize_fuse_stages: bool, optimize_fuse_read_stages: bool, optimize_fuse_shuffle_stages: bool, + optimize_reorder_stages: bool, actor_prefetcher_enabled: bool, use_push_based_shuffle: bool, pipeline_push_based_shuffle_reduce_tasks: bool, @@ -76,6 +80,7 @@ def __init__( self.optimize_fuse_stages = optimize_fuse_stages self.optimize_fuse_read_stages = optimize_fuse_read_stages self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages + self.optimize_reorder_stages = optimize_reorder_stages self.actor_prefetcher_enabled = actor_prefetcher_enabled self.use_push_based_shuffle = use_push_based_shuffle self.pipeline_push_based_shuffle_reduce_tasks = ( @@ -104,6 +109,7 @@ def get_current() -> "DatasetContext": optimize_fuse_stages=DEFAULT_OPTIMIZE_FUSE_STAGES, optimize_fuse_read_stages=DEFAULT_OPTIMIZE_FUSE_READ_STAGES, optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES, + optimize_reorder_stages=DEFAULT_OPTIMIZE_REORDER_STAGES, actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED, use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE, # NOTE(swang): We have to pipeline reduce tasks right now diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f2127e5edebf..bbaaeaae3854 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -35,7 +35,12 @@ from ray.data._internal.fast_repartition import fast_repartition from ray.data._internal.lazy_block_list import LazyBlockList from ray.data._internal.output_buffer import BlockOutputBuffer -from ray.data._internal.plan import AllToAllStage, ExecutionPlan, OneToOneStage +from ray.data._internal.plan import ( + AllToAllStage, + ExecutionPlan, + OneToOneStage, + RandomizeBlocksStage, +) from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.shuffle_and_partition import ( @@ -169,6 +174,8 @@ def __init__( plan: ExecutionPlan, epoch: int, lazy: bool, + *, + defer_execution: bool = False, ): """Construct a Dataset (internal API). @@ -183,7 +190,7 @@ def __init__( self._epoch = epoch self._lazy = lazy - if not lazy: + if not lazy and not defer_execution: self._plan.execute(allow_clear_input_blocks=False) @staticmethod @@ -808,26 +815,11 @@ def randomize_block_order( based on system randomness. Returns: - The shuffled dataset. + The block-shuffled dataset. """ - def do_randomize_block_order(block_list, *_): - num_blocks = block_list.executed_num_blocks() # Blocking. - if num_blocks == 0: - return block_list, {} - - randomized_block_list = block_list.randomize_block_order(seed) - - return randomized_block_list, {} - - plan = self._plan.with_stage( - AllToAllStage( - "randomize_block_order", - None, - do_randomize_block_order, - ) - ) - return Dataset(plan, self._epoch, self._lazy) + plan = self._plan.with_stage(RandomizeBlocksStage(seed)) + return Dataset(plan, self._epoch, self._lazy, defer_execution=True) def random_sample( self, fraction: float, *, seed: Optional[int] = None diff --git a/python/ray/data/tests/test_optimize.py b/python/ray/data/tests/test_optimize.py index c449eaa1438a..95fd02272121 100644 --- a/python/ray/data/tests/test_optimize.py +++ b/python/ray/data/tests/test_optimize.py @@ -9,6 +9,7 @@ import ray from ray._private.internal_api import memory_summary +from ray.data import Dataset from ray.data.block import BlockMetadata from ray.data.context import DatasetContext from ray.data.datasource import Datasource, ReadTask @@ -48,7 +49,14 @@ def expect_stages(pipe, num_stages_expected, stage_names): for name in stage_names: name = " " + name + ":" assert name in stats, (name, stats) - assert len(pipe._optimized_stages) == num_stages_expected, pipe._optimized_stages + if isinstance(pipe, Dataset): + assert ( + len(pipe._plan._stages_before_snapshot) == num_stages_expected + ), pipe._plan._stages_before_snapshot + else: + assert ( + len(pipe._optimized_stages) == num_stages_expected + ), pipe._optimized_stages def test_memory_sanity(shutdown_only): @@ -294,6 +302,32 @@ def test_stage_linking(ray_start_regular_shared): _assert_has_stages(ds._plan._last_optimized_stages, ["read->map"]) +def test_optimize_reorder(ray_start_regular_shared): + context = DatasetContext.get_current() + context.optimize_fuse_stages = True + context.optimize_fuse_read_stages = True + context.optimize_reorder_stages = True + + ds = ray.data.range(10).randomize_block_order().map_batches(lambda x: x) + expect_stages( + ds, + 2, + ["read->map_batches", "randomize_block_order"], + ) + + ds2 = ( + ray.data.range(10) + .randomize_block_order() + .repartition(10) + .map_batches(lambda x: x) + ) + expect_stages( + ds2, + 3, + ["read", "randomize_block_order", "repartition", "map_batches"], + ) + + def test_optimize_fuse(ray_start_regular_shared): context = DatasetContext.get_current()