Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Support async callable classes in map_batches() #46129

Merged
merged 9 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 114 additions & 48 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import asyncio
import collections
import inspect
import queue
from concurrent.futures import ThreadPoolExecutor
from types import GeneratorType
from typing import Any, Callable, Iterable, Iterator, List, Optional

Expand Down Expand Up @@ -104,23 +108,35 @@ def _parse_op_fn(op: AbstractUDFMap):
fn_constructor_args = op._fn_constructor_args or ()
fn_constructor_kwargs = op._fn_constructor_kwargs or {}

op_fn = make_callable_class_concurrent(op_fn)

def fn(item: Any) -> Any:
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == op_fn
try:
return ray.data._cached_fn(item, *fn_args, **fn_kwargs)
except Exception as e:
_handle_debugger_exception(e)

def init_fn():
if ray.data._cached_fn is None:
ray.data._cached_cls = op_fn
ray.data._cached_fn = op_fn(
*fn_constructor_args, **fn_constructor_kwargs
)

if inspect.isasyncgenfunction(op._fn.__call__):

async def fn(item: Any) -> Any:
assert ray.data._cached_fn is not None
assert ray.data._cached_cls == op_fn

try:
return ray.data._cached_fn(item, *fn_args, **fn_kwargs)
except Exception as e:
_handle_debugger_exception(e)

else:
op_fn = make_callable_class_concurrent(op_fn)

def fn(item: Any) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 fns are almost identical, except the assertion? better avoid duplicating the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that's what i initially thought, but one of them is async, and the other is not. is there a way to combine the inner implementation but create an async version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, nvm, I didn't notice the async prefix

assert ray.data._cached_fn is not None
assert ray.data._cached_cls == op_fn
try:
return ray.data._cached_fn(item, *fn_args, **fn_kwargs)
except Exception as e:
_handle_debugger_exception(e)

else:

def fn(item: Any) -> Any:
Expand Down Expand Up @@ -158,6 +174,7 @@ def _validate_batch_output(batch: Block) -> None:
np.ndarray,
collections.abc.Mapping,
pd.core.frame.DataFrame,
dict,
),
):
raise ValueError(
Expand Down Expand Up @@ -193,45 +210,94 @@ def _validate_batch_output(batch: Block) -> None:
def _generate_transform_fn_for_map_batches(
fn: UserDefinedFunction,
) -> MapTransformCallable[DataBatch, DataBatch]:
def transform_fn(
batches: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
for batch in batches:
try:
if (
not isinstance(batch, collections.abc.Mapping)
and BlockAccessor.for_block(batch).num_rows() == 0
):
# For empty input blocks, we directly ouptut them without
# calling the UDF.
# TODO(hchen): This workaround is because some all-to-all
# operators output empty blocks with no schema.
res = [batch]
else:
res = fn(batch)
if not isinstance(res, GeneratorType):
res = [res]
except ValueError as e:
read_only_msgs = [
"assignment destination is read-only",
"buffer source array is read-only",
]
err_msg = str(e)
if any(msg in err_msg for msg in read_only_msgs):
raise ValueError(
f"Batch mapper function {fn.__name__} tried to mutate a "
"zero-copy read-only batch. To be able to mutate the "
"batch, pass zero_copy_batch=False to map_batches(); "
"this will create a writable copy of the batch before "
"giving it to fn. To elide this copy, modify your mapper "
"function so it doesn't try to mutate its input."
) from e
if inspect.iscoroutinefunction(fn):
# UDF is a callable class with async generator `__call__` method.
def transform_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inline function is too long. let's define it as a util function.

input_iterable: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
# Use a queue to store results from async generator calls.
# In the main event loop, we will put results into this queue
# from async generator, and yield them from the queue as they
# become available.
result_queue = queue.Queue()

async def process_batch(batch: DataBatch):
output_batch_iterator = await fn(batch)
# As soon as results become available from the async generator,
# put them into the result queue so they can be yielded.
async for output_row in output_batch_iterator:
result_queue.put(output_row)

async def process_all_batches():
tasks = [asyncio.create_task(process_batch(x)) for x in input_iterable]
for task in asyncio.as_completed(tasks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as_completed doesn't preserve the order. Shouldn't use it when preserve_order is set.

await task
# Sentinel to indicate completion.
result_queue.put(None)

def run_event_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(process_all_batches())
loop.close()

# Start the event loop in a new thread
executor = ThreadPoolExecutor(max_workers=1)
executor.submit(run_event_loop)

# Yield results as they become available
while True:
# `out_batch` here is a one-row batch which contains
# output from the async generator, corresponding to a
# single row from the input batch.
out_batch = result_queue.get()
# Exit when sentinel is received.
if out_batch is None:
break
_validate_batch_output(out_batch)
yield out_batch

else:

def transform_fn(
batches: Iterable[DataBatch], _: TaskContext
) -> Iterable[DataBatch]:
for batch in batches:
try:
if (
not isinstance(batch, collections.abc.Mapping)
and BlockAccessor.for_block(batch).num_rows() == 0
):
# For empty input blocks, we directly ouptut them without
# calling the UDF.
# TODO(hchen): This workaround is because some all-to-all
# operators output empty blocks with no schema.
res = [batch]
else:
res = fn(batch)
if not isinstance(res, GeneratorType):
res = [res]
except ValueError as e:
read_only_msgs = [
"assignment destination is read-only",
"buffer source array is read-only",
]
err_msg = str(e)
if any(msg in err_msg for msg in read_only_msgs):
raise ValueError(
f"Batch mapper function {fn.__name__} tried to mutate a "
"zero-copy read-only batch. To be able to mutate the "
"batch, pass zero_copy_batch=False to map_batches(); "
"this will create a writable copy of the batch before "
"giving it to fn. To elide this copy, modify your mapper "
"function so it doesn't try to mutate its input."
) from e
else:
raise e from None
else:
raise e from None
else:
for out_batch in res:
_validate_batch_output(out_batch)
yield out_batch
for out_batch in res:
_validate_batch_output(out_batch)
yield out_batch

return transform_fn

Expand Down
29 changes: 29 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import itertools
import math
import os
Expand Down Expand Up @@ -700,6 +701,34 @@ def fail_generator(batch):
).take()


def test_map_batches_async_generator(ray_start_regular_shared):
async def sleep_and_yield(i):
await asyncio.sleep(i)
return {"input": [i], "output": [2**i]}

class AsyncActor:
def __init__(self):
pass

async def __call__(self, batch):
tasks = [sleep_and_yield(i) for i in batch["id"]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

results = await asyncio.gather(*tasks)
for result in results:
yield result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using gather, so we can yield results as soon as they become available.

tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]]
for task in tasks:
    yield await task


n = 5
ds = ray.data.range(n, override_num_blocks=1)
ds = ds.map_batches(AsyncActor, batch_size=None, concurrency=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add max_concurrency to test concurrently handling many batches.


start_t = time.time()
output = ds.take_all()
runtime = time.time() - start_t
assert runtime < sum(range(n)), runtime

expected_output = [{"input": i, "output": 2**i} for i in range(n)]
assert output == expected_output, (output, expected_output)


def test_map_batches_actors_preserves_order(shutdown_only):
class UDFClass:
def __call__(self, x):
Expand Down
Loading