Skip to content

Commit

Permalink
[core] Streaming generator executor waits for item report to complete…
Browse files Browse the repository at this point in the history
… before continuing (ray-project#44257)

ray-project#42260 updated streaming generator tasks to asynchronously report generator returns, instead of synchronously reporting each generator return before yielding the next return. However this has a couple problems:

    If the task still has a reference to the yielded value, it may modify the value. The serialized and reported return will then have a different value than expected.

    As per [core] Streaming generator task waits for all object report acks before finishing the task ray-project#44079, we need to track the number of in-flight RPCs to report generator returns, so that we can wait for them all to reply before we return from the end of the task. If we increment the count of in-flight RPCs asynchronously, we can end up returning from the task while there are still in-flight RPCs.

So this PR reverts some of the logic in ray-project#42260 to wait for the generator return to be serialized into the protobuf sent back to the caller. Note that we do not wait for the reply (unless under backpressure).

We can later re-introduce asynchronous generator reports, but we will need to evaluate the performance benefit of a new implementation that also addresses both of the above points.

---------

Signed-off-by: Stephanie Wang <[email protected]>
  • Loading branch information
stephanie-wang committed Mar 27, 2024
1 parent 6b3186b commit 99ba261
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,7 @@ async def execute_streaming_generator_async(
"""
cdef:
int64_t cur_generator_index = 0
CRayStatus return_status

assert context.is_initialized()
# Generator task should only have 1 return object ref,
Expand All @@ -1433,41 +1434,33 @@ async def execute_streaming_generator_async(
executor = worker.core_worker.get_event_loop_executor()
interrupt_signal_event = threading.Event()

futures = []
try:
try:
async for output in gen:
# NOTE: Reporting generator output in a streaming fashion,
# is done in a standalone thread-pool fully *asynchronously*
# to avoid blocking the event-loop and allow it to *concurrently*
# make progress, since serializing and actual RPC I/O is done
# with "nogil".
futures.append(
loop.run_in_executor(
executor,
report_streaming_generator_output,
context,
output,
cur_generator_index,
interrupt_signal_event,
)
)
cur_generator_index += 1
except Exception as e:
# Report the exception to the owner of the task.
futures.append(
loop.run_in_executor(
# NOTE: Report of streaming generator output is done in a
# standalone thread-pool to avoid blocking the event loop,
# since serializing and actual RPC I/O is done with "nogil". We
# still wait for the report to finish to ensure that the task
# does not modify the output before we serialize it.
#
# Note that the RPC is sent asynchronously, and we do not wait
# for the reply here. The exception is if the user specified a
# backpressure threshold for the streaming generator, and we
# are currently under backpressure. Then we need to wait for an
# ack from the caller (the reply for a possibly previous report
# RPC) that they have consumed more ObjectRefs.
await loop.run_in_executor(
executor,
report_streaming_generator_exception,
report_streaming_generator_output,
context,
e,
output,
cur_generator_index,
interrupt_signal_event,
)
)

# Make sure all RPC I/O completes before returning
await asyncio.gather(*futures)
cur_generator_index += 1
except Exception as e:
# Report the exception to the owner of the task.
report_streaming_generator_exception(context, e, cur_generator_index, None)

except BaseException as be:
# NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
Expand All @@ -1489,6 +1482,15 @@ async def execute_streaming_generator_async(

raise

# The caller gets object values through the reports. If we finish the task
# before sending the report is complete, then we may fail before the report
# is sent to the caller. Then, the caller would never be able to ray.get
# the yield'ed ObjectRef. Therefore, we must wait for all in-flight object
# reports to complete before finishing the task.
with nogil:
return_status = context.waiter.get().WaitAllObjectsReported()
check_status(return_status)


cdef create_generator_return_obj(
output,
Expand Down

0 comments on commit 99ba261

Please sign in to comment.