Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Support async callable classes in map_batches() #46129

Merged
merged 9 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@

# Module-level cached global functions for callable classes. It needs to be defined here
# since it has to be process-global across cloudpickled funcs.
_cached_fn = None
_cached_cls = None
_map_actor_context = None

configure_logging()

Expand Down
214 changes: 166 additions & 48 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import asyncio
import collections
import inspect
import queue
from threading import Thread
from types import GeneratorType
from typing import Any, Callable, Iterable, Iterator, List, Optional

Expand Down Expand Up @@ -43,6 +47,36 @@
from ray.util.rpdb import _is_ray_debugger_enabled


class _MapActorContext:
def __init__(
self,
udf_map_cls: UserDefinedFunction,
udf_map_fn: Callable[[Any], Any],
is_async: bool,
):
self.udf_map_cls = udf_map_cls
self.udf_map_fn = udf_map_fn
self.is_async = is_async
self.udf_map_asyncio_loop = None
self.udf_map_asyncio_thread = None

if is_async:
self._init_async()

def _init_async(self):
# Only used for callable class with async generator `__call__` method.
loop = asyncio.new_event_loop()

def run_loop():
asyncio.set_event_loop(loop)
loop.run_forever()

thread = Thread(target=run_loop)
thread.start()
self.udf_map_asyncio_loop = loop
self.udf_map_asyncio_thread = thread


def plan_udf_map_op(
op: AbstractUDFMap, physical_children: List[PhysicalOperator]
) -> MapOperator:
Expand Down Expand Up @@ -104,23 +138,53 @@ def _parse_op_fn(op: AbstractUDFMap):
fn_constructor_args = op._fn_constructor_args or ()
fn_constructor_kwargs = op._fn_constructor_kwargs or {}

op_fn = make_callable_class_concurrent(op_fn)
is_async_gen = inspect.isasyncgenfunction(op._fn.__call__)

def fn(item: Any) -> Any:
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == op_fn
try:
return ray.data._cached_fn(item, *fn_args, **fn_kwargs)
except Exception as e:
_handle_debugger_exception(e)
# TODO(scottjlee): (1) support non-generator async functions
# (2) make the map actor async
if not is_async_gen:
op_fn = make_callable_class_concurrent(op_fn)

