diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index 901a282775c5..9a27c0b6585d 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -62,8 +62,7 @@ # Module-level cached global functions for callable classes. It needs to be defined here # since it has to be process-global across cloudpickled funcs. -_cached_fn = None -_cached_cls = None +_map_actor_context = None configure_logging() diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index b1747616c655..d1be38f8ac52 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -1,4 +1,8 @@ +import asyncio import collections +import inspect +import queue +from threading import Thread from types import GeneratorType from typing import Any, Callable, Iterable, Iterator, List, Optional @@ -43,6 +47,36 @@ from ray.util.rpdb import _is_ray_debugger_enabled +class _MapActorContext: + def __init__( + self, + udf_map_cls: UserDefinedFunction, + udf_map_fn: Callable[[Any], Any], + is_async: bool, + ): + self.udf_map_cls = udf_map_cls + self.udf_map_fn = udf_map_fn + self.is_async = is_async + self.udf_map_asyncio_loop = None + self.udf_map_asyncio_thread = None + + if is_async: + self._init_async() + + def _init_async(self): + # Only used for callable class with async generator `__call__` method. + loop = asyncio.new_event_loop() + + def run_loop(): + asyncio.set_event_loop(loop) + loop.run_forever() + + thread = Thread(target=run_loop) + thread.start() + self.udf_map_asyncio_loop = loop + self.udf_map_asyncio_thread = thread + + def plan_udf_map_op( op: AbstractUDFMap, physical_children: List[PhysicalOperator] ) -> MapOperator: @@ -104,23 +138,53 @@ def _parse_op_fn(op: AbstractUDFMap): fn_constructor_args = op._fn_constructor_args or () fn_constructor_kwargs = op._fn_constructor_kwargs or {} - op_fn = make_callable_class_concurrent(op_fn) + is_async_gen = inspect.isasyncgenfunction(op._fn.__call__) - def fn(item: Any) -> Any: - assert ray.data._cached_fn is not None - assert ray.data._cached_cls == op_fn - try: - return ray.data._cached_fn(item, *fn_args, **fn_kwargs) - except Exception as e: - _handle_debugger_exception(e) + # TODO(scottjlee): (1) support non-generator async functions + # (2) make the map actor async + if not is_async_gen: + op_fn = make_callable_class_concurrent(op_fn) def init_fn(): - if ray.data._cached_fn is None: - ray.data._cached_cls = op_fn - ray.data._cached_fn = op_fn( - *fn_constructor_args, **fn_constructor_kwargs + if ray.data._map_actor_context is None: + ray.data._map_actor_context = _MapActorContext( + udf_map_cls=op_fn, + udf_map_fn=op_fn( + *fn_constructor_args, + **fn_constructor_kwargs, + ), + is_async=is_async_gen, ) + if is_async_gen: + + async def fn(item: Any) -> Any: + assert ray.data._map_actor_context is not None + assert ray.data._map_actor_context.is_async + + try: + return ray.data._map_actor_context.udf_map_fn( + item, + *fn_args, + **fn_kwargs, + ) + except Exception as e: + _handle_debugger_exception(e) + + else: + + def fn(item: Any) -> Any: + assert ray.data._map_actor_context is not None + assert not ray.data._map_actor_context.is_async + try: + return ray.data._map_actor_context.udf_map_fn( + item, + *fn_args, + **fn_kwargs, + ) + except Exception as e: + _handle_debugger_exception(e) + else: def fn(item: Any) -> Any: @@ -158,6 +222,7 @@ def _validate_batch_output(batch: Block) -> None: np.ndarray, collections.abc.Mapping, pd.core.frame.DataFrame, + dict, ), ): raise ValueError( @@ -192,46 +257,99 @@ def _validate_batch_output(batch: Block) -> None: def _generate_transform_fn_for_map_batches( fn: UserDefinedFunction, +) -> MapTransformCallable[DataBatch, DataBatch]: + if inspect.iscoroutinefunction(fn): + # UDF is a callable class with async generator `__call__` method. + transform_fn = _generate_transform_fn_for_async_map_batches(fn) + + else: + + def transform_fn( + batches: Iterable[DataBatch], _: TaskContext + ) -> Iterable[DataBatch]: + for batch in batches: + try: + if ( + not isinstance(batch, collections.abc.Mapping) + and BlockAccessor.for_block(batch).num_rows() == 0 + ): + # For empty input blocks, we directly ouptut them without + # calling the UDF. + # TODO(hchen): This workaround is because some all-to-all + # operators output empty blocks with no schema. + res = [batch] + else: + res = fn(batch) + if not isinstance(res, GeneratorType): + res = [res] + except ValueError as e: + read_only_msgs = [ + "assignment destination is read-only", + "buffer source array is read-only", + ] + err_msg = str(e) + if any(msg in err_msg for msg in read_only_msgs): + raise ValueError( + f"Batch mapper function {fn.__name__} tried to mutate a " + "zero-copy read-only batch. To be able to mutate the " + "batch, pass zero_copy_batch=False to map_batches(); " + "this will create a writable copy of the batch before " + "giving it to fn. To elide this copy, modify your mapper " + "function so it doesn't try to mutate its input." + ) from e + else: + raise e from None + else: + for out_batch in res: + _validate_batch_output(out_batch) + yield out_batch + + return transform_fn + + +def _generate_transform_fn_for_async_map_batches( + fn: UserDefinedFunction, ) -> MapTransformCallable[DataBatch, DataBatch]: def transform_fn( - batches: Iterable[DataBatch], _: TaskContext + input_iterable: Iterable[DataBatch], _: TaskContext ) -> Iterable[DataBatch]: - for batch in batches: - try: - if ( - not isinstance(batch, collections.abc.Mapping) - and BlockAccessor.for_block(batch).num_rows() == 0 - ): - # For empty input blocks, we directly ouptut them without - # calling the UDF. - # TODO(hchen): This workaround is because some all-to-all - # operators output empty blocks with no schema. - res = [batch] - else: - res = fn(batch) - if not isinstance(res, GeneratorType): - res = [res] - except ValueError as e: - read_only_msgs = [ - "assignment destination is read-only", - "buffer source array is read-only", - ] - err_msg = str(e) - if any(msg in err_msg for msg in read_only_msgs): - raise ValueError( - f"Batch mapper function {fn.__name__} tried to mutate a " - "zero-copy read-only batch. To be able to mutate the " - "batch, pass zero_copy_batch=False to map_batches(); " - "this will create a writable copy of the batch before " - "giving it to fn. To elide this copy, modify your mapper " - "function so it doesn't try to mutate its input." - ) from e - else: - raise e from None + # Use a queue to store outputs from async generator calls. + # We will put output batches into this queue from async + # generators, and in the main event loop, yield them from + # the queue as they become available. + output_batch_queue = queue.Queue() + + async def process_batch(batch: DataBatch): + output_batch_iterator = await fn(batch) + # As soon as results become available from the async generator, + # put them into the result queue so they can be yielded. + async for output_batch in output_batch_iterator: + output_batch_queue.put(output_batch) + + async def process_all_batches(): + loop = ray.data._map_actor_context.udf_map_asyncio_loop + tasks = [loop.create_task(process_batch(x)) for x in input_iterable] + + ctx = ray.data.DataContext.get_current() + if ctx.execution_options.preserve_order: + for task in tasks: + await task() else: - for out_batch in res: - _validate_batch_output(out_batch) - yield out_batch + for task in asyncio.as_completed(tasks): + await task + + # Use the existing event loop to create and run Tasks to process each batch + loop = ray.data._map_actor_context.udf_map_asyncio_loop + future = asyncio.run_coroutine_threadsafe(process_all_batches(), loop) + + # Yield results as they become available. + while not future.done(): + # Here, `out_batch` is a one-row output batch + # from the async generator, corresponding to a + # single row from the input batch. + out_batch = output_batch_queue.get() + _validate_batch_output(out_batch) + yield out_batch return transform_fn diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index a511eb7255c3..337bacc6f77d 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1,3 +1,4 @@ +import asyncio import itertools import math import os @@ -1057,6 +1058,39 @@ def test_nonserializable_map_batches(shutdown_only): x.map_batches(lambda _: lock).take(1) +def test_map_batches_async_generator(shutdown_only): + ray.shutdown() + ray.init(num_cpus=10) + + async def sleep_and_yield(i): + print("sleep", i) + await asyncio.sleep(i % 5) + print("yield", i) + return {"input": [i], "output": [2**i]} + + class AsyncActor: + def __init__(self): + pass + + async def __call__(self, batch): + tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]] + for task in tasks: + yield await task + + n = 10 + ds = ray.data.range(n, override_num_blocks=2) + ds = ds.map(lambda x: x) + ds = ds.map_batches(AsyncActor, batch_size=1, concurrency=1, max_concurrency=2) + + start_t = time.time() + output = ds.take_all() + runtime = time.time() - start_t + assert runtime < sum(range(n)), runtime + + expected_output = [{"input": i, "output": 2**i} for i in range(n)] + assert output == expected_output, (output, expected_output) + + if __name__ == "__main__": import sys