Skip to content

Commit

Permalink
[Datasets] Make Write op extend AbstractMap operator (ray-project#32538)
Browse files Browse the repository at this point in the history
Signed-off-by: jianoaix <[email protected]>
  • Loading branch information
jianoaix authored Feb 24, 2023
1 parent b3b8ba8 commit 7fa6a0f
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 24 deletions.
28 changes: 28 additions & 0 deletions python/ray/data/_internal/logical/operators/write_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Dict, Optional

from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data.datasource.datasource import Datasource


class Write(AbstractMap):
"""Logical operator for write."""

def __init__(
self,
input_op: LogicalOperator,
datasource: Datasource,
ray_remote_args: Optional[Dict[str, Any]] = None,
**write_args,
):
super().__init__(
"Write",
input_op,
ray_remote_args,
)
self._datasource = datasource
self._write_args = write_args
# Always use task to write.
self._compute = "tasks"
# Take the input blocks unchanged while writing.
self._target_block_size = float("inf")
53 changes: 35 additions & 18 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
* They have compatible remote arguments.
"""
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.logical.operators.map_operator import AbstractMap
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap
from ray.data._internal.logical.operators.read_operator import Read

# We only support fusing MapOperators.
if not isinstance(down_op, MapOperator) or not isinstance(up_op, MapOperator):
Expand All @@ -63,9 +63,14 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
down_logical_op = self._op_map[down_op]
up_logical_op = self._op_map[up_op]

# We only support fusing upstream reads and maps with downstream maps.
if not isinstance(down_logical_op, AbstractUDFMap) or not isinstance(
up_logical_op, (Read, AbstractUDFMap)
# If the downstream operator takes no input, it cannot be fused with
# the upstream operator.
if not down_logical_op._input_dependencies:
return False

# We only support fusing AbstractMap -> AbstractMap operators.
if not isinstance(down_logical_op, AbstractMap) or not isinstance(
up_logical_op, AbstractMap
):
return False

Expand All @@ -81,11 +86,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
return False

# Fusing callable classes is only supported if they are the same function AND
# their construction arguments are the same.
# their construction arguments are the same. Note the Write can be compatbile
# with any UDF as Write itself doesn't have UDF.
# TODO(Clark): Support multiple callable classes instantiating in the same actor
# worker.
if (
isinstance(down_logical_op._fn, CallableClass)
isinstance(down_logical_op, AbstractUDFMap)
and isinstance(down_logical_op._fn, CallableClass)
and isinstance(up_logical_op, AbstractUDFMap)
and isinstance(up_logical_op._fn, CallableClass)
and (
Expand Down Expand Up @@ -172,18 +179,28 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
else:
# Bottom out at the source logical op (e.g. Read()).
input_op = up_logical_op
logical_op = AbstractUDFMap(
name,
input_op,
down_logical_op._fn,
down_logical_op._fn_args,
down_logical_op._fn_kwargs,
down_logical_op._fn_constructor_args,
down_logical_op._fn_constructor_kwargs,
target_block_size,
compute,
ray_remote_args,
)
if isinstance(down_logical_op, AbstractUDFMap):
logical_op = AbstractUDFMap(
name,
input_op,
down_logical_op._fn,
down_logical_op._fn_args,
down_logical_op._fn_kwargs,
down_logical_op._fn_constructor_args,
down_logical_op._fn_constructor_kwargs,
target_block_size,
compute,
ray_remote_args,
)
else:
from ray.data._internal.logical.operators.map_operator import AbstractMap

# The downstream op is AbstractMap instead of AbstractUDFMap.
logical_op = AbstractMap(
name,
input_op,
ray_remote_args,
)
self._op_map[op] = logical_op
# Return the fused physical operator.
return op
Expand Down
24 changes: 24 additions & 0 deletions python/ray/data/_internal/planner/plan_write_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Iterator

from ray.data._internal.execution.interfaces import (
PhysicalOperator,
TaskContext,
)
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data.block import Block
from ray.data._internal.planner.write import generate_write_fn
from ray.data._internal.logical.operators.write_operator import Write


def _plan_write_op(op: Write, input_physical_dag: PhysicalOperator) -> PhysicalOperator:
transform_fn = generate_write_fn(op._datasource, **op._write_args)

def do_write(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
yield from transform_fn(blocks, ctx)

return MapOperator.create(
do_write,
input_physical_dag,
name="Write",
ray_remote_args=op._ray_remote_args,
)
5 changes: 5 additions & 0 deletions python/ray/data/_internal/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
)
from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.logical.operators.write_operator import Write
from ray.data._internal.logical.operators.map_operator import AbstractUDFMap
from ray.data._internal.planner.plan_all_to_all_op import _plan_all_to_all_op
from ray.data._internal.planner.plan_udf_map_op import _plan_udf_map_op
from ray.data._internal.planner.plan_read_op import _plan_read_op
from ray.data._internal.planner.plan_write_op import _plan_write_op


class Planner:
Expand All @@ -38,6 +40,9 @@ def _plan(self, logical_op: LogicalOperator) -> PhysicalOperator:
if isinstance(logical_op, Read):
assert not physical_children
physical_op = _plan_read_op(logical_op)
elif isinstance(logical_op, Write):
assert len(physical_children) == 1
physical_op = _plan_write_op(logical_op, physical_children[0])
elif isinstance(logical_op, AbstractUDFMap):
assert len(physical_children) == 1
physical_op = _plan_udf_map_op(logical_op, physical_children[0])
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/_internal/planner/write.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Callable, Iterator

from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, RowUDF
from ray.data.block import Block
from ray.data.datasource import Datasource


def generate_write_fn(
datasource: Datasource, **write_args
) -> Callable[[Iterator[Block], TaskContext, RowUDF], Iterator[Block]]:
) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]:
# If the write op succeeds, the resulting Dataset is a list of
# WriteResult (one element per write task). Otherwise, an error will
# be raised. The Datasource can handle execution outcomes with the
# on_write_complete() and on_write_failed().
def fn(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]:
def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]:
return [[datasource.write(blocks, ctx, **write_args)]]

return fn
9 changes: 7 additions & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
FlatMap,
MapRows,
MapBatches,
Write,
)
from ray.data._internal.logical.operators.write_operator import Write
from ray.data._internal.planner.filter import generate_filter_fn
from ray.data._internal.planner.flat_map import generate_flat_map_fn
from ray.data._internal.planner.map_batches import generate_map_batches_fn
Expand Down Expand Up @@ -2707,10 +2707,15 @@ def write_datasource(
)

if type(datasource).write != Datasource.write:
write_fn = generate_write_fn(datasource, **write_args)

def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]:
return write_fn(blocks, ctx)

plan = self._plan.with_stage(
OneToOneStage(
"write",
generate_write_fn(datasource, **write_args),
write_fn_wrapper,
"tasks",
ray_remote_args,
fn=lambda x: x,
Expand Down
20 changes: 19 additions & 1 deletion python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Sort,
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.logical.operators.write_operator import Write
from ray.data._internal.logical.operators.map_operator import (
MapRows,
MapBatches,
Expand Down Expand Up @@ -497,12 +498,29 @@ def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_opt
assert name in ds.stats()


def test_write_operator(ray_start_regular_shared, enable_optimizer, tmp_path):
def test_write_fusion(ray_start_regular_shared, enable_optimizer, tmp_path):
ds = ray.data.range(10, parallelism=2)
ds.write_csv(tmp_path)
assert "DoRead->Write" in ds._write_ds.stats()


def test_write_operator(ray_start_regular_shared, enable_optimizer):
planner = Planner()
datasource = ParquetDatasource()
read_op = Read(datasource)
op = Write(
read_op,
datasource,
)
plan = LogicalPlan(op)
physical_op = planner.plan(plan).dag

assert op.name == "Write"
assert isinstance(physical_op, MapOperator)
assert len(physical_op.input_dependencies) == 1
assert isinstance(physical_op.input_dependencies[0], MapOperator)


def test_sort_operator(ray_start_regular_shared, enable_optimizer):
planner = Planner()
read_op = Read(ParquetDatasource())
Expand Down

0 comments on commit 7fa6a0f

Please sign in to comment.