def init_fn():
if ray.data._cached_fn is None:
ray.data._cached_cls = op_fn
ray.data._cached_fn = op_fn(
*fn_constructor_args, **fn_constructor_kwargs
if ray.data._map_actor_context is None:
ray.data._map_actor_context = _MapActorContext(
udf_map_cls=op_fn,
udf_map_fn=op_fn(
*fn_constructor_args,
**fn_constructor_kwargs,
),
is_async=is_async_gen,
)

if is_async_gen:

async def fn(item: Any) -> Any:
assert ray.data._map_actor_context is not None
assert ray.data._map_actor_context.is_async

try:
return ray.data._map_actor_context.udf_map_fn(
item,
*fn_args,
**fn_kwargs,
)
except Exception as e:
_handle_debugger_exception(e)

else:

def fn(item: Any) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 fns are almost identical, except the assertion? better avoid duplicating the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's what i initially thought, but one of them is async, and the other is not. is there a way to combine the inner implementation but create an async version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, nvm, I didn't notice the async prefix

assert ray.data._map_actor_context is not None
assert not ray.data._map_actor_context.is_async
try:
return ray.data._map_actor_context.udf_map_fn(
item,
*fn_args,
**fn_kwargs,
)
except Exception as e:
_handle_debugger_exception(e)

else:

def fn(item: Any) -> Any:
Expand Down Expand Up @@ -158,6 +222,7 @@ def _validate_batch_output(batch: Block) -> None:
np.ndarray,
collections.abc.Mapping,
pd.core.frame.DataFrame,
dict,
),
):
raise ValueError(
Expand Down Expand Up @@ -192,46 +257,99 @@ def _validate_batch_output(batch: Block) -> None:

def _generate_transform_fn_for_map_batches(
fn: UserDefinedFunction,
) -> MapTransformCallable[DataBatch, DataBatch]:
if inspect.iscoroutinefunction(fn):
# UDF is a callable class with async generator `__call__` method.
transform_fn = _generate_transform_fn_for_async_map_batches(fn)

else:

def transform_fn(
batches: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
for batch in batches:
try:
if (
not isinstance(batch, collections.abc.Mapping)
and BlockAccessor.for_block(batch).num_rows() == 0
):
# For empty input blocks, we directly ouptut them without
# calling the UDF.
# TODO(hchen): This workaround is because some all-to-all
# operators output empty blocks with no schema.
res = [batch]
else:
res = fn(batch)
if not isinstance(res, GeneratorType):
res = [res]
except ValueError as e:
read_only_msgs = [
"assignment destination is read-only",
"buffer source array is read-only",
]
err_msg = str(e)
if any(msg in err_msg for msg in read_only_msgs):
raise ValueError(
f"Batch mapper function {fn.__name__} tried to mutate a "
"zero-copy read-only batch. To be able to mutate the "
"batch, pass zero_copy_batch=False to map_batches(); "
"this will create a writable copy of the batch before "
"giving it to fn. To elide this copy, modify your mapper "
"function so it doesn't try to mutate its input."
) from e
else:
raise e from None
else:
for out_batch in res:
_validate_batch_output(out_batch)
yield out_batch

return transform_fn


def _generate_transform_fn_for_async_map_batches(
fn: UserDefinedFunction,
) -> MapTransformCallable[DataBatch, DataBatch]:
def transform_fn(
batches: Iterable[DataBatch], _: TaskContext
input_iterable: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
for batch in batches:
try:
if (
not isinstance(batch, collections.abc.Mapping)
and BlockAccessor.for_block(batch).num_rows() == 0
):
# For empty input blocks, we directly ouptut them without
# calling the UDF.
# TODO(hchen): This workaround is because some all-to-all
# operators output empty blocks with no schema.
res = [batch]
else:
res = fn(batch)
if not isinstance(res, GeneratorType):
res = [res]
except ValueError as e:
read_only_msgs = [
"assignment destination is read-only",
"buffer source array is read-only",
]
err_msg = str(e)
if any(msg in err_msg for msg in read_only_msgs):
raise ValueError(
f"Batch mapper function {fn.__name__} tried to mutate a "
"zero-copy read-only batch. To be able to mutate the "
"batch, pass zero_copy_batch=False to map_batches(); "
"this will create a writable copy of the batch before "
"giving it to fn. To elide this copy, modify your mapper "
"function so it doesn't try to mutate its input."
) from e
else:
raise e from None
# Use a queue to store outputs from async generator calls.
# We will put output batches into this queue from async
# generators, and in the main event loop, yield them from
# the queue as they become available.
output_batch_queue = queue.Queue()

async def process_batch(batch: DataBatch):
output_batch_iterator = await fn(batch)
# As soon as results become available from the async generator,
# put them into the result queue so they can be yielded.
async for output_batch in output_batch_iterator:
output_batch_queue.put(output_batch)

async def process_all_batches():
loop = ray.data._map_actor_context.udf_map_asyncio_loop
tasks = [loop.create_task(process_batch(x)) for x in input_iterable]

ctx = ray.data.DataContext.get_current()
if ctx.execution_options.preserve_order:
for task in tasks:
await task()
else:
for out_batch in res:
_validate_batch_output(out_batch)
yield out_batch
for task in asyncio.as_completed(tasks):
await task

# Use the existing event loop to create and run Tasks to process each batch
loop = ray.data._map_actor_context.udf_map_asyncio_loop
future = asyncio.run_coroutine_threadsafe(process_all_batches(), loop)

# Yield results as they become available.
while not future.done():
# Here, `out_batch` is a one-row output batch
# from the async generator, corresponding to a
# single row from the input batch.
out_batch = output_batch_queue.get()
_validate_batch_output(out_batch)
yield out_batch

return transform_fn

Expand Down
34 changes: 34 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import itertools
import math
import os
Expand Down Expand Up @@ -1057,6 +1058,39 @@ def test_nonserializable_map_batches(shutdown_only):
x.map_batches(lambda _: lock).take(1)


def test_map_batches_async_generator(shutdown_only):
ray.shutdown()
ray.init(num_cpus=10)

async def sleep_and_yield(i):
print("sleep", i)
await asyncio.sleep(i % 5)
print("yield", i)
return {"input": [i], "output": [2**i]}

class AsyncActor:
def __init__(self):
pass

async def __call__(self, batch):
tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]]
for task in tasks:
yield await task

n = 10
ds = ray.data.range(n, override_num_blocks=2)
ds = ds.map(lambda x: x)
ds = ds.map_batches(AsyncActor, batch_size=1, concurrency=1, max_concurrency=2)

start_t = time.time()
output = ds.take_all()
runtime = time.time() - start_t
assert runtime < sum(range(n)), runtime

expected_output = [{"input": i, "output": 2**i} for i in range(n)]
assert output == expected_output, (output, expected_output)


if __name__ == "__main__":
import sys

Expand Down
Loading