diff --git a/python/ray/data/_internal/logical/operators/write_operator.py b/python/ray/data/_internal/logical/operators/write_operator.py new file mode 100644 index 000000000000..f85b513e37f1 --- /dev/null +++ b/python/ray/data/_internal/logical/operators/write_operator.py @@ -0,0 +1,28 @@ +from typing import Any, Dict, Optional + +from ray.data._internal.logical.interfaces import LogicalOperator +from ray.data._internal.logical.operators.map_operator import AbstractMap +from ray.data.datasource.datasource import Datasource + + +class Write(AbstractMap): + """Logical operator for write.""" + + def __init__( + self, + input_op: LogicalOperator, + datasource: Datasource, + ray_remote_args: Optional[Dict[str, Any]] = None, + **write_args, + ): + super().__init__( + "Write", + input_op, + ray_remote_args, + ) + self._datasource = datasource + self._write_args = write_args + # Always use task to write. + self._compute = "tasks" + # Take the input blocks unchanged while writing. + self._target_block_size = float("inf") diff --git a/python/ray/data/_internal/logical/rules/operator_fusion.py b/python/ray/data/_internal/logical/rules/operator_fusion.py index 08f515e11ce0..ea8e91dc6b59 100644 --- a/python/ray/data/_internal/logical/rules/operator_fusion.py +++ b/python/ray/data/_internal/logical/rules/operator_fusion.py @@ -53,8 +53,8 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: * They have compatible remote arguments. """ from ray.data._internal.execution.operators.map_operator import MapOperator + from ray.data._internal.logical.operators.map_operator import AbstractMap from ray.data._internal.logical.operators.map_operator import AbstractUDFMap - from ray.data._internal.logical.operators.read_operator import Read # We only support fusing MapOperators. if not isinstance(down_op, MapOperator) or not isinstance(up_op, MapOperator): @@ -63,9 +63,14 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: down_logical_op = self._op_map[down_op] up_logical_op = self._op_map[up_op] - # We only support fusing upstream reads and maps with downstream maps. - if not isinstance(down_logical_op, AbstractUDFMap) or not isinstance( - up_logical_op, (Read, AbstractUDFMap) + # If the downstream operator takes no input, it cannot be fused with + # the upstream operator. + if not down_logical_op._input_dependencies: + return False + + # We only support fusing AbstractMap -> AbstractMap operators. + if not isinstance(down_logical_op, AbstractMap) or not isinstance( + up_logical_op, AbstractMap ): return False @@ -81,11 +86,13 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool: return False # Fusing callable classes is only supported if they are the same function AND - # their construction arguments are the same. + # their construction arguments are the same. Note the Write can be compatbile + # with any UDF as Write itself doesn't have UDF. # TODO(Clark): Support multiple callable classes instantiating in the same actor # worker. if ( - isinstance(down_logical_op._fn, CallableClass) + isinstance(down_logical_op, AbstractUDFMap) + and isinstance(down_logical_op._fn, CallableClass) and isinstance(up_logical_op, AbstractUDFMap) and isinstance(up_logical_op._fn, CallableClass) and ( @@ -172,18 +179,28 @@ def transform_fn(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: else: # Bottom out at the source logical op (e.g. Read()). input_op = up_logical_op - logical_op = AbstractUDFMap( - name, - input_op, - down_logical_op._fn, - down_logical_op._fn_args, - down_logical_op._fn_kwargs, - down_logical_op._fn_constructor_args, - down_logical_op._fn_constructor_kwargs, - target_block_size, - compute, - ray_remote_args, - ) + if isinstance(down_logical_op, AbstractUDFMap): + logical_op = AbstractUDFMap( + name, + input_op, + down_logical_op._fn, + down_logical_op._fn_args, + down_logical_op._fn_kwargs, + down_logical_op._fn_constructor_args, + down_logical_op._fn_constructor_kwargs, + target_block_size, + compute, + ray_remote_args, + ) + else: + from ray.data._internal.logical.operators.map_operator import AbstractMap + + # The downstream op is AbstractMap instead of AbstractUDFMap. + logical_op = AbstractMap( + name, + input_op, + ray_remote_args, + ) self._op_map[op] = logical_op # Return the fused physical operator. return op diff --git a/python/ray/data/_internal/planner/plan_write_op.py b/python/ray/data/_internal/planner/plan_write_op.py new file mode 100644 index 000000000000..0f3b6ee05e77 --- /dev/null +++ b/python/ray/data/_internal/planner/plan_write_op.py @@ -0,0 +1,24 @@ +from typing import Iterator + +from ray.data._internal.execution.interfaces import ( + PhysicalOperator, + TaskContext, +) +from ray.data._internal.execution.operators.map_operator import MapOperator +from ray.data.block import Block +from ray.data._internal.planner.write import generate_write_fn +from ray.data._internal.logical.operators.write_operator import Write + + +def _plan_write_op(op: Write, input_physical_dag: PhysicalOperator) -> PhysicalOperator: + transform_fn = generate_write_fn(op._datasource, **op._write_args) + + def do_write(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]: + yield from transform_fn(blocks, ctx) + + return MapOperator.create( + do_write, + input_physical_dag, + name="Write", + ray_remote_args=op._ray_remote_args, + ) diff --git a/python/ray/data/_internal/planner/planner.py b/python/ray/data/_internal/planner/planner.py index 02f0ebd238a7..af60ac1a4019 100644 --- a/python/ray/data/_internal/planner/planner.py +++ b/python/ray/data/_internal/planner/planner.py @@ -8,10 +8,12 @@ ) from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll from ray.data._internal.logical.operators.read_operator import Read +from ray.data._internal.logical.operators.write_operator import Write from ray.data._internal.logical.operators.map_operator import AbstractUDFMap from ray.data._internal.planner.plan_all_to_all_op import _plan_all_to_all_op from ray.data._internal.planner.plan_udf_map_op import _plan_udf_map_op from ray.data._internal.planner.plan_read_op import _plan_read_op +from ray.data._internal.planner.plan_write_op import _plan_write_op class Planner: @@ -38,6 +40,9 @@ def _plan(self, logical_op: LogicalOperator) -> PhysicalOperator: if isinstance(logical_op, Read): assert not physical_children physical_op = _plan_read_op(logical_op) + elif isinstance(logical_op, Write): + assert len(physical_children) == 1 + physical_op = _plan_write_op(logical_op, physical_children[0]) elif isinstance(logical_op, AbstractUDFMap): assert len(physical_children) == 1 physical_op = _plan_udf_map_op(logical_op, physical_children[0]) diff --git a/python/ray/data/_internal/planner/write.py b/python/ray/data/_internal/planner/write.py index d85285b8783a..998eaea0ac2e 100644 --- a/python/ray/data/_internal/planner/write.py +++ b/python/ray/data/_internal/planner/write.py @@ -1,18 +1,18 @@ from typing import Callable, Iterator from ray.data._internal.execution.interfaces import TaskContext -from ray.data.block import Block, RowUDF +from ray.data.block import Block from ray.data.datasource import Datasource def generate_write_fn( datasource: Datasource, **write_args -) -> Callable[[Iterator[Block], TaskContext, RowUDF], Iterator[Block]]: +) -> Callable[[Iterator[Block], TaskContext], Iterator[Block]]: # If the write op succeeds, the resulting Dataset is a list of # WriteResult (one element per write task). Otherwise, an error will # be raised. The Datasource can handle execution outcomes with the # on_write_complete() and on_write_failed(). - def fn(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]: + def fn(blocks: Iterator[Block], ctx) -> Iterator[Block]: return [[datasource.write(blocks, ctx, **write_args)]] return fn diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 0dff6901ac55..e3ee8c3e2dc8 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -41,8 +41,8 @@ FlatMap, MapRows, MapBatches, - Write, ) +from ray.data._internal.logical.operators.write_operator import Write from ray.data._internal.planner.filter import generate_filter_fn from ray.data._internal.planner.flat_map import generate_flat_map_fn from ray.data._internal.planner.map_batches import generate_map_batches_fn @@ -2707,10 +2707,15 @@ def write_datasource( ) if type(datasource).write != Datasource.write: + write_fn = generate_write_fn(datasource, **write_args) + + def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]: + return write_fn(blocks, ctx) + plan = self._plan.with_stage( OneToOneStage( "write", - generate_write_fn(datasource, **write_args), + write_fn_wrapper, "tasks", ray_remote_args, fn=lambda x: x, diff --git a/python/ray/data/tests/test_execution_optimizer.py b/python/ray/data/tests/test_execution_optimizer.py index 837bfce57e78..e47ee3a16be5 100644 --- a/python/ray/data/tests/test_execution_optimizer.py +++ b/python/ray/data/tests/test_execution_optimizer.py @@ -13,6 +13,7 @@ Sort, ) from ray.data._internal.logical.operators.read_operator import Read +from ray.data._internal.logical.operators.write_operator import Write from ray.data._internal.logical.operators.map_operator import ( MapRows, MapBatches, @@ -497,12 +498,29 @@ def test_read_map_chain_operator_fusion_e2e(ray_start_regular_shared, enable_opt assert name in ds.stats() -def test_write_operator(ray_start_regular_shared, enable_optimizer, tmp_path): +def test_write_fusion(ray_start_regular_shared, enable_optimizer, tmp_path): ds = ray.data.range(10, parallelism=2) ds.write_csv(tmp_path) assert "DoRead->Write" in ds._write_ds.stats() +def test_write_operator(ray_start_regular_shared, enable_optimizer): + planner = Planner() + datasource = ParquetDatasource() + read_op = Read(datasource) + op = Write( + read_op, + datasource, + ) + plan = LogicalPlan(op) + physical_op = planner.plan(plan).dag + + assert op.name == "Write" + assert isinstance(physical_op, MapOperator) + assert len(physical_op.input_dependencies) == 1 + assert isinstance(physical_op.input_dependencies[0], MapOperator) + + def test_sort_operator(ray_start_regular_shared, enable_optimizer): planner = Planner() read_op = Read(ParquetDatasource())