-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Support async callable classes in map_batches()
#46129
Conversation
Signed-off-by: Scott Lee <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized that we also need to set max_concurrency
to allow handling multiple batches a time.
Some additional changes needed for this PR:
- the event loop and the thread executor should outlive the transform_fn. they should be global singletons.
python/ray/data/tests/test_map.py
Outdated
tasks = [sleep_and_yield(i) for i in batch["id"]] | ||
results = await asyncio.gather(*tasks) | ||
for result in results: | ||
yield result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using gather, so we can yield results as soon as they become available.
tasks = [asyncio.create_task(sleep_and_yield(i)) for i in batch["id"]]
for task in tasks:
yield await task
|
||
async def process_all_batches(): | ||
tasks = [asyncio.create_task(process_batch(x)) for x in input_iterable] | ||
for task in asyncio.as_completed(tasks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as_completed
doesn't preserve the order. Shouldn't use it when preserve_order
is set.
Signed-off-by: Scott Lee <[email protected]>
Signed-off-by: Scott Lee <[email protected]>
Signed-off-by: Scott Lee <[email protected]>
Signed-off-by: Scott Lee <[email protected]>
Signed-off-by: Scott Lee <[email protected]>
python/ray/data/tests/test_map.py
Outdated
pass | ||
|
||
async def __call__(self, batch): | ||
tasks = [sleep_and_yield(i) for i in batch["id"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this line.
# Use the existing event loop to create and run | ||
# Tasks to process each batch | ||
loop = ray.data._cached_loop | ||
loop.run_until_complete(process_all_batches()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still running in the main thread. Thus cannot run multiple batches at the same time.
We should put loop.run_forever()
in a thread, and then call loop.call_soon_threadsafe
here.
for task in asyncio.as_completed(tasks): | ||
await task | ||
# Sentinel to indicate completion. | ||
output_batch_queue.put(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a special object, in case the UDF also returns None
res = [batch] | ||
if inspect.iscoroutinefunction(fn): | ||
# UDF is a callable class with async generator `__call__` method. | ||
def transform_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This inline function is too long. let's define it as a util function.
@@ -104,22 +107,45 @@ def _parse_op_fn(op: AbstractUDFMap): | |||
fn_constructor_args = op._fn_constructor_args or () | |||
fn_constructor_kwargs = op._fn_constructor_kwargs or {} | |||
|
|||
op_fn = make_callable_class_concurrent(op_fn) | |||
if inspect.isasyncgenfunction(op._fn.__call__): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a TODO that we should 1) support non-generator async functions; 2) make the entire map actor async.
python/ray/data/tests/test_map.py
Outdated
|
||
n = 5 | ||
ds = ray.data.range(n, override_num_blocks=1) | ||
ds = ds.map_batches(AsyncActor, batch_size=None, concurrency=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add max_concurrency to test concurrently handling many batches.
async def fn(item: Any) -> Any: | ||
assert ray.data._cached_fn is not None | ||
assert ray.data._cached_cls == op_fn | ||
assert ray.data._cached_loop is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make code cleaner, let's define a class (e.g. MapActorContext
) to capture all these variables
Signed-off-by: Scott Lee <[email protected]>
Signed-off-by: Scott Lee <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for some nits
loop.run_forever() | ||
|
||
thread = Thread(target=run_loop) | ||
thread.start() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, maybe move the above to MapActorContext
to make the code cleaner here.
We can pass in a boolean flag here _MapActorContext(..., is_asyncio=True)
.
Then the code for sync and async branches can be consolidated.
cached_cls: UserDefinedFunction, | ||
cached_fn: Callable[[Any], Any], | ||
cached_loop: Optional[asyncio.AbstractEventLoop] = None, | ||
cached_asyncio_thread: Optional[Thread] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the cached_
prefixes?
for task in asyncio.as_completed(tasks): | ||
await task | ||
# Sentinel to indicate completion. | ||
output_batch_queue.put(OutputQueueSentinel()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OutputQueueSentinel isn't needed now because of while not future.done()
python/ray/data/tests/test_map.py
Outdated
pass | ||
|
||
async def __call__(self, batch): | ||
tasks = [sleep_and_yield(i) for i in batch["id"]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the catch, i had thought i removed it but somehow ended up back again...
|
||
else: | ||
|
||
def fn(item: Any) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these 2 fn
s are almost identical, except the assertion? better avoid duplicating the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah that's what i initially thought, but one of them is async, and the other is not. is there a way to combine the inner implementation but create an async version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, nvm, I didn't notice the async prefix
Why are these changes needed?
Add support for passing CallableClass with asynchronous generator
__call__
method toDataset.map_batches()
API. This is useful for streaming outputs from asynchronous generators as they become available to maximize throughput.Related issue number
Closes #46235
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.