From fb7b5462c6d26174c99d73d722e095402fceb3ed Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 27 Apr 2023 17:34:03 -0700 Subject: [PATCH 1/8] first work Signed-off-by: Scott Lee --- .../execution/streaming_executor_state.py | 2 + .../logical/rules/operator_fusion.py | 82 +++++++++++++------ python/ray/data/_internal/planner/sort.py | 6 +- .../data/tests/test_execution_optimizer.py | 11 +++ 4 files changed, 74 insertions(+), 27 deletions(-) diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py index 24d815bd4e7c..b8c2560f7325 100644 --- a/python/ray/data/_internal/execution/streaming_executor_state.py +++ b/python/ray/data/_internal/execution/streaming_executor_state.py @@ -204,6 +204,8 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> MaybeRefBundle try: # Non-split output case. if output_split_idx is None: + if not self.outqueue: + return None return self.outqueue.popleft() # Scan the queue and look for outputs tagged for the given index. diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index ea8e91dc6b59..3c55ba5b0f1b 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -1,11 +1,18 @@ -from typing import Iterator +from typing import Iterator, List, Tuple +from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll +from ray.data._internal.stats import StatsDict from ray.data.block import Block # TODO(Clark): Remove compute dependency once we delete the legacy compute. from ray.data._internal.compute import is_task_compute, CallableClass, get_compute -from ray.data._internal.execution.interfaces import PhysicalOperator, TaskContext +from ray.data._internal.execution.interfaces import ( + PhysicalOperator, + RefBundle, + TaskContext, +) from ray.data._internal.logical.interfaces import Rule, PhysicalPlan +from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator # Scheduling strategy can be inherited from upstream operator if not specified. @@ -56,8 +63,12 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: from ray.data._internal.logical.operators.map_operator import AbstractMap from ray.data._internal.logical.operators.map_operator import AbstractUDFMap - # We only support fusing MapOperators. - if not isinstance(down_op, MapOperator) or not isinstance(up_op, MapOperator): + # We currently only support fusing for the following cases: + # - MapOperator -> MapOperator + # - MapOperator -> AllToAllOperator + if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance( + up_op, MapOperator + ): return False down_logical_op = self._op_map[down_op] @@ -68,17 +79,20 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: if not down_logical_op._input_dependencies: return False - # We only support fusing AbstractMap -> AbstractMap operators. - if not isinstance(down_logical_op, AbstractMap) or not isinstance( - up_logical_op, AbstractMap - ): + # We currently only support fusing for the following cases: + # - AbstractMap -> AbstractMap + # - AbstractAllToAll -> AbstractMap + if not isinstance( + down_logical_op, (AbstractMap, AbstractAllToAll) + ) or not isinstance(up_logical_op, AbstractMap): return False # Allow fusing tasks->actors if the resources are compatible (read->map), but # not the other way around. The latter (downstream op) will be used as the # compute if fused. if ( - is_task_compute(down_logical_op._compute) + isinstance(down_logical_op, AbstractUDFMap) + and is_task_compute(down_logical_op._compute) and isinstance(up_logical_op, AbstractUDFMap) and get_compute(up_logical_op._compute) != get_compute(down_logical_op._compute) @@ -129,32 +143,48 @@ def _fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator): down_logical_op = self._op_map.pop(down_op) up_logical_op = self._op_map.pop(up_op) - # Merge target block sizes. - down_target_block_size = down_logical_op._target_block_size - up_target_block_size = ( - up_logical_op._target_block_size - if isinstance(up_logical_op, AbstractUDFMap) - else None - ) - if down_target_block_size is not None and up_target_block_size is not None: - target_block_size = max(down_target_block_size, up_target_block_size) - elif up_target_block_size is not None: - target_block_size = up_target_block_size - else: - target_block_size = down_target_block_size + compute = None + target_block_size = None + transform_fn = None # Fuse transformation functions. down_transform_fn = down_op.get_transformation_fn() up_transform_fn = up_op.get_transformation_fn() - def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: + def transform_fn_map_map( + blocks: Iterator[Block], ctx: TaskContext + ) -> Iterator[Block]: blocks = up_transform_fn(blocks, ctx) # TODO(Clark): Add zero-copy batching between transform functions. return down_transform_fn(blocks, ctx) - # We take the downstream op's compute in case we're fusing upstream tasks with a - # downstream actor pool (e.g. read->map). - compute = get_compute(down_logical_op._compute) + def transform_fn_map_alltoall( + blocks: List[RefBundle], ctx: TaskContext + ) -> Tuple[List[RefBundle], StatsDict]: + ctx.map_transform_fn = up_transform_fn + return down_transform_fn(blocks, ctx) + + if isinstance(down_logical_op, AbstractUDFMap): + # Merge target block sizes. + down_target_block_size = down_logical_op._target_block_size + up_target_block_size = ( + up_logical_op._target_block_size + if isinstance(up_logical_op, AbstractUDFMap) + else None + ) + if down_target_block_size is not None and up_target_block_size is not None: + target_block_size = max(down_target_block_size, up_target_block_size) + elif up_target_block_size is not None: + target_block_size = up_target_block_size + else: + target_block_size = down_target_block_size + transform_fn = transform_fn_map_map + # We take the downstream op's compute in case we're fusing upstream tasks with a + # downstream actor pool (e.g. read->map). + compute = get_compute(down_logical_op._compute) + elif isinstance(down_logical_op, AbstractAllToAll): + transform_fn = transform_fn_map_alltoall + ray_remote_args = down_logical_op._ray_remote_args # Make the upstream operator's inputs the new, fused operator's inputs. input_deps = up_op.input_dependencies diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index 955aac05c6d2..d0d978208f32 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List, Tuple +from typing import List, Optional, Tuple from ray.data._internal.execution.interfaces import ( AllToAllTransformFn, @@ -42,6 +42,10 @@ def fn( unified_schema = unify_block_metadata_schema(metadata) _validate_key_fn(unified_schema, key) + map_transform_fn: Optional[AllToAllTransformFn] = ctx.map_transform_fn + if map_transform_fn: + blocks = map_transform_fn(blocks) + if isinstance(key, str): key = [(key, "descending" if descending else "ascending")] if isinstance(key, list): diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 068a7849d953..64ea1cb6932d 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -602,6 +602,17 @@ def __call__(self, x): assert isinstance(physical_op.input_dependencies[0], InputDataBuffer) +def test_read_map_batches_operator_fusion_with_all_to_all_operator( + ray_start_regular_shared, enable_optimizer +): + ds = ray.data.range(10) + ds = ds.map_batches(lambda batch: [x + 1 for x in batch]) + ds = ds.sort() + assert ds.take_all() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert "DoRead->MapBatches->Sort" in ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "Sort"]) + + def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_optimizer): ds = ray.data.range(10, parallelism=2) ds = ds.filter(lambda x: x % 2 == 0) From f3eba72f1e403dcb5407593ed74bc21d3fdbd6e1 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Mon, 1 May 2023 16:06:27 -0700 Subject: [PATCH 2/8] add fusion for map->alltoall Signed-off-by: Scott Lee --- .../data/_internal/execution/interfaces.py | 4 + .../execution/streaming_executor_state.py | 2 - .../logical/operators/all_to_all_operator.py | 3 +- .../logical/rules/operator_fusion.py | 193 ++++++++++++------ .../planner/exchange/shuffle_task_spec.py | 12 +- .../data/_internal/planner/random_shuffle.py | 16 +- python/ray/data/_internal/planner/sort.py | 6 +- .../data/tests/test_execution_optimizer.py | 11 +- 8 files changed, 174 insertions(+), 73 deletions(-) diff --git a/python/ray/data/_internal/execution/interfaces.py b/python/ray/data/_internal/execution/interfaces.py index cd06591bdd0b..654c122a28c6 100644 --- a/python/ray/data/_internal/execution/interfaces.py +++ b/python/ray/data/_internal/execution/interfaces.py @@ -233,6 +233,10 @@ class TaskContext: # TODO(chengsu): clean it up from TaskContext with new optimizer framework. sub_progress_bar_dict: Optional[Dict[str, ProgressBar]] = None + # The underlying function called in a MapOperator; this is used when fusing + # an AllToAllOperator with an upstream MapOperator. + map_transform_fn: Optional["MapTransformFn"] = None + # Block transform function applied by task and actor pools in MapOperator. MapTransformFn = Callable[[Iterable[Block], TaskContext], Iterable[Block]] diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py index b8c2560f7325..24d815bd4e7c 100644 --- a/python/ray/data/_internal/execution/streaming_executor_state.py +++ b/python/ray/data/_internal/execution/streaming_executor_state.py @@ -204,8 +204,6 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> MaybeRefBundle try: # Non-split output case. if output_split_idx is None: - if not self.outqueue: - return None return self.outqueue.popleft() # Scan the queue and look for outputs tagged for the given index. diff --git a/python/ray/data/_internal/logical/operators/all_to_all_operator.py b/python/ray/data/_internal/logical/operators/all_to_all_operator.py index 9dacd39ad5ec..f1a33f76d813 100644 --- a/python/ray/data/_internal/logical/operators/all_to_all_operator.py +++ b/python/ray/data/_internal/logical/operators/all_to_all_operator.py @@ -53,12 +53,13 @@ class RandomShuffle(AbstractAllToAll): def __init__( self, input_op: LogicalOperator, + name: str = "RandomShuffle", seed: Optional[int] = None, num_outputs: Optional[int] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): super().__init__( - "RandomShuffle", + name, input_op, num_outputs=num_outputs, ray_remote_args=ray_remote_args, diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 3c55ba5b0f1b..4fcf1d05107f 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -1,5 +1,9 @@ from typing import Iterator, List, Tuple -from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll +from ray.data._internal.execution.operators.map_operator import MapOperator +from ray.data._internal.logical.operators.all_to_all_operator import ( + AbstractAllToAll, + RandomShuffle, +) from ray.data._internal.stats import StatsDict from ray.data.block import Block @@ -13,6 +17,7 @@ ) from ray.data._internal.logical.interfaces import Rule, PhysicalPlan from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator +from ray.data._internal.logical.operators.map_operator import AbstractUDFMap # Scheduling strategy can be inherited from upstream operator if not specified. @@ -24,26 +29,59 @@ class OperatorFusionRule(Rule): def apply(self, plan: PhysicalPlan) -> PhysicalPlan: self._op_map = plan.op_map.copy() - # Do DFS fusion. - root = self._apply(plan.dag) - return PhysicalPlan(root, self._op_map) + # Do DFS fusion on compatible pairwise operators in two passes. + # In the first pass, only fuse back-to-back map operators together. + op_map_fused = self._fuse_map_to_map_operators(plan.dag) + + # Now that we have fused together all back-to-back map operators, + # we fuse together MapOperator -> AllToAllOperator pairs. + op_map_alltoall_fused = self._fuse_map_to_alltoall_operators(op_map_fused) - def _apply(self, op: PhysicalOperator) -> PhysicalOperator: - """Performs DFS fusion of linear chains of physical map operators, provided that - they are pairwise-compatible. + return PhysicalPlan(op_map_alltoall_fused, self._op_map) - Args: - op: The op that we're trying to fuse with its input. + def _fuse_map_to_map_operators(self, op: MapOperator) -> MapOperator: + """Starting at the given operator, traverses up the DAG of operators + and recursively fuses compatible MapOperator -> MapOperator pairs. + Returns the current (root) operator after completing upstream operator fusions. """ upstream_ops = op.input_dependencies - # Fuse with upstream ops while possible. - while len(upstream_ops) == 1 and self._can_fuse(op, upstream_ops[0]): + while ( + len(upstream_ops) == 1 + and self._can_fuse(op, upstream_ops[0]) + and isinstance(op, MapOperator) + and isinstance(upstream_ops[0], MapOperator) + ): # Fuse operator with its upstream op. - op = self._fuse(op, upstream_ops[0]) + op = self._fuse_ops_map_map(op, upstream_ops[0]) upstream_ops = op.input_dependencies - # Can no longer fuse with upstream ops, proceed up the DAG. + + # Done fusing back-to-back map operators together here, + # move up the DAG to find the next map operators to fuse. op._input_dependencies = [ - self._apply(upstream_op) for upstream_op in upstream_ops + self._fuse_map_to_map_operators(upstream_op) for upstream_op in upstream_ops + ] + return op + + def _fuse_map_to_alltoall_operators(self, op: AllToAllOperator) -> AllToAllOperator: + """Starting at the given operator, traverses up the DAG of operators + and recursively fuses compatible MapOperator -> AllToAllOperator pairs. + Returns the current (root) operator after completing upstream operator fusions. + """ + upstream_ops = op.input_dependencies + while ( + len(upstream_ops) == 1 + and self._can_fuse(op, upstream_ops[0]) + and isinstance(op, AllToAllOperator) + and isinstance(upstream_ops[0], MapOperator) + ): + # Fuse operator with its upstream op. + op = self._fuse_ops_map_alltoall(op, upstream_ops[0]) + upstream_ops = op.input_dependencies + + # Done fusing MapOperator -> AllToAllOperator together here, + # move up the DAG to find the next pair of operators to fuse. + op._input_dependencies = [ + self._fuse_map_to_map_operators(upstream_op) for upstream_op in upstream_ops ] return op @@ -52,7 +90,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: upstream operator. We currently support fusing two operators if the following are all true: - * They are both MapOperators. + * We are fusing either MapOperator -> MapOperator or + MapOperator -> AllToAllOperator. * They either use the same compute configuration, or the upstream operator uses a task pool while the downstream operator uses an actor pool. * If both operators involve callable classes, the callable classes are @@ -65,7 +104,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # We currently only support fusing for the following cases: # - MapOperator -> MapOperator - # - MapOperator -> AllToAllOperator + # - MapOperator -> AllToAllOperator (only + # RandomShuffle LogicalOperator is currently supported) if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance( up_op, MapOperator ): @@ -81,9 +121,9 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # We currently only support fusing for the following cases: # - AbstractMap -> AbstractMap - # - AbstractAllToAll -> AbstractMap + # - AbstractMap -> RandomShuffle if not isinstance( - down_logical_op, (AbstractMap, AbstractAllToAll) + down_logical_op, (AbstractMap, RandomShuffle) ) or not isinstance(up_logical_op, AbstractMap): return False @@ -130,60 +170,47 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # Otherwise, ops are compatible for fusion. return True - def _fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator): - """Fuse the downstream operator with its upstream operator.""" - from ray.data._internal.execution.operators.map_operator import MapOperator - from ray.data._internal.logical.operators.map_operator import AbstractUDFMap - - assert self._can_fuse(down_op, up_op) + def _fuse_ops_map_map( + self, down_op: MapOperator, up_op: MapOperator + ) -> MapOperator: + assert self._can_fuse(down_op, up_op), ( + "Current rule supports fusing MapOperator->MapOperator, but received: " + f"{type(up_op).__name__} -> {type(down_op).__name__}" + ) # Fuse operator names. name = up_op.name + "->" + down_op.name - down_logical_op = self._op_map.pop(down_op) up_logical_op = self._op_map.pop(up_op) - compute = None - target_block_size = None - transform_fn = None + # Merge target block sizes. + down_target_block_size = down_logical_op._target_block_size + up_target_block_size = ( + up_logical_op._target_block_size + if isinstance(up_logical_op, AbstractUDFMap) + else None + ) + if down_target_block_size is not None and up_target_block_size is not None: + target_block_size = max(down_target_block_size, up_target_block_size) + elif up_target_block_size is not None: + target_block_size = up_target_block_size + else: + target_block_size = down_target_block_size # Fuse transformation functions. down_transform_fn = down_op.get_transformation_fn() up_transform_fn = up_op.get_transformation_fn() - def transform_fn_map_map( + def fused_transform_fn_map_map( blocks: Iterator[Block], ctx: TaskContext ) -> Iterator[Block]: blocks = up_transform_fn(blocks, ctx) - # TODO(Clark): Add zero-copy batching between transform functions. - return down_transform_fn(blocks, ctx) - - def transform_fn_map_alltoall( - blocks: List[RefBundle], ctx: TaskContext - ) -> Tuple[List[RefBundle], StatsDict]: - ctx.map_transform_fn = up_transform_fn + # TODO(Scott): Add zero-copy batching between transform functions. return down_transform_fn(blocks, ctx) - if isinstance(down_logical_op, AbstractUDFMap): - # Merge target block sizes. - down_target_block_size = down_logical_op._target_block_size - up_target_block_size = ( - up_logical_op._target_block_size - if isinstance(up_logical_op, AbstractUDFMap) - else None - ) - if down_target_block_size is not None and up_target_block_size is not None: - target_block_size = max(down_target_block_size, up_target_block_size) - elif up_target_block_size is not None: - target_block_size = up_target_block_size - else: - target_block_size = down_target_block_size - transform_fn = transform_fn_map_map - # We take the downstream op's compute in case we're fusing upstream tasks with a - # downstream actor pool (e.g. read->map). - compute = get_compute(down_logical_op._compute) - elif isinstance(down_logical_op, AbstractAllToAll): - transform_fn = transform_fn_map_alltoall + # We take the downstream op's compute in case we're fusing upstream tasks with a + # downstream actor pool (e.g. read->map). + compute = get_compute(down_logical_op._compute) ray_remote_args = down_logical_op._ray_remote_args # Make the upstream operator's inputs the new, fused operator's inputs. @@ -193,7 +220,7 @@ def transform_fn_map_alltoall( # Fused physical map operator. op = MapOperator.create( - transform_fn, + fused_transform_fn_map_map, input_op, name=name, compute_strategy=compute, @@ -202,7 +229,7 @@ def transform_fn_map_alltoall( ) # Build a map logical operator to be used as a reference for further fusion. - # TODO(Clark): This is hacky, remove this once we push fusion to be purely based + # TODO(Scott): This is hacky, remove this once we push fusion to be purely based # on a lower-level operator spec. if isinstance(up_logical_op, AbstractUDFMap): input_op = up_logical_op.input_dependencies[0] @@ -235,6 +262,56 @@ def transform_fn_map_alltoall( # Return the fused physical operator. return op + def _fuse_ops_map_alltoall( + self, down_op: AllToAllOperator, up_op: MapOperator + ) -> AllToAllOperator: + assert self._can_fuse(down_op, up_op), ( + "Current rule supports fusing MapOperator -> AllToAllOperator" + f", but received: {type(up_op).__name__} -> {type(down_op).__name__}" + ) + + # Fuse operator names. + name = up_op.name + "->" + down_op.name + down_logical_op: AbstractAllToAll = self._op_map.pop(down_op) + up_logical_op: AbstractUDFMap = self._op_map.pop(up_op) + assert isinstance(down_logical_op, RandomShuffle), ( + "Current rule supports fusing RandomShuffle downstream operators only, " + f"but got {type(down_logical_op).__name__}" + ) + + # Fuse transformation functions. + down_transform_fn = down_op.get_transformation_fn() + up_transform_fn = up_op.get_transformation_fn() + + def fused_transform_fn_map_alltoall( + blocks: List[RefBundle], ctx: TaskContext + ) -> Tuple[List[RefBundle], StatsDict]: + """To fuse MapOperator->AllToAllOperator, we store the map function + in the TaskContext, which is later called in `ShuffleTaskSpec.map`. + Then, we can return an AllToAllOperator which applies the map function + before executing the shuffle.""" + ctx.map_transform_fn = up_transform_fn + return down_transform_fn(blocks, ctx) + + ray_remote_args = down_logical_op._ray_remote_args + # Make the upstream operator's inputs the new, fused operator's inputs. + input_deps = up_op.input_dependencies + assert len(input_deps) == 1 + input_op = input_deps[0] + + op = AllToAllOperator( + fused_transform_fn_map_alltoall, + input_op, + name=name, + ) + # Bottom out at the source logical op (e.g. Read()). + input_op = up_logical_op + + logical_op = RandomShuffle(input_op, name=name, ray_remote_args=ray_remote_args) + self._op_map[op] = logical_op + # Return the fused physical operator. + return op + def _are_remote_args_compatible(up_args, down_args): """Check if Ray remote arguments are compatible for merging.""" diff --git a/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py b/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py index 474d69b03279..9611041a299e 100644 --- a/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py @@ -4,6 +4,7 @@ import numpy as np from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder +from ray.data._internal.execution.interfaces import MapTransformFn from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata @@ -19,9 +20,10 @@ def __init__( self, random_shuffle: bool = False, random_seed: Optional[int] = None, + upstream_map_fn: Optional[MapTransformFn] = None, ): super().__init__( - map_args=[random_shuffle, random_seed], + map_args=[upstream_map_fn, random_shuffle, random_seed], reduce_args=[random_shuffle, random_seed], ) @@ -30,11 +32,19 @@ def map( idx: int, block: Block, output_num_blocks: int, + upstream_map_fn: Optional[MapTransformFn], random_shuffle: bool, random_seed: Optional[int], ) -> List[Union[BlockMetadata, Block]]: # TODO: Support fusion with other upstream operators. stats = BlockExecStats.builder() + if upstream_map_fn: + mapped_blocks = list(upstream_map_fn([block])) + assert len(mapped_blocks) == 1, ( + "Expected upstream_map_fn to return one block, but instead" + f" returned {len(mapped_blocks)} blocks" + ) + block = mapped_blocks[0] block = BlockAccessor.for_block(block) # Randomize the distribution of records to blocks. diff --git a/python/ray/data/_internal/planner/random_shuffle.py b/python/ray/data/_internal/planner/random_shuffle.py index 8f22741aa93c..9f5d5ac4a067 100644 --- a/python/ray/data/_internal/planner/random_shuffle.py +++ b/python/ray/data/_internal/planner/random_shuffle.py @@ -2,6 +2,7 @@ from ray.data._internal.execution.interfaces import ( AllToAllTransformFn, + MapTransformFn, RefBundle, TaskContext, ) @@ -28,7 +29,20 @@ def fn( ctx: TaskContext, ) -> Tuple[List[RefBundle], StatsDict]: num_input_blocks = sum(len(r.blocks) for r in refs) - shuffle_spec = ShuffleTaskSpec(random_shuffle=True, random_seed=seed) + + # If map_transform_fn is specified (e.g. from fusing + # MapOperator->AllToAllOperator), we pass a map function which + # is applied to each block before shuffling. + map_transform_fn: Optional[MapTransformFn] = ctx.map_transform_fn + upstream_map_fn = None + if map_transform_fn: + upstream_map_fn = lambda block: map_transform_fn(block, ctx) # noqa: E731 + + shuffle_spec = ShuffleTaskSpec( + random_shuffle=True, + random_seed=seed, + upstream_map_fn=upstream_map_fn, + ) if DataContext.get_current().use_push_based_shuffle: if num_outputs is not None: diff --git a/python/ray/data/_internal/planner/sort.py b/python/ray/data/_internal/planner/sort.py index d0d978208f32..955aac05c6d2 100644 --- a/python/ray/data/_internal/planner/sort.py +++ b/python/ray/data/_internal/planner/sort.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List, Optional, Tuple +from typing import List, Tuple from ray.data._internal.execution.interfaces import ( AllToAllTransformFn, @@ -42,10 +42,6 @@ def fn( unified_schema = unify_block_metadata_schema(metadata) _validate_key_fn(unified_schema, key) - map_transform_fn: Optional[AllToAllTransformFn] = ctx.map_transform_fn - if map_transform_fn: - blocks = map_transform_fn(blocks) - if isinstance(key, str): key = [(key, "descending" if descending else "ascending")] if isinstance(key, list): diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 64ea1cb6932d..6df20a3934c7 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -605,12 +605,13 @@ def __call__(self, x): def test_read_map_batches_operator_fusion_with_all_to_all_operator( ray_start_regular_shared, enable_optimizer ): - ds = ray.data.range(10) + n = 10 + ds = ray.data.range(n) ds = ds.map_batches(lambda batch: [x + 1 for x in batch]) - ds = ds.sort() - assert ds.take_all() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - assert "DoRead->MapBatches->Sort" in ds.stats() - _check_usage_record(["ReadRange", "MapBatches", "Sort"]) + ds = ds.random_shuffle() + assert set(ds.take_all()) == set(range(1, n + 1)) + assert "DoRead->MapBatches->RandomShuffle" in ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "RandomShuffle"]) def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_optimizer): From ebbe94f62e0ecf4e6ac3b322d78c77c15ce7ccda Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Mon, 1 May 2023 16:18:15 -0700 Subject: [PATCH 3/8] add test in other direction Signed-off-by: Scott Lee --- python/ray/data/tests/test_execution_optimizer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 6df20a3934c7..c0f03f6c6a8e 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -602,9 +602,10 @@ def __call__(self, x): assert isinstance(physical_op.input_dependencies[0], InputDataBuffer) -def test_read_map_batches_operator_fusion_with_all_to_all_operator( +def test_read_map_batches_operator_fusion_with_random_shuffle_operator( ray_start_regular_shared, enable_optimizer ): + # We currently only support fusing MapOperator->AllToAllOperator. n = 10 ds = ray.data.range(n) ds = ds.map_batches(lambda batch: [x + 1 for x in batch]) @@ -613,6 +614,15 @@ def test_read_map_batches_operator_fusion_with_all_to_all_operator( assert "DoRead->MapBatches->RandomShuffle" in ds.stats() _check_usage_record(["ReadRange", "MapBatches", "RandomShuffle"]) + ds = ray.data.range(n) + ds = ds.random_shuffle() + ds = ds.map_batches(lambda batch: [x + 1 for x in batch]) + assert set(ds.take_all()) == set(range(1, n + 1)) + # TODO(Scott): Update below assertion after supporting fusion in + # the other direction (AllToAllOperator->MapOperator) + assert "DoRead->RandomShuffle->MapBatches" not in ds.stats() + assert all(op in ds.stats() for op in ("DoRead", "RandomShuffle", "MapBatches")) + def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_optimizer): ds = ray.data.range(10, parallelism=2) From 8cf5215c29b7390c34bc7c6ad5cbd4f36d1482db Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Mon, 1 May 2023 17:46:10 -0700 Subject: [PATCH 4/8] update strict mode test Signed-off-by: Scott Lee --- python/ray/data/tests/test_execution_optimizer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index d8480e9dacb6..0a197b23db43 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -613,20 +613,25 @@ def test_read_map_batches_operator_fusion_with_random_shuffle_operator( # We currently only support fusing MapOperator->AllToAllOperator. n = 10 ds = ray.data.range(n) - ds = ds.map_batches(lambda batch: [x + 1 for x in batch]) + ds = ds.map_batches( + lambda batch: {"id": [x + 1 for x in batch["id"]]}, batch_size=None + ) ds = ds.random_shuffle() - assert set(ds.take_all()) == set(range(1, n + 1)) + assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) assert "DoRead->MapBatches->RandomShuffle" in ds.stats() _check_usage_record(["ReadRange", "MapBatches", "RandomShuffle"]) ds = ray.data.range(n) ds = ds.random_shuffle() - ds = ds.map_batches(lambda batch: [x + 1 for x in batch]) - assert set(ds.take_all()) == set(range(1, n + 1)) + ds = ds.map_batches( + lambda batch: {"id": [x + 1 for x in batch["id"]]}, batch_size=None + ) + assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) # TODO(Scott): Update below assertion after supporting fusion in # the other direction (AllToAllOperator->MapOperator) assert "DoRead->RandomShuffle->MapBatches" not in ds.stats() assert all(op in ds.stats() for op in ("DoRead", "RandomShuffle", "MapBatches")) + _check_usage_record(["ReadRange", "RandomShuffle", "MapBatches"]) def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_optimizer): From 520a46af13c5a8545acd0c672aa3dd98dd6c2ad4 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Thu, 4 May 2023 16:14:27 -0700 Subject: [PATCH 5/8] address comments Signed-off-by: Scott Lee --- .../data/_internal/execution/interfaces.py | 2 +- .../logical/rules/operator_fusion.py | 66 +++++----- .../data/_internal/planner/random_shuffle.py | 2 +- .../data/tests/test_execution_optimizer.py | 119 ++++++++++++++++-- 4 files changed, 149 insertions(+), 40 deletions(-) diff --git a/python/ray/data/_internal/execution/interfaces.py b/python/ray/data/_internal/execution/interfaces.py index 5fd1fae6046c..734384232c45 100644 --- a/python/ray/data/_internal/execution/interfaces.py +++ b/python/ray/data/_internal/execution/interfaces.py @@ -235,7 +235,7 @@ class TaskContext: # The underlying function called in a MapOperator; this is used when fusing # an AllToAllOperator with an upstream MapOperator. - map_transform_fn: Optional["MapTransformFn"] = None + upstream_map_transform_fn: Optional["MapTransformFn"] = None # Block transform function applied by task and actor pools in MapOperator. diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 4fcf1d05107f..ed1ce5cf3b04 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -31,59 +31,62 @@ def apply(self, plan: PhysicalPlan) -> PhysicalPlan: self._op_map = plan.op_map.copy() # Do DFS fusion on compatible pairwise operators in two passes. # In the first pass, only fuse back-to-back map operators together. - op_map_fused = self._fuse_map_to_map_operators(plan.dag) + fused_dag = self._fuse_map_operators_in_dag(plan.dag) # Now that we have fused together all back-to-back map operators, # we fuse together MapOperator -> AllToAllOperator pairs. - op_map_alltoall_fused = self._fuse_map_to_alltoall_operators(op_map_fused) + fused_dag = self._fuse_all_to_all_operators_in_dag(fused_dag) - return PhysicalPlan(op_map_alltoall_fused, self._op_map) + return PhysicalPlan(fused_dag, self._op_map) - def _fuse_map_to_map_operators(self, op: MapOperator) -> MapOperator: + def _fuse_map_operators_in_dag(self, dag: PhysicalOperator) -> MapOperator: """Starting at the given operator, traverses up the DAG of operators and recursively fuses compatible MapOperator -> MapOperator pairs. Returns the current (root) operator after completing upstream operator fusions. """ - upstream_ops = op.input_dependencies + upstream_ops = dag.input_dependencies while ( len(upstream_ops) == 1 - and self._can_fuse(op, upstream_ops[0]) - and isinstance(op, MapOperator) + and isinstance(dag, MapOperator) and isinstance(upstream_ops[0], MapOperator) + and self._can_fuse(dag, upstream_ops[0]) ): # Fuse operator with its upstream op. - op = self._fuse_ops_map_map(op, upstream_ops[0]) - upstream_ops = op.input_dependencies + dag = self._get_fused_map_operator(dag, upstream_ops[0]) + upstream_ops = dag.input_dependencies # Done fusing back-to-back map operators together here, # move up the DAG to find the next map operators to fuse. - op._input_dependencies = [ - self._fuse_map_to_map_operators(upstream_op) for upstream_op in upstream_ops + dag._input_dependencies = [ + self._fuse_map_operators_in_dag(upstream_op) for upstream_op in upstream_ops ] - return op + return dag - def _fuse_map_to_alltoall_operators(self, op: AllToAllOperator) -> AllToAllOperator: + def _fuse_all_to_all_operators_in_dag( + self, dag: AllToAllOperator + ) -> AllToAllOperator: """Starting at the given operator, traverses up the DAG of operators and recursively fuses compatible MapOperator -> AllToAllOperator pairs. Returns the current (root) operator after completing upstream operator fusions. """ - upstream_ops = op.input_dependencies + upstream_ops = dag.input_dependencies while ( len(upstream_ops) == 1 - and self._can_fuse(op, upstream_ops[0]) - and isinstance(op, AllToAllOperator) + and isinstance(dag, AllToAllOperator) and isinstance(upstream_ops[0], MapOperator) + and self._can_fuse(dag, upstream_ops[0]) ): # Fuse operator with its upstream op. - op = self._fuse_ops_map_alltoall(op, upstream_ops[0]) - upstream_ops = op.input_dependencies + dag = self._get_fused_all_to_all_operator(dag, upstream_ops[0]) + upstream_ops = dag.input_dependencies # Done fusing MapOperator -> AllToAllOperator together here, # move up the DAG to find the next pair of operators to fuse. - op._input_dependencies = [ - self._fuse_map_to_map_operators(upstream_op) for upstream_op in upstream_ops + dag._input_dependencies = [ + self._fuse_all_to_all_operators_in_dag(upstream_op) + for upstream_op in upstream_ops ] - return op + return dag def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: """Returns whether the provided downstream operator can be fused with the given @@ -170,7 +173,7 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # Otherwise, ops are compatible for fusion. return True - def _fuse_ops_map_map( + def _get_fused_map_operator( self, down_op: MapOperator, up_op: MapOperator ) -> MapOperator: assert self._can_fuse(down_op, up_op), ( @@ -201,7 +204,7 @@ def _fuse_ops_map_map( down_transform_fn = down_op.get_transformation_fn() up_transform_fn = up_op.get_transformation_fn() - def fused_transform_fn_map_map( + def fused_map_transform_fn( blocks: Iterator[Block], ctx: TaskContext ) -> Iterator[Block]: blocks = up_transform_fn(blocks, ctx) @@ -220,7 +223,7 @@ def fused_transform_fn_map_map( # Fused physical map operator. op = MapOperator.create( - fused_transform_fn_map_map, + fused_map_transform_fn, input_op, name=name, compute_strategy=compute, @@ -262,7 +265,7 @@ def fused_transform_fn_map_map( # Return the fused physical operator. return op - def _fuse_ops_map_alltoall( + def _get_fused_all_to_all_operator( self, down_op: AllToAllOperator, up_op: MapOperator ) -> AllToAllOperator: assert self._can_fuse(down_op, up_op), ( @@ -283,14 +286,15 @@ def _fuse_ops_map_alltoall( down_transform_fn = down_op.get_transformation_fn() up_transform_fn = up_op.get_transformation_fn() - def fused_transform_fn_map_alltoall( + def fused_all_to_all_transform_fn( blocks: List[RefBundle], ctx: TaskContext ) -> Tuple[List[RefBundle], StatsDict]: """To fuse MapOperator->AllToAllOperator, we store the map function - in the TaskContext, which is later called in `ShuffleTaskSpec.map`. - Then, we can return an AllToAllOperator which applies the map function - before executing the shuffle.""" - ctx.map_transform_fn = up_transform_fn + in the TaskContext so that it may be used by the downstream + AllToAllOperator's transform function. Then, we can return an + AllToAllOperator which applies the map function before executing + the shuffle.""" + ctx.upstream_map_transform_fn = up_transform_fn return down_transform_fn(blocks, ctx) ray_remote_args = down_logical_op._ray_remote_args @@ -300,7 +304,7 @@ def fused_transform_fn_map_alltoall( input_op = input_deps[0] op = AllToAllOperator( - fused_transform_fn_map_alltoall, + fused_all_to_all_transform_fn, input_op, name=name, ) diff --git a/python/ray/data/_internal/planner/random_shuffle.py b/python/ray/data/_internal/planner/random_shuffle.py index 9f5d5ac4a067..5827c34802d4 100644 --- a/python/ray/data/_internal/planner/random_shuffle.py +++ b/python/ray/data/_internal/planner/random_shuffle.py @@ -33,7 +33,7 @@ def fn( # If map_transform_fn is specified (e.g. from fusing # MapOperator->AllToAllOperator), we pass a map function which # is applied to each block before shuffling. - map_transform_fn: Optional[MapTransformFn] = ctx.map_transform_fn + map_transform_fn: Optional[MapTransformFn] = ctx.upstream_map_transform_fn upstream_map_fn = None if map_transform_fn: upstream_map_fn = lambda block: map_transform_fn(block, ctx) # noqa: E731 diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 0a197b23db43..3e28ddda7fe9 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -607,15 +607,38 @@ def __call__(self, x): assert isinstance(physical_op.input_dependencies[0], InputDataBuffer) +def test_read_map_batches_operator_fusion_with_randomize_blocks_operator( + ray_start_regular_shared, enable_optimizer +): + # Note: We currently do not fuse MapBatches->RandomizeBlocks. + # This test is to ensure that we don't accidentally fuse them. + # There is also an additional optimization rule, under ReorderRandomizeBlocksRule, + # which collapses RandomizeBlocks operators, so we should not be fusing them + # to begin with. + def fn(batch): + return {"id": [x + 1 for x in batch["id"]]} + + n = 10 + ds = ray.data.range(n) + ds = ds.randomize_block_order() + ds = ds.map_batches(fn, batch_size=None) + assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) + assert "RandomizeBlocks" not in ds.stats() + assert "DoRead->MapBatches->RandomizeBlocks" not in ds.stats() + assert "DoRead->MapBatches" in ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "RandomizeBlocks"]) + + def test_read_map_batches_operator_fusion_with_random_shuffle_operator( ray_start_regular_shared, enable_optimizer ): - # We currently only support fusing MapOperator->AllToAllOperator. + # Note: we currently only support fusing MapOperator->AllToAllOperator. + def fn(batch): + return {"id": [x + 1 for x in batch["id"]]} + n = 10 ds = ray.data.range(n) - ds = ds.map_batches( - lambda batch: {"id": [x + 1 for x in batch["id"]]}, batch_size=None - ) + ds = ds.map_batches(fn, batch_size=None) ds = ds.random_shuffle() assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) assert "DoRead->MapBatches->RandomShuffle" in ds.stats() @@ -623,9 +646,7 @@ def test_read_map_batches_operator_fusion_with_random_shuffle_operator( ds = ray.data.range(n) ds = ds.random_shuffle() - ds = ds.map_batches( - lambda batch: {"id": [x + 1 for x in batch["id"]]}, batch_size=None - ) + ds = ds.map_batches(fn, batch_size=None) assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) # TODO(Scott): Update below assertion after supporting fusion in # the other direction (AllToAllOperator->MapOperator) @@ -633,6 +654,90 @@ def test_read_map_batches_operator_fusion_with_random_shuffle_operator( assert all(op in ds.stats() for op in ("DoRead", "RandomShuffle", "MapBatches")) _check_usage_record(["ReadRange", "RandomShuffle", "MapBatches"]) + # Test fusing multiple `map_batches` with multiple `random_shuffle` operations. + ds = ray.data.range(n) + for _ in range(5): + ds = ds.map_batches(fn, batch_size=None) + ds = ds.random_shuffle() + assert set(extract_values("id", ds.take_all())) == set(range(5, n + 5)) + assert f"DoRead->{'MapBatches->' * 5}RandomShuffle" in ds.stats() + + # For interweaved map_batches and random_shuffle operations, we expect to fuse the + # two pairs of MapBatches->RandomShuffle, but not the resulting + # RandomShuffle operators. + ds = ray.data.range(n) + ds = ds.map_batches(fn, batch_size=None) + ds = ds.random_shuffle() + ds = ds.map_batches(fn, batch_size=None) + ds = ds.random_shuffle() + assert set(extract_values("id", ds.take_all())) == set(range(2, n + 2)) + assert "Stage 1 DoRead->MapBatches->RandomShuffle" in ds.stats() + assert "Stage 2 MapBatches->RandomShuffle" + _check_usage_record(["ReadRange", "RandomShuffle", "MapBatches"]) + + +def test_read_map_batches_operator_fusion_with_repartition_operator( + ray_start_regular_shared, enable_optimizer +): + # Note: We currently do not fuse MapBatches->Repartition. + # This test is to ensure that we don't accidentally fuse them, until + # we implement it later. + def fn(batch): + return {"id": [x + 1 for x in batch["id"]]} + + n = 10 + ds = ray.data.range(n) + ds = ds.map_batches(fn, batch_size=None) + ds = ds.repartition(2) + assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) + # TODO(Scott): update the below assertions after we support fusion. + assert "DoRead->MapBatches->Repartition" not in ds.stats() + assert "DoRead->MapBatches" in ds.stats() + assert "Repartition" in ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "Repartition"]) + + +def test_read_map_batches_operator_fusion_with_sort_operator( + ray_start_regular_shared, enable_optimizer +): + # Note: We currently do not fuse MapBatches->Sort. + # This test is to ensure that we don't accidentally fuse them, until + # we implement it later. + def fn(batch): + return {"id": [x + 1 for x in batch["id"]]} + + n = 10 + ds = ray.data.range(n) + ds = ds.map_batches(fn, batch_size=None) + ds = ds.sort("id") + assert extract_values("id", ds.take_all()) == list(range(1, n + 1)) + # TODO(Scott): update the below assertions after we support fusion. + assert "DoRead->MapBatches->Sort" not in ds.stats() + assert "DoRead->MapBatches" in ds.stats() + assert "Sort" in ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "Sort"]) + + +def test_read_map_batches_operator_fusion_with_aggregate_operator( + ray_start_regular_shared, enable_optimizer +): + # Note: We currently do not fuse MapBatches->Repartition. + # This test is to ensure that we don't accidentally fuse them, until + # we implement it later. + def fn(batch): + return {"id": [x + 1 for x in batch["id"]]} + + n = 10 + ds = ray.data.range(n) + ds = ds.map_batches(fn, batch_size=None) + ds = ds.repartition(2) + assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) + # TODO(Scott): update the below assertions after we support fusion. + assert "DoRead->MapBatches->Repartition" not in ds.stats() + assert "DoRead->MapBatches" in ds.stats() + assert "Repartition" in ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "Repartition"]) + def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_optimizer): ds = ray.data.range(10, parallelism=2) From 777bc89d4e4cfa4b51c1dae1a0513f861772152e Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 5 May 2023 11:42:52 -0700 Subject: [PATCH 6/8] clean up Signed-off-by: Scott Lee --- .../logical/rules/operator_fusion.py | 8 ++--- .../data/tests/test_execution_optimizer.py | 33 ++++++++++++------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index ed1ce5cf3b04..293249c105fb 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -107,8 +107,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: # We currently only support fusing for the following cases: # - MapOperator -> MapOperator - # - MapOperator -> AllToAllOperator (only - # RandomShuffle LogicalOperator is currently supported) + # - MapOperator -> AllToAllOperator (only RandomShuffle + # LogicalOperator is currently supported) if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance( up_op, MapOperator ): @@ -277,10 +277,6 @@ def _get_fused_all_to_all_operator( name = up_op.name + "->" + down_op.name down_logical_op: AbstractAllToAll = self._op_map.pop(down_op) up_logical_op: AbstractUDFMap = self._op_map.pop(up_op) - assert isinstance(down_logical_op, RandomShuffle), ( - "Current rule supports fusing RandomShuffle downstream operators only, " - f"but got {type(down_logical_op).__name__}" - ) # Fuse transformation functions. down_transform_fn = down_op.get_transformation_fn() diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 3e28ddda7fe9..b2b4450d70a5 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -721,22 +721,31 @@ def fn(batch): def test_read_map_batches_operator_fusion_with_aggregate_operator( ray_start_regular_shared, enable_optimizer ): - # Note: We currently do not fuse MapBatches->Repartition. + from ray.data.aggregate import AggregateFn + + # Note: We currently do not fuse MapBatches->Aggregate. # This test is to ensure that we don't accidentally fuse them, until # we implement it later. def fn(batch): - return {"id": [x + 1 for x in batch["id"]]} - - n = 10 - ds = ray.data.range(n) - ds = ds.map_batches(fn, batch_size=None) - ds = ds.repartition(2) - assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) + return {"id": [x % 2 for x in batch["id"]]} + + n = 100 + grouped_ds = ray.data.range(n).map_batches(fn, batch_size=None).groupby("id") + agg_ds = grouped_ds.aggregate( + AggregateFn( + init=lambda k: [0, 0], + accumulate_row=lambda a, r: [a[0] + r["id"], a[1] + 1], + merge=lambda a1, a2: [a1[0] + a2[0], a1[1] + a2[1]], + finalize=lambda a: a[0] / a[1], + name="foo", + ), + ) + agg_ds.take_all() == [{"id": 0, "foo": 0.0}, {"id": 1, "foo": 1.0}] # TODO(Scott): update the below assertions after we support fusion. - assert "DoRead->MapBatches->Repartition" not in ds.stats() - assert "DoRead->MapBatches" in ds.stats() - assert "Repartition" in ds.stats() - _check_usage_record(["ReadRange", "MapBatches", "Repartition"]) + assert "DoRead->MapBatches->Aggregate" not in agg_ds.stats() + assert "DoRead->MapBatches" in agg_ds.stats() + assert "Aggregate" in agg_ds.stats() + _check_usage_record(["ReadRange", "MapBatches", "Aggregate"]) def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_optimizer): From 5880507b5e2ca5bf7b29d7a75842766a5b49e05f Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 5 May 2023 12:02:08 -0700 Subject: [PATCH 7/8] lints Signed-off-by: Scott Lee --- python/ray/data/_internal/logical/rules/operator_fusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 293249c105fb..edddbaf53d45 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -183,6 +183,7 @@ def _get_fused_map_operator( # Fuse operator names. name = up_op.name + "->" + down_op.name + down_logical_op = self._op_map.pop(down_op) up_logical_op = self._op_map.pop(up_op) @@ -214,7 +215,6 @@ def fused_map_transform_fn( # We take the downstream op's compute in case we're fusing upstream tasks with a # downstream actor pool (e.g. read->map). compute = get_compute(down_logical_op._compute) - ray_remote_args = down_logical_op._ray_remote_args # Make the upstream operator's inputs the new, fused operator's inputs. input_deps = up_op.input_dependencies @@ -275,6 +275,7 @@ def _get_fused_all_to_all_operator( # Fuse operator names. name = up_op.name + "->" + down_op.name + down_logical_op: AbstractAllToAll = self._op_map.pop(down_op) up_logical_op: AbstractUDFMap = self._op_map.pop(up_op) From be501bb2589071fa5f243ff61452f92a25f8ef34 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Fri, 5 May 2023 12:03:06 -0700 Subject: [PATCH 8/8] clean up Signed-off-by: Scott Lee --- python/ray/data/_internal/logical/rules/operator_fusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index edddbaf53d45..eb5d0b2e0820 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -288,9 +288,7 @@ def fused_all_to_all_transform_fn( ) -> Tuple[List[RefBundle], StatsDict]: """To fuse MapOperator->AllToAllOperator, we store the map function in the TaskContext so that it may be used by the downstream - AllToAllOperator's transform function. Then, we can return an - AllToAllOperator which applies the map function before executing - the shuffle.""" + AllToAllOperator's transform function.""" ctx.upstream_map_transform_fn = up_transform_fn return down_transform_fn(blocks, ctx)