Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Allow fusing MapOperator -> AllToAllOperator #34847

Merged
merged 13 commits into from
May 5, 2023
4 changes: 4 additions & 0 deletions python/ray/data/_internal/execution/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ class TaskContext:
# TODO(chengsu): clean it up from TaskContext with new optimizer framework.
sub_progress_bar_dict: Optional[Dict[str, ProgressBar]] = None

# The underlying function called in a MapOperator; this is used when fusing
# an AllToAllOperator with an upstream MapOperator.
upstream_map_transform_fn: Optional["MapTransformFn"] = None


# Block transform function applied by task and actor pools in MapOperator.
MapTransformFn = Callable[[Iterable[Block], TaskContext], Iterable[Block]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ class RandomShuffle(AbstractAllToAll):
def __init__(
self,
input_op: LogicalOperator,
name: str = "RandomShuffle",
seed: Optional[int] = None,
num_outputs: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
super().__init__(
"RandomShuffle",
name,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed so we can override the name with the upstream operator-fused name.

input_op,
num_outputs=num_outputs,
ray_remote_args=ray_remote_args,
Expand Down
180 changes: 143 additions & 37 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from typing import Iterator
from typing import Iterator, List, Tuple
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.logical.operators.all_to_all_operator import (
AbstractAllToAll,
RandomShuffle,
)
from ray.data._internal.stats import StatsDict

from ray.data.block import Block

# TODO(Clark): Remove compute dependency once we delete the legacy compute.
from ray.data._internal.compute import is_task_compute, CallableClass, get_compute
from ray.data._internal.execution.interfaces import PhysicalOperator, TaskContext
from ray.data._internal.execution.interfaces import (
PhysicalOperator,
RefBundle,
TaskContext,
)
from ray.data._internal.logical.interfaces import Rule, PhysicalPlan
from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap


# Scheduling strategy can be inherited from upstream operator if not specified.
Expand All @@ -17,35 +29,72 @@ class OperatorFusionRule(Rule):

def apply(self, plan: PhysicalPlan) -> PhysicalPlan:
self._op_map = plan.op_map.copy()
# Do DFS fusion.
root = self._apply(plan.dag)
return PhysicalPlan(root, self._op_map)
# Do DFS fusion on compatible pairwise operators in two passes.
# In the first pass, only fuse back-to-back map operators together.
fused_dag = self._fuse_map_operators_in_dag(plan.dag)

def _apply(self, op: PhysicalOperator) -> PhysicalOperator:
"""Performs DFS fusion of linear chains of physical map operators, provided that
they are pairwise-compatible.
# Now that we have fused together all back-to-back map operators,
# we fuse together MapOperator -> AllToAllOperator pairs.
fused_dag = self._fuse_all_to_all_operators_in_dag(fused_dag)

Args:
op: The op that we're trying to fuse with its input.
return PhysicalPlan(fused_dag, self._op_map)

def _fuse_map_operators_in_dag(self, dag: PhysicalOperator) -> MapOperator:
"""Starting at the given operator, traverses up the DAG of operators
and recursively fuses compatible MapOperator -> MapOperator pairs.
Returns the current (root) operator after completing upstream operator fusions.
"""
upstream_ops = op.input_dependencies
# Fuse with upstream ops while possible.
while len(upstream_ops) == 1 and self._can_fuse(op, upstream_ops[0]):
upstream_ops = dag.input_dependencies
while (
len(upstream_ops) == 1
and isinstance(dag, MapOperator)
and isinstance(upstream_ops[0], MapOperator)
and self._can_fuse(dag, upstream_ops[0])
):
# Fuse operator with its upstream op.
op = self._fuse(op, upstream_ops[0])
upstream_ops = op.input_dependencies
# Can no longer fuse with upstream ops, proceed up the DAG.
op._input_dependencies = [
self._apply(upstream_op) for upstream_op in upstream_ops
dag = self._get_fused_map_operator(dag, upstream_ops[0])
upstream_ops = dag.input_dependencies

# Done fusing back-to-back map operators together here,
# move up the DAG to find the next map operators to fuse.
dag._input_dependencies = [
self._fuse_map_operators_in_dag(upstream_op) for upstream_op in upstream_ops
]
return op
return dag

def _fuse_all_to_all_operators_in_dag(
self, dag: AllToAllOperator
) -> AllToAllOperator:
"""Starting at the given operator, traverses up the DAG of operators
and recursively fuses compatible MapOperator -> AllToAllOperator pairs.
Returns the current (root) operator after completing upstream operator fusions.
"""
upstream_ops = dag.input_dependencies
while (
len(upstream_ops) == 1
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
and isinstance(dag, AllToAllOperator)
and isinstance(upstream_ops[0], MapOperator)
and self._can_fuse(dag, upstream_ops[0])
):
# Fuse operator with its upstream op.
dag = self._get_fused_all_to_all_operator(dag, upstream_ops[0])
upstream_ops = dag.input_dependencies

# Done fusing MapOperator -> AllToAllOperator together here,
# move up the DAG to find the next pair of operators to fuse.
dag._input_dependencies = [
self._fuse_all_to_all_operators_in_dag(upstream_op)
for upstream_op in upstream_ops
]
return dag

def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
"""Returns whether the provided downstream operator can be fused with the given
upstream operator.

We currently support fusing two operators if the following are all true:
* They are both MapOperators.
* We are fusing either MapOperator -> MapOperator or
MapOperator -> AllToAllOperator.
* They either use the same compute configuration, or the upstream operator
uses a task pool while the downstream operator uses an actor pool.
* If both operators involve callable classes, the callable classes are
Expand All @@ -56,8 +105,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap

# We only support fusing MapOperators.
if not isinstance(down_op, MapOperator) or not isinstance(up_op, MapOperator):
# We currently only support fusing for the following cases:
# - MapOperator -> MapOperator
# - MapOperator -> AllToAllOperator (only RandomShuffle
# LogicalOperator is currently supported)
if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance(
up_op, MapOperator
):
return False

down_logical_op = self._op_map[down_op]
Expand All @@ -68,17 +122,20 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
if not down_logical_op._input_dependencies:
return False

# We only support fusing AbstractMap -> AbstractMap operators.
if not isinstance(down_logical_op, AbstractMap) or not isinstance(
up_logical_op, AbstractMap
):
# We currently only support fusing for the following cases:
# - AbstractMap -> AbstractMap
# - AbstractMap -> RandomShuffle
if not isinstance(
down_logical_op, (AbstractMap, RandomShuffle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we will support Repartition in a followup PR right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we will do this in a future PR.

) or not isinstance(up_logical_op, AbstractMap):
return False

# Allow fusing tasks->actors if the resources are compatible (read->map), but
# not the other way around. The latter (downstream op) will be used as the
# compute if fused.
if (
is_task_compute(down_logical_op._compute)
isinstance(down_logical_op, AbstractUDFMap)
and is_task_compute(down_logical_op._compute)
and isinstance(up_logical_op, AbstractUDFMap)
and get_compute(up_logical_op._compute)
!= get_compute(down_logical_op._compute)
Expand Down Expand Up @@ -116,12 +173,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
# Otherwise, ops are compatible for fusion.
return True

def _fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator):
"""Fuse the downstream operator with its upstream operator."""
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap

assert self._can_fuse(down_op, up_op)
def _get_fused_map_operator(
self, down_op: MapOperator, up_op: MapOperator
) -> MapOperator:
assert self._can_fuse(down_op, up_op), (
"Current rule supports fusing MapOperator->MapOperator, but received: "
f"{type(up_op).__name__} -> {type(down_op).__name__}"
)

# Fuse operator names.
name = up_op.name + "->" + down_op.name
Expand All @@ -147,9 +205,11 @@ def _fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator):
down_transform_fn = down_op.get_transformation_fn()
up_transform_fn = up_op.get_transformation_fn()

def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
def fused_map_transform_fn(
blocks: Iterator[Block], ctx: TaskContext
) -> Iterator[Block]:
blocks = up_transform_fn(blocks, ctx)
# TODO(Clark): Add zero-copy batching between transform functions.
# TODO(Scott): Add zero-copy batching between transform functions.
return down_transform_fn(blocks, ctx)

# We take the downstream op's compute in case we're fusing upstream tasks with a
Expand All @@ -163,7 +223,7 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:

# Fused physical map operator.
op = MapOperator.create(
transform_fn,
fused_map_transform_fn,
input_op,
name=name,
compute_strategy=compute,
Expand All @@ -172,7 +232,7 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
)

# Build a map logical operator to be used as a reference for further fusion.
# TODO(Clark): This is hacky, remove this once we push fusion to be purely based
# TODO(Scott): This is hacky, remove this once we push fusion to be purely based
# on a lower-level operator spec.
if isinstance(up_logical_op, AbstractUDFMap):
input_op = up_logical_op.input_dependencies[0]
Expand Down Expand Up @@ -205,6 +265,52 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
# Return the fused physical operator.
return op

def _get_fused_all_to_all_operator(
self, down_op: AllToAllOperator, up_op: MapOperator
) -> AllToAllOperator:
assert self._can_fuse(down_op, up_op), (
"Current rule supports fusing MapOperator -> AllToAllOperator"
f", but received: {type(up_op).__name__} -> {type(down_op).__name__}"
)

# Fuse operator names.
name = up_op.name + "->" + down_op.name

down_logical_op: AbstractAllToAll = self._op_map.pop(down_op)
up_logical_op: AbstractUDFMap = self._op_map.pop(up_op)

# Fuse transformation functions.
down_transform_fn = down_op.get_transformation_fn()
up_transform_fn = up_op.get_transformation_fn()

def fused_all_to_all_transform_fn(
blocks: List[RefBundle], ctx: TaskContext
) -> Tuple[List[RefBundle], StatsDict]:
"""To fuse MapOperator->AllToAllOperator, we store the map function
in the TaskContext so that it may be used by the downstream
AllToAllOperator's transform function."""
ctx.upstream_map_transform_fn = up_transform_fn
return down_transform_fn(blocks, ctx)

ray_remote_args = down_logical_op._ray_remote_args
# Make the upstream operator's inputs the new, fused operator's inputs.
input_deps = up_op.input_dependencies
assert len(input_deps) == 1
input_op = input_deps[0]

op = AllToAllOperator(
fused_all_to_all_transform_fn,
input_op,
name=name,
)
# Bottom out at the source logical op (e.g. Read()).
input_op = up_logical_op

logical_op = RandomShuffle(input_op, name=name, ray_remote_args=ray_remote_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we support other operator like Repartition?

self._op_map[op] = logical_op
# Return the fused physical operator.
return op


def _are_remote_args_compatible(up_args, down_args):
"""Check if Ray remote arguments are compatible for merging."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import MapTransformFn
from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec
from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata

Expand All @@ -19,9 +20,10 @@ def __init__(
self,
random_shuffle: bool = False,
random_seed: Optional[int] = None,
upstream_map_fn: Optional[MapTransformFn] = None,
):
super().__init__(
map_args=[random_shuffle, random_seed],
map_args=[upstream_map_fn, random_shuffle, random_seed],
reduce_args=[random_shuffle, random_seed],
)

Expand All @@ -30,11 +32,19 @@ def map(
idx: int,
block: Block,
output_num_blocks: int,
upstream_map_fn: Optional[MapTransformFn],
random_shuffle: bool,
random_seed: Optional[int],
) -> List[Union[BlockMetadata, Block]]:
# TODO: Support fusion with other upstream operators.
stats = BlockExecStats.builder()
if upstream_map_fn:
mapped_blocks = list(upstream_map_fn([block]))
assert len(mapped_blocks) == 1, (
"Expected upstream_map_fn to return one block, but instead"
f" returned {len(mapped_blocks)} blocks"
)
block = mapped_blocks[0]
block = BlockAccessor.for_block(block)

# Randomize the distribution of records to blocks.
Expand Down
16 changes: 15 additions & 1 deletion python/ray/data/_internal/planner/random_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ray.data._internal.execution.interfaces import (
AllToAllTransformFn,
MapTransformFn,
RefBundle,
TaskContext,
)
Expand All @@ -28,7 +29,20 @@ def fn(
ctx: TaskContext,
) -> Tuple[List[RefBundle], StatsDict]:
num_input_blocks = sum(len(r.blocks) for r in refs)
shuffle_spec = ShuffleTaskSpec(random_shuffle=True, random_seed=seed)

# If map_transform_fn is specified (e.g. from fusing
# MapOperator->AllToAllOperator), we pass a map function which
# is applied to each block before shuffling.
map_transform_fn: Optional[MapTransformFn] = ctx.upstream_map_transform_fn
upstream_map_fn = None
if map_transform_fn:
upstream_map_fn = lambda block: map_transform_fn(block, ctx) # noqa: E731

shuffle_spec = ShuffleTaskSpec(
random_shuffle=True,
random_seed=seed,
upstream_map_fn=upstream_map_fn,
)

if DataContext.get_current().use_push_based_shuffle:
if num_outputs is not None:
Expand Down
Loading