diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a33a1c2aef71..aabdb80fd70d 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1419,6 +1419,7 @@ async def execute_streaming_generator_async( """ cdef: int64_t cur_generator_index = 0 + CRayStatus return_status assert context.is_initialized() # Generator task should only have 1 return object ref, @@ -1433,41 +1434,33 @@ async def execute_streaming_generator_async( executor = worker.core_worker.get_event_loop_executor() interrupt_signal_event = threading.Event() - futures = [] try: try: async for output in gen: - # NOTE: Reporting generator output in a streaming fashion, - # is done in a standalone thread-pool fully *asynchronously* - # to avoid blocking the event-loop and allow it to *concurrently* - # make progress, since serializing and actual RPC I/O is done - # with "nogil". - futures.append( - loop.run_in_executor( - executor, - report_streaming_generator_output, - context, - output, - cur_generator_index, - interrupt_signal_event, - ) - ) - cur_generator_index += 1 - except Exception as e: - # Report the exception to the owner of the task. - futures.append( - loop.run_in_executor( + # NOTE: Report of streaming generator output is done in a + # standalone thread-pool to avoid blocking the event loop, + # since serializing and actual RPC I/O is done with "nogil". We + # still wait for the report to finish to ensure that the task + # does not modify the output before we serialize it. + # + # Note that the RPC is sent asynchronously, and we do not wait + # for the reply here. The exception is if the user specified a + # backpressure threshold for the streaming generator, and we + # are currently under backpressure. Then we need to wait for an + # ack from the caller (the reply for a possibly previous report + # RPC) that they have consumed more ObjectRefs. + await loop.run_in_executor( executor, - report_streaming_generator_exception, + report_streaming_generator_output, context, - e, + output, cur_generator_index, interrupt_signal_event, ) - ) - - # Make sure all RPC I/O completes before returning - await asyncio.gather(*futures) + cur_generator_index += 1 + except Exception as e: + # Report the exception to the owner of the task. + report_streaming_generator_exception(context, e, cur_generator_index, None) except BaseException as be: # NOTE: PLEASE READ CAREFULLY BEFORE CHANGING @@ -1489,6 +1482,15 @@ async def execute_streaming_generator_async( raise + # The caller gets object values through the reports. If we finish the task + # before sending the report is complete, then we may fail before the report + # is sent to the caller. Then, the caller would never be able to ray.get + # the yield'ed ObjectRef. Therefore, we must wait for all in-flight object + # reports to complete before finishing the task. + with nogil: + return_status = context.waiter.get().WaitAllObjectsReported() + check_status(return_status) + cdef create_generator_return_obj( output,