-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Allow fusing MapOperator -> AllToAllOperator #34847
Changes from all commits
fb7b546
774a4a7
f3eba72
ebbe94f
1f3af9a
8cf5215
6660fc4
520a46a
059a4a4
777bc89
9265930
5880507
be501bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,23 @@ | ||
from typing import Iterator | ||
from typing import Iterator, List, Tuple | ||
from ray.data._internal.execution.operators.map_operator import MapOperator | ||
from ray.data._internal.logical.operators.all_to_all_operator import ( | ||
AbstractAllToAll, | ||
RandomShuffle, | ||
) | ||
from ray.data._internal.stats import StatsDict | ||
|
||
from ray.data.block import Block | ||
|
||
# TODO(Clark): Remove compute dependency once we delete the legacy compute. | ||
from ray.data._internal.compute import is_task_compute, CallableClass, get_compute | ||
from ray.data._internal.execution.interfaces import PhysicalOperator, TaskContext | ||
from ray.data._internal.execution.interfaces import ( | ||
PhysicalOperator, | ||
RefBundle, | ||
TaskContext, | ||
) | ||
from ray.data._internal.logical.interfaces import Rule, PhysicalPlan | ||
from ray.data._internal.execution.operators.all_to_all_operator import AllToAllOperator | ||
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap | ||
|
||
|
||
# Scheduling strategy can be inherited from upstream operator if not specified. | ||
|
@@ -17,35 +29,72 @@ class OperatorFusionRule(Rule): | |
|
||
def apply(self, plan: PhysicalPlan) -> PhysicalPlan: | ||
self._op_map = plan.op_map.copy() | ||
# Do DFS fusion. | ||
root = self._apply(plan.dag) | ||
return PhysicalPlan(root, self._op_map) | ||
# Do DFS fusion on compatible pairwise operators in two passes. | ||
# In the first pass, only fuse back-to-back map operators together. | ||
fused_dag = self._fuse_map_operators_in_dag(plan.dag) | ||
|
||
def _apply(self, op: PhysicalOperator) -> PhysicalOperator: | ||
"""Performs DFS fusion of linear chains of physical map operators, provided that | ||
they are pairwise-compatible. | ||
# Now that we have fused together all back-to-back map operators, | ||
# we fuse together MapOperator -> AllToAllOperator pairs. | ||
fused_dag = self._fuse_all_to_all_operators_in_dag(fused_dag) | ||
|
||
Args: | ||
op: The op that we're trying to fuse with its input. | ||
return PhysicalPlan(fused_dag, self._op_map) | ||
|
||
def _fuse_map_operators_in_dag(self, dag: PhysicalOperator) -> MapOperator: | ||
"""Starting at the given operator, traverses up the DAG of operators | ||
and recursively fuses compatible MapOperator -> MapOperator pairs. | ||
Returns the current (root) operator after completing upstream operator fusions. | ||
""" | ||
upstream_ops = op.input_dependencies | ||
# Fuse with upstream ops while possible. | ||
while len(upstream_ops) == 1 and self._can_fuse(op, upstream_ops[0]): | ||
upstream_ops = dag.input_dependencies | ||
while ( | ||
len(upstream_ops) == 1 | ||
and isinstance(dag, MapOperator) | ||
and isinstance(upstream_ops[0], MapOperator) | ||
and self._can_fuse(dag, upstream_ops[0]) | ||
): | ||
# Fuse operator with its upstream op. | ||
op = self._fuse(op, upstream_ops[0]) | ||
upstream_ops = op.input_dependencies | ||
# Can no longer fuse with upstream ops, proceed up the DAG. | ||
op._input_dependencies = [ | ||
self._apply(upstream_op) for upstream_op in upstream_ops | ||
dag = self._get_fused_map_operator(dag, upstream_ops[0]) | ||
upstream_ops = dag.input_dependencies | ||
|
||
# Done fusing back-to-back map operators together here, | ||
# move up the DAG to find the next map operators to fuse. | ||
dag._input_dependencies = [ | ||
self._fuse_map_operators_in_dag(upstream_op) for upstream_op in upstream_ops | ||
] | ||
return op | ||
return dag | ||
|
||
def _fuse_all_to_all_operators_in_dag( | ||
self, dag: AllToAllOperator | ||
) -> AllToAllOperator: | ||
"""Starting at the given operator, traverses up the DAG of operators | ||
and recursively fuses compatible MapOperator -> AllToAllOperator pairs. | ||
Returns the current (root) operator after completing upstream operator fusions. | ||
""" | ||
upstream_ops = dag.input_dependencies | ||
while ( | ||
len(upstream_ops) == 1 | ||
scottjlee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and isinstance(dag, AllToAllOperator) | ||
and isinstance(upstream_ops[0], MapOperator) | ||
and self._can_fuse(dag, upstream_ops[0]) | ||
): | ||
# Fuse operator with its upstream op. | ||
dag = self._get_fused_all_to_all_operator(dag, upstream_ops[0]) | ||
upstream_ops = dag.input_dependencies | ||
|
||
# Done fusing MapOperator -> AllToAllOperator together here, | ||
# move up the DAG to find the next pair of operators to fuse. | ||
dag._input_dependencies = [ | ||
self._fuse_all_to_all_operators_in_dag(upstream_op) | ||
for upstream_op in upstream_ops | ||
] | ||
return dag | ||
|
||
def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: | ||
"""Returns whether the provided downstream operator can be fused with the given | ||
upstream operator. | ||
|
||
We currently support fusing two operators if the following are all true: | ||
* They are both MapOperators. | ||
* We are fusing either MapOperator -> MapOperator or | ||
MapOperator -> AllToAllOperator. | ||
* They either use the same compute configuration, or the upstream operator | ||
uses a task pool while the downstream operator uses an actor pool. | ||
* If both operators involve callable classes, the callable classes are | ||
|
@@ -56,8 +105,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: | |
from ray.data._internal.logical.operators.map_operator import AbstractMap | ||
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap | ||
|
||
# We only support fusing MapOperators. | ||
if not isinstance(down_op, MapOperator) or not isinstance(up_op, MapOperator): | ||
# We currently only support fusing for the following cases: | ||
# - MapOperator -> MapOperator | ||
# - MapOperator -> AllToAllOperator (only RandomShuffle | ||
# LogicalOperator is currently supported) | ||
if not isinstance(down_op, (MapOperator, AllToAllOperator)) or not isinstance( | ||
up_op, MapOperator | ||
): | ||
return False | ||
|
||
down_logical_op = self._op_map[down_op] | ||
|
@@ -68,17 +122,20 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: | |
if not down_logical_op._input_dependencies: | ||
return False | ||
|
||
# We only support fusing AbstractMap -> AbstractMap operators. | ||
if not isinstance(down_logical_op, AbstractMap) or not isinstance( | ||
up_logical_op, AbstractMap | ||
): | ||
# We currently only support fusing for the following cases: | ||
# - AbstractMap -> AbstractMap | ||
# - AbstractMap -> RandomShuffle | ||
if not isinstance( | ||
down_logical_op, (AbstractMap, RandomShuffle) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we will support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we will do this in a future PR. |
||
) or not isinstance(up_logical_op, AbstractMap): | ||
return False | ||
|
||
# Allow fusing tasks->actors if the resources are compatible (read->map), but | ||
# not the other way around. The latter (downstream op) will be used as the | ||
# compute if fused. | ||
if ( | ||
is_task_compute(down_logical_op._compute) | ||
isinstance(down_logical_op, AbstractUDFMap) | ||
and is_task_compute(down_logical_op._compute) | ||
and isinstance(up_logical_op, AbstractUDFMap) | ||
and get_compute(up_logical_op._compute) | ||
!= get_compute(down_logical_op._compute) | ||
|
@@ -116,12 +173,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: | |
# Otherwise, ops are compatible for fusion. | ||
return True | ||
|
||
def _fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator): | ||
"""Fuse the downstream operator with its upstream operator.""" | ||
from ray.data._internal.execution.operators.map_operator import MapOperator | ||
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap | ||
|
||
assert self._can_fuse(down_op, up_op) | ||
def _get_fused_map_operator( | ||
self, down_op: MapOperator, up_op: MapOperator | ||
) -> MapOperator: | ||
assert self._can_fuse(down_op, up_op), ( | ||
"Current rule supports fusing MapOperator->MapOperator, but received: " | ||
f"{type(up_op).__name__} -> {type(down_op).__name__}" | ||
) | ||
|
||
# Fuse operator names. | ||
name = up_op.name + "->" + down_op.name | ||
|
@@ -147,9 +205,11 @@ def _fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator): | |
down_transform_fn = down_op.get_transformation_fn() | ||
up_transform_fn = up_op.get_transformation_fn() | ||
|
||
def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: | ||
def fused_map_transform_fn( | ||
blocks: Iterator[Block], ctx: TaskContext | ||
) -> Iterator[Block]: | ||
blocks = up_transform_fn(blocks, ctx) | ||
# TODO(Clark): Add zero-copy batching between transform functions. | ||
# TODO(Scott): Add zero-copy batching between transform functions. | ||
return down_transform_fn(blocks, ctx) | ||
|
||
# We take the downstream op's compute in case we're fusing upstream tasks with a | ||
|
@@ -163,7 +223,7 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: | |
|
||
# Fused physical map operator. | ||
op = MapOperator.create( | ||
transform_fn, | ||
fused_map_transform_fn, | ||
input_op, | ||
name=name, | ||
compute_strategy=compute, | ||
|
@@ -172,7 +232,7 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: | |
) | ||
|
||
# Build a map logical operator to be used as a reference for further fusion. | ||
# TODO(Clark): This is hacky, remove this once we push fusion to be purely based | ||
# TODO(Scott): This is hacky, remove this once we push fusion to be purely based | ||
# on a lower-level operator spec. | ||
if isinstance(up_logical_op, AbstractUDFMap): | ||
input_op = up_logical_op.input_dependencies[0] | ||
|
@@ -205,6 +265,52 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: | |
# Return the fused physical operator. | ||
return op | ||
|
||
def _get_fused_all_to_all_operator( | ||
self, down_op: AllToAllOperator, up_op: MapOperator | ||
) -> AllToAllOperator: | ||
assert self._can_fuse(down_op, up_op), ( | ||
"Current rule supports fusing MapOperator -> AllToAllOperator" | ||
f", but received: {type(up_op).__name__} -> {type(down_op).__name__}" | ||
) | ||
|
||
# Fuse operator names. | ||
name = up_op.name + "->" + down_op.name | ||
|
||
down_logical_op: AbstractAllToAll = self._op_map.pop(down_op) | ||
up_logical_op: AbstractUDFMap = self._op_map.pop(up_op) | ||
|
||
# Fuse transformation functions. | ||
down_transform_fn = down_op.get_transformation_fn() | ||
up_transform_fn = up_op.get_transformation_fn() | ||
|
||
def fused_all_to_all_transform_fn( | ||
blocks: List[RefBundle], ctx: TaskContext | ||
) -> Tuple[List[RefBundle], StatsDict]: | ||
"""To fuse MapOperator->AllToAllOperator, we store the map function | ||
in the TaskContext so that it may be used by the downstream | ||
AllToAllOperator's transform function.""" | ||
ctx.upstream_map_transform_fn = up_transform_fn | ||
return down_transform_fn(blocks, ctx) | ||
|
||
ray_remote_args = down_logical_op._ray_remote_args | ||
# Make the upstream operator's inputs the new, fused operator's inputs. | ||
input_deps = up_op.input_dependencies | ||
assert len(input_deps) == 1 | ||
input_op = input_deps[0] | ||
|
||
op = AllToAllOperator( | ||
fused_all_to_all_transform_fn, | ||
input_op, | ||
name=name, | ||
) | ||
# Bottom out at the source logical op (e.g. Read()). | ||
input_op = up_logical_op | ||
|
||
logical_op = RandomShuffle(input_op, name=name, ray_remote_args=ray_remote_args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do we support other operator like |
||
self._op_map[op] = logical_op | ||
# Return the fused physical operator. | ||
return op | ||
|
||
|
||
def _are_remote_args_compatible(up_args, down_args): | ||
"""Check if Ray remote arguments are compatible for merging.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is needed so we can override the name with the upstream operator-fused name.