-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Support async callable classes in map_batches()
#46129
Changes from 1 commit
599be0b
bdd812e
dec32a1
13b980f
4292b93
de61a0d
d217c1c
9175edd
c5af4ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
import asyncio | ||
import collections | ||
import inspect | ||
import queue | ||
from concurrent.futures import ThreadPoolExecutor | ||
from types import GeneratorType | ||
from typing import Any, Callable, Iterable, Iterator, List, Optional | ||
|
||
|
@@ -104,23 +108,35 @@ def _parse_op_fn(op: AbstractUDFMap): | |
fn_constructor_args = op._fn_constructor_args or () | ||
fn_constructor_kwargs = op._fn_constructor_kwargs or {} | ||
|
||
op_fn = make_callable_class_concurrent(op_fn) | ||
|
||
def fn(item: Any) -> Any: | ||
assert ray.data._cached_fn is not None | ||
assert ray.data._cached_cls == op_fn | ||
try: | ||
return ray.data._cached_fn(item, *fn_args, **fn_kwargs) | ||
except Exception as e: | ||
_handle_debugger_exception(e) | ||
|
||
def init_fn(): | ||
if ray.data._cached_fn is None: | ||
ray.data._cached_cls = op_fn | ||
ray.data._cached_fn = op_fn( | ||
*fn_constructor_args, **fn_constructor_kwargs | ||
) | ||
|
||
if inspect.isasyncgenfunction(op._fn.__call__): | ||
|
||
async def fn(item: Any) -> Any: | ||
assert ray.data._cached_fn is not None | ||
assert ray.data._cached_cls == op_fn | ||
|
||
try: | ||
return ray.data._cached_fn(item, *fn_args, **fn_kwargs) | ||
except Exception as e: | ||
_handle_debugger_exception(e) | ||
|
||
else: | ||
op_fn = make_callable_class_concurrent(op_fn) | ||
|
||
def fn(item: Any) -> Any: | ||
assert ray.data._cached_fn is not None | ||
assert ray.data._cached_cls == op_fn | ||
try: | ||
return ray.data._cached_fn(item, *fn_args, **fn_kwargs) | ||
except Exception as e: | ||
_handle_debugger_exception(e) | ||
|
||
else: | ||
|
||
def fn(item: Any) -> Any: | ||
|
@@ -158,6 +174,7 @@ def _validate_batch_output(batch: Block) -> None: | |
np.ndarray, | ||
collections.abc.Mapping, | ||
pd.core.frame.DataFrame, | ||
dict, | ||
), | ||
): | ||
raise ValueError( | ||
|
@@ -193,45 +210,94 @@ def _validate_batch_output(batch: Block) -> None: | |
def _generate_transform_fn_for_map_batches( | ||
fn: UserDefinedFunction, | ||
) -> MapTransformCallable[DataBatch, DataBatch]: | ||
def transform_fn( | ||
batches: Iterable[DataBatch], _: TaskContext | ||
) -> Iterable[DataBatch]: | ||
for batch in batches: | ||
try: | ||
if ( | ||
not isinstance(batch, collections.abc.Mapping) | ||
and BlockAccessor.for_block(batch).num_rows() == 0 | ||
): | ||
# For empty input blocks, we directly ouptut them without | ||
# calling the UDF. | ||
# TODO(hchen): This workaround is because some all-to-all | ||
# operators output empty blocks with no schema. | ||
res = [batch] | ||
else: | ||
res = fn(batch) | ||
if not isinstance(res, GeneratorType): | ||
res = [res] | ||
except ValueError as e: | ||
read_only_msgs = [ | ||
"assignment destination is read-only", | ||
"buffer source array is read-only", | ||
] | ||
err_msg = str(e) | ||
if any(msg in err_msg for msg in read_only_msgs): | ||
raise ValueError( | ||
f"Batch mapper function {fn.__name__} tried to mutate a " | ||
"zero-copy read-only batch. To be able to mutate the " | ||
"batch, pass zero_copy_batch=False to map_batches(); " | ||
"this will create a writable copy of the batch before " | ||
"giving it to fn. To elide this copy, modify your mapper " | ||
"function so it doesn't try to mutate its input." | ||
) from e | ||
if inspect.iscoroutinefunction(fn): | ||
# UDF is a callable class with async generator `__call__` method. | ||
def transform_fn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This inline function is too long. let's define it as a util function. |
||
input_iterable: Iterable[DataBatch], _: TaskContext | ||
) -> Iterable[DataBatch]: | ||
# Use a queue to store results from async generator calls. | ||
# In the main event loop, we will put results into this queue | ||
# from async generator, and yield them from the queue as they | ||
# become available. | ||
result_queue = queue.Queue() | ||
|
||
async def process_batch(batch: DataBatch): | ||
output_batch_iterator = await fn(batch) | ||
# As soon as results become available from the async generator, | ||
# put them into the result queue so they can be yielded. | ||
async for output_row in output_batch_iterator: | ||
result_queue.put(output_row) | ||
|
||
async def process_all_batches(): | ||
tasks = [asyncio.create_task(process_batch(x)) for x in input_iterable] | ||
for task in asyncio.as_completed(tasks): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
await task | ||
# Sentinel to indicate completion. | ||
result_queue.put(None) | ||
|
||
def run_event_loop(): | ||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
loop.run_until_complete(process_all_batches()) | ||
loop.close() | ||
|
||
# Start the event loop in a new thread | ||
executor = ThreadPoolExecutor(max_workers=1) | ||
executor.submit(run_event_loop) | ||
|
||
# Yield results as they become available | ||
while True: | ||
# `out_batch` here is a one-row batch which contains | ||
# output from the async generator, corresponding to a | ||
# single row from the input batch. | ||
out_batch = result_queue.get() | ||
# Exit when sentinel is received. | ||
if out_batch is None: | ||
break | ||
_validate_batch_output(out_batch) | ||
yield out_batch | ||
|
||
else: | ||
|
||
def transform_fn( | ||
batches: Iterable[DataBatch], _: TaskContext | ||
) -> Iterable[DataBatch]: | ||
for batch in batches: | ||
try: | ||
if ( | ||
not isinstance(batch, collections.abc.Mapping) | ||
and BlockAccessor.for_block(batch).num_rows() == 0 | ||
): | ||
# For empty input blocks, we directly ouptut them without | ||
# calling the UDF. | ||
# TODO(hchen): This workaround is because some all-to-all | ||
# operators output empty blocks with no schema. | ||
res = [batch] | ||
else: | ||
res = fn(batch) | ||
if not isinstance(res, GeneratorType): | ||
res = [res] | ||
except ValueError as e: | ||
read_only_msgs = [ | ||
"assignment destination is read-only", | ||
"buffer source array is read-only", | ||
] | ||
err_msg = str(e) | ||
if any(msg in err_msg for msg in read_only_msgs): | ||
raise ValueError( | ||
f"Batch mapper function {fn.__name__} tried to mutate a " | ||
"zero-copy read-only batch. To be able to mutate the " | ||
"batch, pass zero_copy_batch=False to map_batches(); " | ||
"this will create a writable copy of the batch before " | ||
"giving it to fn. To elide this copy, modify your mapper " | ||
"function so it doesn't try to mutate its input." | ||
) from e | ||
else: | ||
raise e from None | ||
else: | ||
raise e from None | ||
else: | ||
for out_batch in res: | ||
_validate_batch_output(out_batch) | ||
yield out_batch | ||
for out_batch in res: | ||
_validate_batch_output(out_batch) | ||
yield out_batch | ||
|
||
return transform_fn | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import asyncio | ||
import itertools | ||
import math | ||
import os | ||
|
@@ -700,6 +701,34 @@ def fail_generator(batch): | |
).take() | ||
|
||
|
||
def test_map_batches_async_generator(ray_start_regular_shared): | ||
async def sleep_and_yield(i): | ||
await asyncio.sleep(i) | ||
return {"input": [i], "output": [2**i]} | ||
|
||
class AsyncActor: | ||
def __init__(self): | ||
pass | ||
|
||
async def __call__(self, batch): | ||
tasks = [sleep_and_yield(i) for i in batch["id"]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this line. |
||
results = await asyncio.gather(*tasks) | ||
for result in results: | ||
yield result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid using gather, so we can yield results as soon as they become available.
|
||
|
||
n = 5 | ||
ds = ray.data.range(n, override_num_blocks=1) | ||
ds = ds.map_batches(AsyncActor, batch_size=None, concurrency=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add max_concurrency to test concurrently handling many batches. |
||
|
||
start_t = time.time() | ||
output = ds.take_all() | ||
runtime = time.time() - start_t | ||
assert runtime < sum(range(n)), runtime | ||
|
||
expected_output = [{"input": i, "output": 2**i} for i in range(n)] | ||
assert output == expected_output, (output, expected_output) | ||
|
||
|
||
def test_map_batches_actors_preserves_order(shutdown_only): | ||
class UDFClass: | ||
def __call__(self, x): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these 2
fn
s are almost identical, except the assertion? better avoid duplicating the code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah that's what i initially thought, but one of them is async, and the other is not. is there a way to combine the inner implementation but create an async version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, nvm, I didn't notice the async prefix