From 452ed1f6e7e168c34a50a83e82635453d8faf6e2 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 12 May 2023 06:19:11 -0700 Subject: [PATCH 01/77] initial version Signed-off-by: SangBin Cho --- .../runtime/task/local_mode_task_submitter.cc | 1 + python/ray/_private/worker.py | 4 + python/ray/_private/workers/default_worker.py | 1 - python/ray/_raylet.pxd | 1 + python/ray/_raylet.pyx | 301 +++++++++++++++++- python/ray/actor.py | 4 + python/ray/exceptions.py | 4 + python/ray/includes/libcoreworker.pxd | 14 +- python/ray/remote_function.py | 4 + python/ray/tests/BUILD | 1 + python/ray/tests/test_streaming_generator.py | 229 +++++++++++++ src/ray/common/task/task_spec.cc | 4 + src/ray/common/task/task_spec.h | 2 + src/ray/common/task/task_util.h | 2 + src/ray/core_worker/common.cc | 30 ++ src/ray/core_worker/common.h | 5 + src/ray/core_worker/core_worker.cc | 93 +++++- src/ray/core_worker/core_worker.h | 26 +- src/ray/core_worker/core_worker_options.h | 5 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 3 +- src/ray/core_worker/reference_count.cc | 29 ++ src/ray/core_worker/reference_count.h | 4 + src/ray/core_worker/task_manager.cc | 109 ++++++- src/ray/core_worker/task_manager.h | 52 ++- src/ray/core_worker/test/core_worker_test.cc | 1 + .../test/dependency_resolver_test.cc | 1 + .../test/direct_task_transport_test.cc | 1 + src/ray/core_worker/test/mock_worker.cc | 3 +- .../transport/direct_actor_transport.cc | 30 -- .../transport/direct_actor_transport.h | 1 + src/ray/gcs/test/gcs_test_util.h | 1 + src/ray/protobuf/common.proto | 22 ++ src/ray/protobuf/core_worker.proto | 28 +- .../scheduling/cluster_task_manager_test.cc | 1 + src/ray/rpc/worker/core_worker_client.h | 10 + src/ray/rpc/worker/core_worker_server.h | 3 + 36 files changed, 943 insertions(+), 87 deletions(-) create mode 100644 python/ray/tests/test_streaming_generator.py diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 145e8130fe15..8e82b06e1eaa 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -61,6 +61,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, address, 1, /*returns_dynamic=*/false, + /*is_streaming_generator*/false, required_resources, required_placement_resources, "", diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index a9b81d672fb3..7fe3db7d0a6f 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2810,6 +2810,10 @@ def cancel(object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool worker = ray._private.worker.global_worker worker.check_connected() + if isinstance(object_ref, ray._raylet.StreamingObjectRefGeneratorV2): + assert hasattr(object_ref, "_generator_ref") + object_ref = object_ref._generator_ref + if not isinstance(object_ref, ray.ObjectRef): raise TypeError( "ray.cancel() only supported for non-actor object refs. " diff --git a/python/ray/_private/workers/default_worker.py b/python/ray/_private/workers/default_worker.py index 937f45a8b85d..462c9e284f49 100644 --- a/python/ray/_private/workers/default_worker.py +++ b/python/ray/_private/workers/default_worker.py @@ -169,7 +169,6 @@ # https://github.com/ray-project/ray/pull/12225#issue-525059663. args = parser.parse_args() ray._private.ray_logging.setup_logger(args.logging_level, args.logging_format) - worker_launched_time_ms = time.time_ns() // 1e6 if args.worker_type == "WORKER": diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 6af1879a5d8a..28a7632ed8c1 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -143,6 +143,7 @@ cdef class CoreWorker: self, worker, outputs, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + const CAddress &caller_address, CObjectID ref_generator_id=*) cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 5b135b35d419..8bf936f6497c 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -131,6 +131,7 @@ from ray.exceptions import ( AsyncioActorExit, PendingCallsLimitExceeded, RpcError, + RayKeyError, ) from ray._private import external_storage from ray.util.scheduling_strategies import ( @@ -195,6 +196,63 @@ class ObjectRefGenerator: return len(self._refs) +class StreamingObjectRefGeneratorV2: + def __init__(self, generator_ref): + self._generator_ref = generator_ref + self._generator_task_completed_time = None + self._generator_task_exception = None + + def __iter__(self): + return self + + def __next__(self): + core_worker = ray._private.worker.global_worker.core_worker + obj = self._handle_next() + while obj.is_nil(): + if self._generator_task_exception: + # The generator task has failed. We raise StopIteration + # to conform the next interface in Python. + raise StopIteration + else: + # Otherwise, check the task status. + r, _ = ray.wait([self._generator_ref], timeout=0) + if len(r) > 0: + try: + ray.get(r) + except Exception as e: + # If it has failed, return the generator task ref + # so that the ref will raise an exception. + self._generator_task_exception = e + return self._generator_ref + finally: + if self._generator_task_completed_time is None: + self._generator_task_completed_time = time.time() + + if self._generator_task_completed_time: + if time.time() - self._generator_task_completed_time > 30: + # It means the next wasn't reported although the task + # has been terminated 30 seconds ago. + assert False, "Unexpected network failure occured." + + + time.sleep(0.001) + obj = self._handle_next() + return obj + + def _handle_next(self): + try: + core_worker = ray._private.worker.global_worker.core_worker + obj = core_worker.generator_get_next(self._generator_ref) + return obj + except RayKeyError: + raise StopIteration + + def __del__(self): + worker = ray._private.worker.global_worker + if hasattr(worker, "core_worker"): + worker.core_worker.generator_del(self._generator_ref) + + cdef int check_status(const CRayStatus& status) nogil except -1: if status.ok(): return 0 @@ -206,6 +264,9 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsOutOfDisk(): raise OutOfDiskError(message) + # SANG-TODO Use a different error NotFound + elif status.IsKeyError(): + raise RayKeyError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() elif status.IsTimedOut(): @@ -597,7 +658,7 @@ cdef store_task_errors( proctitle, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, c_string* application_error, - ): + const CAddress &caller_address): cdef: CoreWorker core_worker = worker.core_worker @@ -641,7 +702,8 @@ cdef store_task_errors( errors.append(failure_object) num_errors_stored = core_worker.store_task_outputs( worker, errors, - returns) + returns, + caller_address) ray._private.utils.push_error_to_driver( worker, @@ -652,6 +714,154 @@ cdef store_task_errors( raise RayActorError.from_task_error(failure_object) return num_errors_stored + +cdef execute_streaming_generator( + generator, + const CObjectID &generator_id, + CTaskType task_type, + const CAddress &caller_address, + TaskID task_id, + const c_string &serialized_retry_exception_allowlist, + function_name, + function_descriptor, + title, + actor, + actor_id, + c_bool *is_retryable_error, + c_string *application_error): + """Execute a given generator and streaming-report the + result to the given caller_address. + + The output from the generator will be stored to the in-memory + or plasma object store. The generated return objects will be + reported to the owner of the task as soon as they are generated. + + It means when this method is used, the result of each generator + will be reported and available from the given "caller address" + before the task is finished. + + Args: + generator: The generator to run. + generator_id: The object ref id of the generator task. + task_type: The type of the task. E.g., actor task, normal task. + caller_address: The address of the caller. By our protocol, + the caller of the streaming generator task is always + the owner, so we can also call it "owner address". + task_id: The task ID of the generator task. + serialized_retry_exception_allowlist: A list of + exceptions that are allowed to retry this generator task. + function_name: The name of the generator function. Used for + writing an error message. + function_descriptor: The function descriptor of + the generator function. Used for writing an error message. + title: The process title of the generator task. Used for + writing an error message. + actor: The instance of the actor created in this worker. + It is used to write an error message. + actor_id: The ID of the actor. It is used to write an error message. + is_retryable_error(out): It is set to True if the generator + raises an exception, and the error is retryable. + application_error(out): It is set if the generator raises an + application error. + """ + worker = ray._private.worker.global_worker + cdef: + CoreWorker core_worker = worker.core_worker + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result + + generator_index = 0 + assert inspect.isgenerator(generator), ( + "execute_generator's first argument must be a generator." + ) + + while True: + try: + output = next(generator) + except StopIteration: + break + except Exception as e: + # Report the error if the generator failed to execute. + is_retryable_error[0] = determine_if_retryable( + e, + serialized_retry_exception_allowlist, + function_descriptor, + ) + + if ( + is_retryable_error[0] + and core_worker.get_current_task_retry_exceptions() + ): + logger.debug("Task failed with retryable exception:" + " {}.".format( + task_id), + exc_info=True) + # Raise an exception directly and halt the execution + # because there's no need to set the exception + # for the return value when the task is retryable. + raise e + + logger.debug("Task failed with unretryable exception:" + " {}.".format( + task_id), + exc_info=True) + + error_id = (CCoreWorkerProcess.GetCoreWorker() + .AllocateDynamicReturnId(caller_address)) + intermediate_result.push_back( + c_pair[CObjectID, shared_ptr[CRayObject]]( + error_id, shared_ptr[CRayObject]())) + + store_task_errors( + worker, e, + True, # task_exception + actor, # actor + actor_id, # actor id + function_name, task_type, title, + &intermediate_result, application_error, caller_address) + + CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( + intermediate_result.back(), + generator_id, caller_address, generator_index, False) + break + else: + # Report the intermediate result if there was no error. + return_id = (CCoreWorkerProcess.GetCoreWorker() + .AllocateDynamicReturnId(caller_address)) + intermediate_result.push_back( + c_pair[CObjectID, shared_ptr[CRayObject]]( + return_id, shared_ptr[CRayObject]())) + + core_worker.store_task_outputs( + worker, [output], + &intermediate_result, + caller_address, + generator_id) + # print("SANG-TODO Writes an index ", i) + assert intermediate_result.size() == 1 + del output + + CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( + intermediate_result.back(), + generator_id, + caller_address, + generator_index, + False) + finally: + if intermediate_result.size() > 0: + intermediate_result.pop_back() + generator_index += 1 + + # Close it. + # SANG-TODO Implement the close API. + # print("SANG-TODO Closes an index ", i) + CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( + c_pair[CObjectID, shared_ptr[CRayObject]](CObjectID.Nil(), shared_ptr[CRayObject]()), + generator_id, + caller_address, + generator_index, + True) + + cdef execute_dynamic_generator_and_store_task_outputs( generator, const CObjectID &generator_id, @@ -663,7 +873,8 @@ cdef execute_dynamic_generator_and_store_task_outputs( c_bool is_reattempt, function_name, function_descriptor, - title): + title, + const CAddress &caller_address): worker = ray._private.worker.global_worker cdef: CoreWorker core_worker = worker.core_worker @@ -672,6 +883,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( core_worker.store_task_outputs( worker, generator, dynamic_returns, + caller_address, generator_id) except Exception as error: is_retryable_error[0] = determine_if_retryable( @@ -699,7 +911,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( # generate one additional ObjectRef. This last # ObjectRef will contain the error. error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId()) + .AllocateDynamicReturnId(caller_address)) dynamic_returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -713,7 +925,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( None, # actor None, # actor id function_name, task_type, title, - dynamic_returns, application_error) + dynamic_returns, application_error, caller_address) if num_errors_stored == 0: assert is_reattempt # TODO(swang): The generator task failed and we @@ -750,7 +962,8 @@ cdef void execute_task( c_bool is_reattempt, execution_info, title, - task_name) except *: + task_name, + c_bool is_streaming_generator) except *: worker = ray._private.worker.global_worker manager = worker.function_actor_manager actor = None @@ -899,6 +1112,35 @@ cdef void execute_task( ray.util.pdb.set_trace( breakpoint_uuid=debugger_breakpoint) outputs = function_executor(*args, **kwargs) + + if is_streaming_generator: + # Streaming generator always has a single return value + # which is the generator task return. + assert returns[0].size() == 1 + + if not inspect.isgenerator(outputs): + raise ValueError( + "Functions with " + "@ray.remote(num_returns=\"streaming\" must return a " + "generator") + + execute_streaming_generator( + outputs, + returns[0][0].first, # generator object ID. + task_type, + caller_address, + task_id, + serialized_retry_exception_allowlist, + function_name, + function_descriptor, + title, + actor, + actor_id, + is_retryable_error, + application_error) + # Streaming generator output is not used, so set it to None. + outputs = None + next_breakpoint = ( ray._private.worker.global_worker.debugger_breakpoint) if next_breakpoint != b"": @@ -979,6 +1221,8 @@ cdef void execute_task( # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): num_returns = returns[0].size() + # TODO(sang): Remove it once we use streaming generator + # by default. if dynamic_returns != NULL: if not inspect.isgenerator(outputs): raise ValueError( @@ -998,7 +1242,8 @@ cdef void execute_task( is_reattempt, function_name, function_descriptor, - title) + title, + caller_address) task_exception = False dynamic_refs = [] @@ -1016,11 +1261,12 @@ cdef void execute_task( # all generator tasks, both static and dynamic. core_worker.store_task_outputs( worker, outputs, - returns) + returns, + caller_address) except Exception as e: num_errors_stored = store_task_errors( worker, e, task_exception, actor, actor_id, function_name, - task_type, title, returns, application_error) + task_type, title, returns, application_error, caller_address) if returns[0].size() > 0 and num_errors_stored == 0: logger.exception( "Unhandled error: Task threw exception, but all " @@ -1047,7 +1293,8 @@ cdef execute_task_with_cancellation_handler( # the concurrency groups of this actor. const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups, const c_string c_name_of_concurrency_group_to_execute, - c_bool is_reattempt): + c_bool is_reattempt, + c_bool is_streaming_generator): is_retryable_error[0] = False @@ -1132,7 +1379,8 @@ cdef execute_task_with_cancellation_handler( application_error, c_defined_concurrency_groups, c_name_of_concurrency_group_to_execute, - is_reattempt, execution_info, title, task_name) + is_reattempt, execution_info, title, task_name, + is_streaming_generator) # Check for cancellation. PyErr_CheckSignals() @@ -1159,7 +1407,8 @@ cdef execute_task_with_cancellation_handler( task_type, title, returns, # application_error: we are passing NULL since we don't want the # cancel tasks to fail. - NULL) + NULL, + caller_address) finally: with current_task_id_lock: current_task_id = None @@ -1204,7 +1453,8 @@ cdef CRayStatus task_execution_handler( c_string *application_error, const c_vector[CConcurrencyGroup] &defined_concurrency_groups, const c_string name_of_concurrency_group_to_execute, - c_bool is_reattempt) nogil: + c_bool is_reattempt, + c_bool is_streaming_generator) nogil: with gil, disable_client_hook(): # Initialize job_config if it hasn't already. # Setup system paths configured in job_config. @@ -1228,7 +1478,8 @@ cdef CRayStatus task_execution_handler( application_error, defined_concurrency_groups, name_of_concurrency_group_to_execute, - is_reattempt) + is_reattempt, + is_streaming_generator) except Exception as e: sys_exit = SystemExit() if isinstance(e, RayActorError) and \ @@ -2722,6 +2973,7 @@ cdef class CoreWorker: worker, outputs, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + const CAddress &caller_address, CObjectID ref_generator_id=CObjectID.Nil()): cdef: CObjectID return_id @@ -2763,7 +3015,7 @@ cdef class CoreWorker: num_returns)) while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId()) + .AllocateDynamicReturnId(caller_address)) returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -3035,6 +3287,25 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) + def generator_del(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + + CCoreWorkerProcess.GetCoreWorker().DelGenerator(c_generator_id) + + def generator_get_next(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + CObjectReference c_object_ref + + check_status(CCoreWorkerProcess.GetCoreWorker().GetNextObjectRef(c_generator_id, &c_object_ref)) + return ObjectRef( + c_object_ref.object_id(), + c_object_ref.owner_address().SerializeAsString(), + "", + # Already added when the ref is updated. + skip_adding_local_ref=True) + cdef void async_callback(shared_ptr[CRayObject] obj, CObjectID object_ref, void *user_callback) with gil: diff --git a/python/ray/actor.py b/python/ray/actor.py index 7191031e059b..27fcf8c05a16 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -1167,6 +1167,10 @@ def _actor_method_call( if num_returns == "dynamic": num_returns = -1 + elif num_returns == "streaming": + # TODO(sang): This is a temporary private API. + # Remove it when we migrate to the streaming generator. + num_returns = -2 object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 276acfd372c6..945661d0a96e 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -336,6 +336,10 @@ def __str__(self): return error_msg +class RayKeyError(RayError): + pass + + @PublicAPI class ObjectStoreFullError(RayError): """Indicates that the object store is full. diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index f1763aa89b35..026ba9f57ae6 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -147,7 +147,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: shared_ptr[CRayObject] *return_object, const CObjectID& generator_id ) - CObjectID AllocateDynamicReturnId() + void DelGenerator(const CObjectID &generator_id) + CRayStatus GetNextObjectRef(const CObjectID &generator_id, CObjectReference *object_ref_out) + CObjectID AllocateDynamicReturnId(const CAddress &owner_address) + CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() @@ -235,6 +238,12 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t timeout_ms, c_vector[shared_ptr[CObjectLocation]] *results) CRayStatus TriggerGlobalGC() + CRayStatus ObjectRefStreamWrite( + const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object, + const CObjectID &generator_id, + const CAddress &caller_address, + int64_t idx, + c_bool finished) c_string MemoryUsageString() CWorkerContext &GetWorkerContext() @@ -304,7 +313,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_string *application_error, const c_vector[CConcurrencyGroup] &defined_concurrency_groups, const c_string name_of_concurrency_group_to_execute, - c_bool is_reattempt) nogil + c_bool is_reattempt, + c_bool is_streaming_generator) nogil ) task_execution_callback (void(const CWorkerID &) nogil) on_worker_shutdown (CRayStatus() nogil) check_signals diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 2d1162b6ce33..3b36800344fc 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -306,6 +306,10 @@ def _remote(self, args=None, kwargs=None, **task_options): num_returns = task_options["num_returns"] if num_returns == "dynamic": num_returns = -1 + elif num_returns == "streaming": + # TODO(sang): This is a temporary private API. + # Remove it when we migrate to the streaming generator. + num_returns = -2 max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] if isinstance(retry_exceptions, (list, tuple)): diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 34c239ceb9c7..f07618222f94 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -46,6 +46,7 @@ py_test_module_list( "test_gcs_fault_tolerance.py", "test_gcs_utils.py", "test_generators.py", + "test_streaming_generators.py", "test_metrics_agent.py", "test_metrics_head.py", "test_component_failures_2.py", diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py new file mode 100644 index 000000000000..5f81a737b58f --- /dev/null +++ b/python/ray/tests/test_streaming_generator.py @@ -0,0 +1,229 @@ +import asyncio +import pytest +import numpy as np +import sys +import time + +import ray +from ray.util.client.ray_client_helpers import ( + ray_start_client_server_for_address, +) +from ray._private.client_mode_hook import enable_client_mode +from ray.tests.conftest import call_ray_start_context +from ray._private.test_utils import wait_for_condition +from ray.experimental.state.api import list_tasks +from ray._raylet import StreamingObjectRefGeneratorV2 + + +def test_generator_basic(shutdown_only): + ray.init(num_cpus=1) + + """Basic cases""" + @ray.remote + def f(): + for i in range(5): + yield i + + gen = f.options(num_returns="streaming").remote() + i = 0 + for ref in gen: + print(ray.get(ref)) + assert i == ray.get(ref) + del ref + i += 1 + + """Exceptions""" + + @ray.remote + def f(): + for i in range(5): + if i == 2: + raise ValueError + yield i + + gen = f.options(num_returns="streaming").remote() + ray.get(next(gen)) + ray.get(next(gen)) + with pytest.raises(ray.exceptions.RayTaskError) as e: + ray.get(next(gen)) + print(str(e.value)) + with pytest.raises(StopIteration): + ray.get(next(gen)) + with pytest.raises(StopIteration): + ray.get(next(gen)) + + """Generator Task failure""" + + @ray.remote + class A: + def getpid(self): + import os + + return os.getpid() + + def f(self): + for i in range(5): + import time + + time.sleep(0.1) + yield i + + a = A.remote() + i = 0 + gen = a.f.options(num_returns="streaming").remote() + i = 0 + for ref in gen: + if i == 2: + ray.kill(a) + if i == 3: + with pytest.raises(ray.exceptions.RayActorError) as e: + ray.get(ref) + assert "The actor is dead because it was killed by `ray.kill`" in str( + e.value + ) + break + assert i == ray.get(ref) + del ref + i += 1 + for _ in range(10): + with pytest.raises(StopIteration): + next(gen) + + """Retry exceptions""" + # TODO(sang): Enable it + # @ray.remote + # class Actor: + # def __init__(self): + # self.should_kill = True + + # def should_kill(self): + # return self.should_kill + + # async def set(self, wait_s): + # await asyncio.sleep(wait_s) + # self.should_kill = False + + # @ray.remote(retry_exceptions=[ValueError], max_retries=10) + # def f(a): + # for i in range(5): + # should_kill = ray.get(a.should_kill.remote()) + # if i == 3 and should_kill: + # raise ValueError + # yield i + + # a = Actor.remote() + # gen = f.options(num_returns="streaming").remote(a) + # assert ray.get(next(gen)) == 0 + # assert ray.get(next(gen)) == 1 + # assert ray.get(next(gen)) == 2 + # a.set.remote(3) + # assert ray.get(next(gen)) == 3 + # assert ray.get(next(gen)) == 4 + # with pytest.raises(StopIteration): + # ray.get(next(gen)) + + """Cancel""" + + @ray.remote + def f(): + for i in range(5): + time.sleep(5) + yield i + + gen = f.options(num_returns="streaming").remote() + assert ray.get(next(gen)) == 0 + ray.cancel(gen) + with pytest.raises(ray.exceptions.RayTaskError) as e: + assert ray.get(next(gen)) == 1 + assert "was cancelled" in str(e.value) + with pytest.raises(StopIteration): + next(gen) + + +@pytest.mark.parametrize("use_actors", [False, True]) +@pytest.mark.parametrize("store_in_plasma", [False, True]) +def test_generator_streaming(shutdown_only, use_actors, store_in_plasma): + """Verify the generator is working in a streaming fashion.""" + ray.init() + remote_generator_fn = None + if use_actors: + + @ray.remote + class Generator: + def __init__(self): + pass + + def generator(self, num_returns, store_in_plasma): + for i in range(num_returns): + if store_in_plasma: + yield np.ones(1_000_000, dtype=np.int8) * i + else: + yield [i] + + g = Generator.remote() + remote_generator_fn = g.generator + else: + + @ray.remote(max_retries=0) + def generator(num_returns, store_in_plasma): + for i in range(num_returns): + if store_in_plasma: + yield np.ones(1_000_000, dtype=np.int8) * i + else: + yield [i] + + remote_generator_fn = generator + + """Verify num_returns="streaming" is streaming""" + gen = remote_generator_fn.options(num_returns="streaming").remote(3, store_in_plasma) + for ref in gen: + id = ref.hex() + print(ray.get(ref)) + del ref + from ray.experimental.state.api import list_objects + + wait_for_condition( + lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 + ) + + +def test_generator_dist_chain(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) + ray.init() + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + + @ray.remote + class ChainActor: + def __init__(self, child=None): + self.child = child + + def get_data(self): + if not self.child: + for _ in range(10): + time.sleep(0.1) + yield np.ones(5 * 1024 * 1024) + else: + for data in self.child.get_data.options(num_returns="streaming").remote(): + yield ray.get(data) + + chain_actor = ChainActor.remote() + chain_actor_2 = ChainActor.remote(chain_actor) + chain_actor_3 = ChainActor.remote(chain_actor_2) + chain_actor_4 = ChainActor.remote(chain_actor_3) + + for ref in chain_actor_4.get_data.options(num_returns="streaming").remote(): + assert np.array_equal(np.ones(5 * 1024 * 1024), ray.get(ref)) + del ref + + +if __name__ == "__main__": + import os + + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 71000748cb44..d28a6f671334 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -218,6 +218,10 @@ ObjectID TaskSpecification::ReturnId(size_t return_index) const { bool TaskSpecification::ReturnsDynamic() const { return message_->returns_dynamic(); } +// TODO(sang): Merge this with ReturnsDynamic once migrating to the +// streaming generator. +bool TaskSpecification::IsStreamingGenerator() const { return message_->streaming_generator(); } + std::vector TaskSpecification::DynamicReturnIds() const { RAY_CHECK(message_->returns_dynamic()); std::vector dynamic_return_ids; diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 3b29d2aadb3b..eea53f3d0348 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -262,6 +262,8 @@ class TaskSpecification : public MessageWrapper { bool ReturnsDynamic() const; + bool IsStreamingGenerator() const; + std::vector DynamicReturnIds() const; void AddDynamicReturnId(const ObjectID &dynamic_return_id); diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c260745b7161..1110504ea0b5 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -126,6 +126,7 @@ class TaskSpecBuilder { const rpc::Address &caller_address, uint64_t num_returns, bool returns_dynamic, + bool is_streaming_generator, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const std::string &debugger_breakpoint, @@ -149,6 +150,7 @@ class TaskSpecBuilder { message_->mutable_caller_address()->CopyFrom(caller_address); message_->set_num_returns(num_returns); message_->set_returns_dynamic(returns_dynamic); + message_->set_streaming_generator(is_streaming_generator); message_->mutable_required_resources()->insert(required_resources.begin(), required_resources.end()); message_->mutable_required_placement_resources()->insert( diff --git a/src/ray/core_worker/common.cc b/src/ray/core_worker/common.cc index e0849c29ec1f..0f640e154bc3 100644 --- a/src/ray/core_worker/common.cc +++ b/src/ray/core_worker/common.cc @@ -49,5 +49,35 @@ std::string GenerateCachedActorName(const std::string &ns, return ns + "-" + actor_name; } +void SerializeReturnObject(const ObjectID &object_id, + const std::shared_ptr &return_object, + rpc::ReturnObject *return_object_proto) { + return_object_proto->set_object_id(object_id.Binary()); + + if (!return_object) { + // This should only happen if the local raylet died. Caller should + // retry the task. + RAY_LOG(WARNING) << "Failed to create task return object " << object_id + << " in the object store, exiting."; + QuickExit(); + } + return_object_proto->set_size(return_object->GetSize()); + if (return_object->GetData() != nullptr && return_object->GetData()->IsPlasmaBuffer()) { + return_object_proto->set_in_plasma(true); + } else { + if (return_object->GetData() != nullptr) { + return_object_proto->set_data(return_object->GetData()->Data(), + return_object->GetData()->Size()); + } + if (return_object->GetMetadata() != nullptr) { + return_object_proto->set_metadata(return_object->GetMetadata()->Data(), + return_object->GetMetadata()->Size()); + } + } + for (const auto &nested_ref : return_object->GetNestedRefs()) { + return_object_proto->add_nested_inlined_refs()->CopyFrom(nested_ref); + } +} + } // namespace core } // namespace ray diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 9ca7daa2a950..86d7499b0f4b 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -21,6 +21,7 @@ #include "ray/common/task/task_spec.h" #include "ray/raylet_client/raylet_client.h" #include "ray/util/util.h" +#include "src/ray/protobuf/common.pb.h" namespace ray { namespace core { @@ -37,6 +38,10 @@ std::string LanguageString(Language language); // `namespace-[job_id-]actor_name` std::string GenerateCachedActorName(const std::string &ns, const std::string &actor_name); +void SerializeReturnObject(const ObjectID &object_id, + const std::shared_ptr &return_object, + rpc::ReturnObject *return_object_proto); + /// Information about a remote function. class RayFunction { public: diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index ff03b5b85508..7f163ad4d8b7 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1657,6 +1657,45 @@ void CoreWorker::TriggerGlobalGC() { }); } +Status CoreWorker::ObjectRefStreamWrite( + const std::pair> &dynamic_return_object, + const ObjectID &generator_id, + const rpc::Address &caller_address, + int64_t idx, + bool finished) { + RAY_LOG(DEBUG) << "SANG-TODO Write the object ref stream, index: " << idx + << " finished: " << finished << ", id: " << dynamic_return_object.first; + rpc::WriteObjectRefStreamRequest request; + request.mutable_worker_addr()->CopyFrom(rpc_address_); + request.set_idx(idx); + request.set_finished(finished); + request.set_generator_id(generator_id.Binary()); + auto client = core_worker_client_pool_->GetOrConnect(caller_address); + + // Object id is nil if it is the close operations. + // SANG-TODO Support a separate endpoint Close. + if (!dynamic_return_object.first.IsNil()) { + auto return_object_proto = request.add_dynamic_return_objects(); + SerializeReturnObject( + dynamic_return_object.first, dynamic_return_object.second, return_object_proto); + std::vector deleted; + ReferenceCounter::ReferenceTableProto borrowed_refs; + reference_counter_->PopAndClearLocalBorrowers( + {dynamic_return_object.first}, &borrowed_refs, &deleted); + memory_store_->Delete(deleted); + } + + client->WriteObjectRefStream( + request, [](const Status &status, const rpc::WriteObjectRefStreamReply &reply) { + if (status.ok()) { + RAY_LOG(DEBUG) << "SANG-TODO Succeeded to send the object ref"; + } else { + RAY_LOG(DEBUG) << "SANG-TODO Failed to send the object ref"; + } + }); + return Status::OK(); +} + std::string CoreWorker::MemoryUsageString() { // Currently only the Plasma store returns a debug string. return plasma_store_provider_->MemoryUsageString(); @@ -1842,6 +1881,16 @@ void CoreWorker::BuildCommonTaskSpec( // is a generator of ObjectRefs. num_returns = 1; } + // TODO(sang): Remove this and integrate it to + // nun_returns == -1 once migrating to streaming + // generator. + bool is_streaming_generator = num_returns == -2; + if (is_streaming_generator) { + num_returns = 1; + // We are using the dynamic return if + // the streaming generator is used. + returns_dynamic = true; + } RAY_CHECK(num_returns >= 0); builder.SetCommonTaskSpec( task_id, @@ -1858,6 +1907,7 @@ void CoreWorker::BuildCommonTaskSpec( address, num_returns, returns_dynamic, + is_streaming_generator, required_resources, required_placement_resources, debugger_breakpoint, @@ -2591,6 +2641,9 @@ Status CoreWorker::ExecuteTask( dynamic_return_objects = NULL; } else if (task_spec.AttemptNumber() > 0) { for (const auto &dynamic_return_id : task_spec.DynamicReturnIds()) { + // Increase the put index so that when the generator creates a new obj + // the object id won't conflict. + worker_context_.GetNextPutIndex(); dynamic_return_objects->push_back( std::make_pair<>(dynamic_return_id, std::shared_ptr())); RAY_LOG(DEBUG) << "Re-executed task " << task_spec.TaskId() @@ -2651,7 +2704,8 @@ Status CoreWorker::ExecuteTask( application_error, defined_concurrency_groups, name_of_concurrency_group_to_execute, - /*is_reattempt=*/task_spec.AttemptNumber() > 0); + /*is_reattempt=*/task_spec.AttemptNumber() > 0, + /*is_streaming_generator*/task_spec.IsStreamingGenerator()); // Get the reference counts for any IDs that we borrowed during this task, // remove the local reference for these IDs, and return the ref count info to @@ -2744,6 +2798,24 @@ Status CoreWorker::SealReturnObject(const ObjectID &return_id, return status; } +void CoreWorker::DelGenerator(const ObjectID &generator_id) { + task_manager_->DelGenerator(generator_id); +} + +Status CoreWorker::GetNextObjectRef(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out) { + ObjectID object_id; + const auto &status = task_manager_->GetNextObjectRef(generator_id, &object_id); + if (!status.ok()) { + return status; + } + + RAY_CHECK(object_ref_out != nullptr); + object_ref_out->set_object_id(object_id.Binary()); + object_ref_out->mutable_owner_address()->CopyFrom(rpc_address_); + return status; +} + bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, std::shared_ptr *return_object, const ObjectID &generator_id) { @@ -2797,13 +2869,11 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, } } -ObjectID CoreWorker::AllocateDynamicReturnId() { +ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) { const auto &task_spec = worker_context_.GetCurrentTask(); - const auto return_id = - ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); + const auto return_id = ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); AddLocalReference(return_id, ""); - reference_counter_->AddBorrowedObject( - return_id, ObjectID::Nil(), worker_context_.GetCurrentTask()->CallerAddress()); + reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); return return_id; } @@ -3174,6 +3244,7 @@ void CoreWorker::ProcessSubscribeForObjectEviction( // counter so that we know that it exists. const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); + // SANG-TODO If streaming, use streaming instead. reference_counter_->AddDynamicReturn(object_id, generator_id); } @@ -3308,6 +3379,7 @@ void CoreWorker::AddSpilledObjectLocationOwner( // object. Add the dynamically created object to our ref counter so that we // know that it exists. RAY_CHECK(!generator_id->IsNil()); + // SANG-TODO If streaming, use streaming instead. reference_counter_->AddDynamicReturn(object_id, *generator_id); } @@ -3339,6 +3411,7 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, // The task is a generator and may not have finished yet. Add the internal // ObjectID so that we can update its location. reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); + // SANG-TODO If streaming, use streaming instead. RAY_UNUSED(reference_counter_->AddObjectLocation(object_id, node_id)); } } @@ -3369,6 +3442,14 @@ void CoreWorker::ProcessSubscribeObjectLocations( reference_counter_->PublishObjectLocationSnapshot(object_id); } +void CoreWorker::HandleWriteObjectRefStream(rpc::WriteObjectRefStreamRequest request, + rpc::WriteObjectRefStreamReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "SANG-TODO HandleWriteObjectRefStream"; + task_manager_->HandleIntermediateResult(request); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + void CoreWorker::HandleGetObjectLocationsOwner( rpc::GetObjectLocationsOwnerRequest request, rpc::GetObjectLocationsOwnerReply *reply, diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index b87621238f4a..ddfc6928da03 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -354,6 +354,13 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { NodeID GetCurrentNodeId() const { return NodeID::FromBinary(rpc_address_.raylet_id()); } + // SANG-TODO Update the docstring. + void DelGenerator(const ObjectID &generator_id); + + // SANG-TODO Update the docstring. + Status GetNextObjectRef(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out); + const PlacementGroupID &GetCurrentPlacementGroupId() const { return worker_context_.GetCurrentPlacementGroupId(); } @@ -697,6 +704,15 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Trigger garbage collection on each worker in the cluster. void TriggerGlobalGC(); + + /// SANG-TODO Update the docstring. + /// SANG-TODO Support close separately. + Status ObjectRefStreamWrite( + const std::pair> &dynamic_return_object, + const ObjectID &generator_id, + const rpc::Address &caller_address, + int64_t idx, + bool finished); /// Get a string describing object store memory usage for debugging purposes. /// @@ -937,15 +953,18 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { std::shared_ptr *return_object, const ObjectID &generator_id); + /// SANG-TODO Update the docstring. /// Dynamically allocate an object. /// /// This should be used during task execution, if the task wants to return an /// object to the task caller and have the resulting ObjectRef be owned by /// the caller. This is in contrast to static allocation, where the caller /// decides at task invocation time how many returns the task should have. + /// \param[in] owner_address The address of the owner who will own this + /// dynamically generated object. /// /// \param[out] The ObjectID that the caller should use to store the object. - ObjectID AllocateDynamicReturnId(); + ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address); /// Get a handle to an actor. /// @@ -1035,6 +1054,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::GetObjectLocationsOwnerReply *reply, rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. + void HandleWriteObjectRefStream(rpc::WriteObjectRefStreamRequest request, + rpc::WriteObjectRefStreamReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Implements gRPC server handler. void HandleKillActor(rpc::KillActorRequest request, rpc::KillActorReply *reply, diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 157a3fbc53a3..3a8346776077 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -56,7 +56,10 @@ struct CoreWorkerOptions { // used for actor creation task. const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt)>; + bool is_reattempt, + // True if the task is for streaming generator. + // TODO(sang): Remove it and combine it with dynamic returns. + bool is_streaming_generator)>; CoreWorkerOptions() : store_socket(""), diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 5afb92f853be..109dd0dc9686 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -124,7 +124,8 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt) { + bool is_reattempt, + bool is_streaming_generator) { // These 2 parameters are used for Python only, and Java worker // will not use them. RAY_UNUSED(defined_concurrency_groups); diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index ba5321828207..3aa7cce0a044 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -239,6 +239,35 @@ void ReferenceCounter::AddDynamicReturn(const ObjectID &object_id, AddNestedObjectIdsInternal(generator_id, {object_id}, owner_address); } +void ReferenceCounter::AddStreamingDynamicReturn(const ObjectID &object_id, + const ObjectID &generator_id) { + absl::MutexLock lock(&mutex_); + auto outer_it = object_id_refs_.find(generator_id); + if (outer_it == object_id_refs_.end()) { + // Outer object already went out of scope. Either: + // 1. The inner object was never deserialized and has already gone out of + // scope. + // 2. The inner object was deserialized and we already added it as a + // dynamic return. + // Either way, we shouldn't add the inner object to the ref count. + return; + } + RAY_LOG(DEBUG) << "Adding dynamic return " << object_id + << " contained in generator object " << generator_id; + RAY_CHECK(outer_it->second.owned_by_us); + RAY_CHECK(outer_it->second.owner_address.has_value()); + rpc::Address owner_address(outer_it->second.owner_address.value()); + RAY_UNUSED(AddOwnedObjectInternal(object_id, + {}, + owner_address, + outer_it->second.call_site, + /*object_size=*/-1, + outer_it->second.is_reconstructable, + /*add_local_ref=*/true, + absl::optional())); + UpdateObjectPendingCreation(object_id, false); +} + bool ReferenceCounter::AddOwnedObjectInternal( const ObjectID &object_id, const std::vector &inner_ids, diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index c16ee0392119..4b4a597595a2 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -201,6 +201,10 @@ class ReferenceCounter : public ReferenceCounterInterface, void AddDynamicReturn(const ObjectID &object_id, const ObjectID &generator_id) LOCKS_EXCLUDED(mutex_); + // SANG-TODO Update the docstring. + void AddStreamingDynamicReturn(const ObjectID &object_id, const ObjectID &generator_id) + LOCKS_EXCLUDED(mutex_); + /// Update the size of the object. /// /// \param[in] object_id The ID of the object. diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index f5de3de65cc3..2c433d023a00 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -300,6 +300,110 @@ bool TaskManager::HandleTaskReturn(const ObjectID &object_id, return direct_return; } +void TaskManager::DelGenerator(const ObjectID &generator_id) { + while (true) { + ObjectID object_id; + const auto &status = GetNextObjectRef(generator_id, &object_id); + // SANG-TODO We should remove a reference. Need a test. + if (status.IsKeyError()) { + break; + } + if (object_id == ObjectID::Nil()) { + break; + } + RAY_LOG(DEBUG) << "SANG-TODO DelGenerator Get Next"; + } + RAY_LOG(DEBUG) << "SANG-TODO Delete generator from " << generator_id; +} + +Status TaskManager::GetNextObjectRef(const ObjectID &generator_id, + ObjectID *object_id_out) { + absl::MutexLock lock(&mu_); + RAY_CHECK(object_id_out != nullptr); + auto it = dynamic_ids_from_generator_.find(generator_id); + if (it == dynamic_ids_from_generator_.end()) { + RAY_LOG(DEBUG) << "SANG-TODO Generator already GC'ed " << *object_id_out + << " generator id: " << generator_id; + *object_id_out = ObjectID::Nil(); + return Status::OK(); + } + + auto &reader = dynamic_ids_from_generator_[generator_id]; + if (reader.last != -1 && reader.curr >= reader.last) { + RAY_LOG(DEBUG) << "SANG-TODO Generator has no more objects " << generator_id; + return Status::KeyError("Finished"); + } + auto reader_it = reader.idx_to_refs.find(reader.curr); + if (reader_it != reader.idx_to_refs.end()) { + *object_id_out = reader_it->second; + reader.idx_to_refs.erase(reader.curr); + reader.curr += 1; + RAY_LOG(DEBUG) << "SANG-TODO Get the next object id " << *object_id_out + << " generator id: " << generator_id; + } else { + RAY_LOG(DEBUG) << "SANG-TODO Object not available. Current index: " << reader.curr + << " last: " << reader.last << " generator id: " << generator_id; + *object_id_out = ObjectID::Nil(); + } + return Status::OK(); +} + +void TaskManager::HandleIntermediateResult( + const rpc::WriteObjectRefStreamRequest &request) { + const auto &generator_id = ObjectID::FromBinary(request.generator_id()); + const auto &task_id = generator_id.TaskId(); + int64_t idx = request.idx(); + // Every generated object has the same task id. + RAY_LOG(DEBUG) << "SANG-TODO Received an intermediate result of index " << request.idx() + << " generator_id: " << generator_id; + + { + absl::MutexLock lock(&mu_); + if (request.finished()) { + RAY_LOG(DEBUG) << "SANG-TODO Finished with an index " << request.idx(); + auto &reader = dynamic_ids_from_generator_[generator_id]; + reader.last = request.idx(); + RAY_CHECK(request.dynamic_return_objects_size() == 0); + } + } + + const auto store_in_plasma_ids = GetTaskReturnObjectsToStoreInPlasma(task_id); + + for (const auto &return_object : request.dynamic_return_objects()) { + const auto object_id = ObjectID::FromBinary(return_object.object_id()); + RAY_LOG(DEBUG) << "SANG-TODO Add an object " << object_id; + int64_t curr; + { + absl::MutexLock lock(&mu_); + auto &reader = dynamic_ids_from_generator_[generator_id]; + curr = reader.curr; + if (idx >= curr) { + reader.idx_to_refs.emplace(idx, object_id); + // TODO(sang): Add it when retry is supported. + // auto it = submissible_tasks_.find(task_id); + // if (it != submissible_tasks_.end()) { + // // NOTE(sang): This is a hack to modify immutable field. + // // It is possible because most of attributes under + // // TaskSpecification is a pointer to the protobuf message. + // TaskSpecification spec; + // spec = it->second.spec; + // spec.AddDynamicReturnId(object_id); + // it->second.reconstructable_return_ids.insert(object_id); + // } + } + } + // If we call this method while holding a lock, it can deadlock. + if (idx >= curr) { + reference_counter_->AddStreamingDynamicReturn(object_id, generator_id); + } + HandleTaskReturn(object_id, + return_object, + NodeID::FromBinary(request.worker_addr().raylet_id()), + /*store_in_plasma*/store_in_plasma_ids.count(object_id)); + } + RAY_LOG(DEBUG) << "SANG-TODO Finished handling intermediate result"; +} + void TaskManager::CompletePendingTask(const TaskID &task_id, const rpc::PushTaskReply &reply, const rpc::Address &worker_addr, @@ -710,8 +814,9 @@ absl::flat_hash_set TaskManager::GetTaskReturnObjectsToStoreInPlasma( absl::flat_hash_set store_in_plasma_ids = {}; absl::MutexLock lock(&mu_); auto it = submissible_tasks_.find(task_id); - RAY_CHECK(it != submissible_tasks_.end()) - << "Tried to store return values for task that was not pending " << task_id; + if (it == submissible_tasks_.end()) { + return {}; + } first_execution = it->second.num_successful_executions == 0; if (!first_execution) { store_in_plasma_ids = it->second.reconstructable_return_ids; diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 0ab8621368d6..9a80893c8edd 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -37,6 +37,14 @@ class TaskFinisherInterface { const rpc::Address &actor_addr, bool is_application_error) = 0; + virtual void HandleIntermediateResult( + const rpc::WriteObjectRefStreamRequest &request) = 0; + + virtual void DelGenerator(const ObjectID &generator_id) = 0; + + virtual Status GetNextObjectRef(const ObjectID &generator_id, + ObjectID *object_id_out) = 0; + virtual bool RetryTaskIfPossible(const TaskID &task_id, const rpc::RayErrorInfo &error_info) = 0; @@ -87,6 +95,12 @@ using PushErrorCallback = std::function; +struct ObjectRefStreamReader { + absl::flat_hash_map idx_to_refs; + int64_t last = -1; + int64_t curr = 0; +}; + class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterface { public: TaskManager(std::shared_ptr in_memory_store, @@ -167,6 +181,16 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa const rpc::Address &worker_addr, bool is_application_error) override; + + // SANG-TODO Docstring + change the method. + void HandleIntermediateResult(const rpc::WriteObjectRefStreamRequest &request) override; + + // SANG-TODO Docstring + change the method. + void DelGenerator(const ObjectID &generator_id) override; + + // SANG-TODO Docstring + change the method. + Status GetNextObjectRef(const ObjectID &generator_id, ObjectID *object_id_out) override; + /// Returns true if task can be retried. /// /// \param[in] task_id ID of the task to be retried. @@ -459,17 +483,19 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa const rpc::Address &worker_addr, const ReferenceCounter::ReferenceTableProto &borrowed_refs); - // Get the objects that were stored in plasma upon the first successful - // execution of this task. If the task is re-executed, these objects should - // get stored in plasma again, even if they are small and were returned - // directly in the worker's reply. This ensures that any reference holders - // that are already scheduled at the raylet can retrieve these objects - // through plasma. - // \param[in] task_id The task ID. - // \param[out] first_execution Whether the task has been successfully - // executed before. If this is false, then the objects to store in plasma - // will be empty. - // \param [out] Return objects that should be stored in plasma. + /// Get the objects that were stored in plasma upon the first successful + /// execution of this task. If the task is re-executed, these objects should + /// get stored in plasma again, even if they are small and were returned + /// directly in the worker's reply. This ensures that any reference holders + /// that are already scheduled at the raylet can retrieve these objects + /// through plasma. + /// + /// \param[in] task_id The task ID. + /// \param[out] first_execution Whether the task has been successfully + /// executed before. If this is false, then the objects to store in plasma + /// will be empty. + /// \param [out] Return objects that should be stored in plasma. If the + /// task has been already terminated, it returns an empty set. absl::flat_hash_set GetTaskReturnObjectsToStoreInPlasma( const TaskID &task_id, bool *first_execution = nullptr) const LOCKS_EXCLUDED(mu_); @@ -560,6 +586,10 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// error). worker::TaskEventBuffer &task_event_buffer_; + // SANG-TODO Docstring + change the name. + absl::flat_hash_map dynamic_ids_from_generator_ + GUARDED_BY(mu_); + friend class TaskManagerTest; }; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 31a97db7bd4f..62dd91f4474b 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -570,6 +570,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { address, num_returns, false, + false, resources, resources, "", diff --git a/src/ray/core_worker/test/dependency_resolver_test.cc b/src/ray/core_worker/test/dependency_resolver_test.cc index 4d2406e006ec..5ca82b773b7a 100644 --- a/src/ray/core_worker/test/dependency_resolver_test.cc +++ b/src/ray/core_worker/test/dependency_resolver_test.cc @@ -44,6 +44,7 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r empty_address, 1, false, + false, resources, resources, serialized_runtime_env, diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 61eb4370c3f4..498551b61334 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -65,6 +65,7 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r empty_address, 1, false, + false, resources, resources, serialized_runtime_env, diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 1c782438ae28..7529a5255ee0 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -67,7 +67,8 @@ class MockWorker { std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt) { + bool is_reattempt, + bool is_streaming_generator) { return ExecuteTask(caller_address, task_type, task_name, diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index c355d5f42108..57e7dbd1ca76 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -25,36 +25,6 @@ using namespace ray::gcs; namespace ray { namespace core { -void SerializeReturnObject(const ObjectID &object_id, - const std::shared_ptr &return_object, - rpc::ReturnObject *return_object_proto) { - return_object_proto->set_object_id(object_id.Binary()); - - if (!return_object) { - // This should only happen if the local raylet died. Caller should - // retry the task. - RAY_LOG(WARNING) << "Failed to create task return object " << object_id - << " in the object store, exiting."; - QuickExit(); - } - return_object_proto->set_size(return_object->GetSize()); - if (return_object->GetData() != nullptr && return_object->GetData()->IsPlasmaBuffer()) { - return_object_proto->set_in_plasma(true); - } else { - if (return_object->GetData() != nullptr) { - return_object_proto->set_data(return_object->GetData()->Data(), - return_object->GetData()->Size()); - } - if (return_object->GetMetadata() != nullptr) { - return_object_proto->set_metadata(return_object->GetMetadata()->Data(), - return_object->GetMetadata()->Size()); - } - } - for (const auto &nested_ref : return_object->GetNestedRefs()) { - return_object_proto->add_nested_inlined_refs()->CopyFrom(nested_ref); - } -} - void CoreWorkerDirectTaskReceiver::Init( std::shared_ptr client_pool, rpc::Address rpc_address, diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index a81899f4127e..d77ec7fcb34e 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -30,6 +30,7 @@ #include "ray/common/ray_object.h" #include "ray/core_worker/actor_creator.h" #include "ray/core_worker/actor_handle.h" +#include "ray/core_worker/common.h" #include "ray/core_worker/context.h" #include "ray/core_worker/fiber.h" #include "ray/core_worker/store_provider/memory_store/memory_store.h" diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index a0746add894c..07acdc20237d 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -58,6 +58,7 @@ struct Mocker { owner_address, 1, false, + false, required_resources, required_placement_resources, "", diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 6ac9b1411135..bf10020a37b9 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -422,6 +422,12 @@ message TaskSpec { // This will be the actor creation task's task id for concurrent actors. Or // the main thread's task id for other cases. bytes submitter_task_id = 33; + // True if the task is a streaming generator. When it is true, + // returns_dynamic has to be true as well. This is a temporary flag + // until we migrate the generator implementatino to streaming. + // TODO(sang): Remove it once migrating to the streaming generator + // by default. + bool streaming_generator = 34; } message TaskInfoEntry { @@ -539,6 +545,22 @@ message TaskArg { repeated ObjectReference nested_inlined_refs = 4; } +message ReturnObject { + // Object ID. + bytes object_id = 1; + // If set, indicates the data is in plasma instead of inline. This + // means that data and metadata will be empty. + bool in_plasma = 2; + // Data of the object. + bytes data = 3; + // Metadata of the object. + bytes metadata = 4; + // ObjectIDs that were nested in data. This is only set for inlined objects. + repeated ObjectReference nested_inlined_refs = 5; + // Size of this object. + int64 size = 6; +} + // Task spec of an actor creation task. message ActorCreationTaskSpec { // ID of the actor that will be created by this task. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index ab709d8cd9a3..a0ac0832185a 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -69,22 +69,6 @@ message ActorHandle { int32 max_pending_calls = 13; } -message ReturnObject { - // Object ID. - bytes object_id = 1; - // If set, indicates the data is in plasma instead of inline. This - // means that data and metadata will be empty. - bool in_plasma = 2; - // Data of the object. - bytes data = 3; - // Metadata of the object. - bytes metadata = 4; - // ObjectIDs that were nested in data. This is only set for inlined objects. - repeated ObjectReference nested_inlined_refs = 5; - // Size of this object. - int64 size = 6; -} - message PushTaskRequest { // The ID of the worker this message is intended for. bytes intended_worker_id = 1; @@ -398,6 +382,16 @@ message RayletNotifyGCSRestartRequest {} message RayletNotifyGCSRestartReply {} +message WriteObjectRefStreamRequest { + repeated ReturnObject dynamic_return_objects = 1; + Address worker_addr = 2; + int64 idx = 3; + bool finished = 4; + bytes generator_id = 5; +} + +message WriteObjectRefStreamReply {} + service CoreWorkerService { // Notify core worker GCS has restarted. rpc RayletNotifyGCSRestart(RayletNotifyGCSRestartRequest) @@ -418,6 +412,8 @@ service CoreWorkerService { /// It is replied once there are batch of objects that need to be published to /// the caller (subscriber). rpc PubsubLongPolling(PubsubLongPollingRequest) returns (PubsubLongPollingReply); + // SANG-TODO Write a docstring and change the RPC name. + rpc WriteObjectRefStream(WriteObjectRefStreamRequest) returns (WriteObjectRefStreamReply); /// The pubsub command batch request used by the subscriber. rpc PubsubCommandBatch(PubsubCommandBatchRequest) returns (PubsubCommandBatchReply); // Update the batched object location information to the ownership-based object diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index de2bd227996c..d5f312864e79 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -165,6 +165,7 @@ RayTask CreateTask( address, 0, /*returns_dynamic=*/false, + /*is_streaming_generator*/false, required_resources, {}, "", diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index de9b68ba0fd5..c950f9ad56ff 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -154,6 +154,10 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { const GetObjectLocationsOwnerRequest &request, const ClientCallback &callback) {} + virtual void WriteObjectRefStream( + const WriteObjectRefStreamRequest &request, + const ClientCallback &callback) {} + /// Tell this actor to exit immediately. virtual void KillActor(const KillActorRequest &request, const ClientCallback &callback) {} @@ -283,6 +287,12 @@ class CoreWorkerClient : public std::enable_shared_from_this, /*method_timeout_ms*/ -1, override) + VOID_RPC_CLIENT_METHOD(CoreWorkerService, + WriteObjectRefStream, + grpc_client_, + /*method_timeout_ms*/ -1, + override) + VOID_RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index b881778f03de..c9dc97967edd 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -43,6 +43,8 @@ namespace rpc { CoreWorkerService, UpdateObjectLocationBatch, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ CoreWorkerService, GetObjectLocationsOwner, -1) \ + RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ + CoreWorkerService, WriteObjectRefStream, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, KillActor, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, CancelTask, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, RemoteCancelTask, -1) \ @@ -68,6 +70,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubCommandBatch) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(UpdateObjectLocationBatch) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectLocationsOwner) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WriteObjectRefStream) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RemoteCancelTask) \ From 3ebe327916d5e488594dab8d7ab8ce194db67376 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 12 May 2023 07:49:50 -0700 Subject: [PATCH 02/77] in progress. Signed-off-by: SangBin Cho --- .../runtime/task/local_mode_task_submitter.cc | 2 +- python/ray/_private/ray_option_utils.py | 2 +- python/ray/_private/worker.py | 2 +- python/ray/_raylet.pyx | 48 ++++++++++--------- python/ray/actor.py | 8 +++- python/ray/includes/libcoreworker.pxd | 10 ++-- python/ray/remote_function.py | 9 +++- python/ray/tests/test_streaming_generator.py | 31 ++++++------ src/ray/common/task/task_spec.cc | 4 +- src/ray/core_worker/core_worker.cc | 11 +++-- src/ray/core_worker/core_worker.h | 12 ++--- src/ray/core_worker/task_manager.cc | 2 +- src/ray/core_worker/task_manager.h | 1 - .../scheduling/cluster_task_manager_test.cc | 2 +- 14 files changed, 83 insertions(+), 61 deletions(-) diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 8e82b06e1eaa..6052531e1211 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -61,7 +61,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, address, 1, /*returns_dynamic=*/false, - /*is_streaming_generator*/false, + /*is_streaming_generator*/ false, required_resources, required_placement_resources, "", diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index 88703942f64e..97c35f9449ca 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -154,7 +154,7 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "num_returns": Option( (int, str, type(None)), lambda x: None - if (x is None or x == "dynamic" or x >= 0) + if (x is None or x == "dynamic" or x == "streaming" or x >= 0) else "The keyword 'num_returns' only accepts None, a non-negative integer, or " '"dynamic" (for generators)', default_value=1, diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 7fe3db7d0a6f..9fa4e2e9bcc2 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2810,7 +2810,7 @@ def cancel(object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool worker = ray._private.worker.global_worker worker.check_connected() - if isinstance(object_ref, ray._raylet.StreamingObjectRefGeneratorV2): + if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): assert hasattr(object_ref, "_generator_ref") object_ref = object_ref._generator_ref diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 8bf936f6497c..0c3843b60e8e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -196,7 +196,7 @@ class ObjectRefGenerator: return len(self._refs) -class StreamingObjectRefGeneratorV2: +class StreamingObjectRefGenerator: def __init__(self, generator_ref): self._generator_ref = generator_ref self._generator_task_completed_time = None @@ -234,8 +234,8 @@ class StreamingObjectRefGeneratorV2: # has been terminated 30 seconds ago. assert False, "Unexpected network failure occured." - - time.sleep(0.001) + # 100us busy waiting + time.sleep(0.0001) obj = self._handle_next() return obj @@ -792,18 +792,14 @@ cdef execute_streaming_generator( and core_worker.get_current_task_retry_exceptions() ): logger.debug("Task failed with retryable exception:" - " {}.".format( - task_id), - exc_info=True) + " {}.".format(task_id), exc_info=True) # Raise an exception directly and halt the execution # because there's no need to set the exception # for the return value when the task is retryable. raise e logger.debug("Task failed with unretryable exception:" - " {}.".format( - task_id), - exc_info=True) + " {}.".format(task_id), exc_info=True) error_id = (CCoreWorkerProcess.GetCoreWorker() .AllocateDynamicReturnId(caller_address)) @@ -822,11 +818,16 @@ cdef execute_streaming_generator( CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( intermediate_result.back(), generator_id, caller_address, generator_index, False) + + if intermediate_result.size() > 0: + intermediate_result.pop_back() + generator_index += 1 break else: # Report the intermediate result if there was no error. - return_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + return_id = ( + CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + caller_address)) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -839,23 +840,25 @@ cdef execute_streaming_generator( # print("SANG-TODO Writes an index ", i) assert intermediate_result.size() == 1 del output - + CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( intermediate_result.back(), generator_id, caller_address, generator_index, False) - finally: + if intermediate_result.size() > 0: intermediate_result.pop_back() generator_index += 1 - # Close it. - # SANG-TODO Implement the close API. + # All the intermediate result has to be popped and reported. + assert intermediate_result.size() == 0 + # Report the owner that there's no more objects. # print("SANG-TODO Closes an index ", i) CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( - c_pair[CObjectID, shared_ptr[CRayObject]](CObjectID.Nil(), shared_ptr[CRayObject]()), + c_pair[CObjectID, shared_ptr[CRayObject]]( + CObjectID.Nil(), shared_ptr[CRayObject]()), generator_id, caller_address, generator_index, @@ -1121,12 +1124,12 @@ cdef void execute_task( if not inspect.isgenerator(outputs): raise ValueError( "Functions with " - "@ray.remote(num_returns=\"streaming\" must return a " - "generator") + "@ray.remote(num_returns=\"streaming\" " + "must return a generator") execute_streaming_generator( outputs, - returns[0][0].first, # generator object ID. + returns[0][0].first, # generator object ID. task_type, caller_address, task_id, @@ -1220,10 +1223,9 @@ cdef void execute_task( # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): - num_returns = returns[0].size() # TODO(sang): Remove it once we use streaming generator # by default. - if dynamic_returns != NULL: + if dynamic_returns != NULL and not is_streaming_generator: if not inspect.isgenerator(outputs): raise ValueError( "Functions with " @@ -3298,7 +3300,9 @@ cdef class CoreWorker: CObjectID c_generator_id = generator_id.native() CObjectReference c_object_ref - check_status(CCoreWorkerProcess.GetCoreWorker().GetNextObjectRef(c_generator_id, &c_object_ref)) + check_status( + CCoreWorkerProcess.GetCoreWorker().GetNextObjectRef( + c_generator_id, &c_object_ref)) return ObjectRef( c_object_ref.object_id(), c_object_ref.owner_address().SerializeAsString(), diff --git a/python/ray/actor.py b/python/ray/actor.py index 27fcf8c05a16..6b4127067680 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -22,7 +22,7 @@ ) from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group from ray._private.utils import get_runtime_env_info, parse_runtime_env -from ray._raylet import PythonFunctionDescriptor +from ray._raylet import PythonFunctionDescriptor, StreamingObjectRefGenerator from ray.exceptions import AsyncioActorExit from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.placement_group import _configure_placement_group_based_on_context @@ -1183,6 +1183,12 @@ def _actor_method_call( concurrency_group_name if concurrency_group_name is not None else b"", ) + if num_returns == -2: + # Streaming generator will return a single ref + # that is for the generator task. + assert len(object_refs) == 1 + generator_ref = object_refs[0] + return StreamingObjectRefGenerator(generator_ref) if len(object_refs) == 1: object_refs = object_refs[0] elif len(object_refs) == 0: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 026ba9f57ae6..1575c0687b88 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -145,13 +145,13 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool PinExistingReturnObject( const CObjectID& return_id, shared_ptr[CRayObject] *return_object, - const CObjectID& generator_id - ) - void DelGenerator(const CObjectID &generator_id) - CRayStatus GetNextObjectRef(const CObjectID &generator_id, CObjectReference *object_ref_out) + const CObjectID& generator_id) + void DelGenerator(const CObjectID &generator_id) + CRayStatus GetNextObjectRef( + const CObjectID &generator_id, + CObjectReference *object_ref_out) CObjectID AllocateDynamicReturnId(const CAddress &owner_address) - CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() CNodeID GetCurrentNodeId() diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 3b36800344fc..607ae9fec640 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -15,7 +15,7 @@ from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group from ray._private.serialization import pickle_dumps from ray._private.utils import get_runtime_env_info, parse_runtime_env -from ray._raylet import PythonFunctionDescriptor +from ray._raylet import PythonFunctionDescriptor, StreamingObjectRefGenerator from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.placement_group import _configure_placement_group_based_on_context from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -310,6 +310,7 @@ def _remote(self, args=None, kwargs=None, **task_options): # TODO(sang): This is a temporary private API. # Remove it when we migrate to the streaming generator. num_returns = -2 + max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] if isinstance(retry_exceptions, (list, tuple)): @@ -400,6 +401,12 @@ def invocation(args, kwargs): # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). worker.debugger_breakpoint = b"" + if num_returns == -2: + # Streaming generator will return a single ref + # that is for the generator task. + assert len(object_refs) == 1 + generator_ref = object_refs[0] + return StreamingObjectRefGenerator(generator_ref) if len(object_refs) == 1: return object_refs[0] elif len(object_refs) > 1: diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 5f81a737b58f..d99234d23e71 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -1,24 +1,17 @@ -import asyncio import pytest import numpy as np import sys import time import ray -from ray.util.client.ray_client_helpers import ( - ray_start_client_server_for_address, -) -from ray._private.client_mode_hook import enable_client_mode -from ray.tests.conftest import call_ray_start_context from ray._private.test_utils import wait_for_condition -from ray.experimental.state.api import list_tasks -from ray._raylet import StreamingObjectRefGeneratorV2 def test_generator_basic(shutdown_only): ray.init(num_cpus=1) """Basic cases""" + @ray.remote def f(): for i in range(5): @@ -63,8 +56,6 @@ def getpid(self): def f(self): for i in range(5): - import time - time.sleep(0.1) yield i @@ -90,7 +81,7 @@ def f(self): next(gen) """Retry exceptions""" - # TODO(sang): Enable it + # TODO(sang): Enable it once retry is supported. # @ray.remote # class Actor: # def __init__(self): @@ -175,16 +166,26 @@ def generator(num_returns, store_in_plasma): remote_generator_fn = generator """Verify num_returns="streaming" is streaming""" - gen = remote_generator_fn.options(num_returns="streaming").remote(3, store_in_plasma) + gen = remote_generator_fn.options(num_returns="streaming").remote( + 3, store_in_plasma + ) + i = 0 for ref in gen: id = ref.hex() - print(ray.get(ref)) + if store_in_plasma: + expected = np.ones(1_000_000, dtype=np.int8) * i + assert np.array_equal(ray.get(ref), expected) + else: + expected = [i] + assert ray.get(ref) == expected + del ref from ray.experimental.state.api import list_objects wait_for_condition( lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 ) + i += 1 def test_generator_dist_chain(ray_start_cluster): @@ -207,7 +208,9 @@ def get_data(self): time.sleep(0.1) yield np.ones(5 * 1024 * 1024) else: - for data in self.child.get_data.options(num_returns="streaming").remote(): + for data in self.child.get_data.options( + num_returns="streaming" + ).remote(): yield ray.get(data) chain_actor = ChainActor.remote() diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index d28a6f671334..11e4778b297e 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -220,7 +220,9 @@ bool TaskSpecification::ReturnsDynamic() const { return message_->returns_dynami // TODO(sang): Merge this with ReturnsDynamic once migrating to the // streaming generator. -bool TaskSpecification::IsStreamingGenerator() const { return message_->streaming_generator(); } +bool TaskSpecification::IsStreamingGenerator() const { + return message_->streaming_generator(); +} std::vector TaskSpecification::DynamicReturnIds() const { RAY_CHECK(message_->returns_dynamic()); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7f163ad4d8b7..4a57e99904f6 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2705,7 +2705,7 @@ Status CoreWorker::ExecuteTask( defined_concurrency_groups, name_of_concurrency_group_to_execute, /*is_reattempt=*/task_spec.AttemptNumber() > 0, - /*is_streaming_generator*/task_spec.IsStreamingGenerator()); + /*is_streaming_generator*/ task_spec.IsStreamingGenerator()); // Get the reference counts for any IDs that we borrowed during this task, // remove the local reference for these IDs, and return the ref count info to @@ -2871,7 +2871,8 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) { const auto &task_spec = worker_context_.GetCurrentTask(); - const auto return_id = ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); + const auto return_id = + ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); return return_id; @@ -3245,7 +3246,7 @@ void CoreWorker::ProcessSubscribeForObjectEviction( const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); // SANG-TODO If streaming, use streaming instead. - reference_counter_->AddDynamicReturn(object_id, generator_id); + reference_counter_->AddStreamingDynamicReturn(object_id, generator_id); } // Returns true if the object was present and the callback was added. It might have @@ -3380,7 +3381,7 @@ void CoreWorker::AddSpilledObjectLocationOwner( // know that it exists. RAY_CHECK(!generator_id->IsNil()); // SANG-TODO If streaming, use streaming instead. - reference_counter_->AddDynamicReturn(object_id, *generator_id); + reference_counter_->AddStreamingDynamicReturn(object_id, *generator_id); } auto reference_exists = @@ -3410,8 +3411,8 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, if (!maybe_generator_id.IsNil()) { // The task is a generator and may not have finished yet. Add the internal // ObjectID so that we can update its location. - reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); // SANG-TODO If streaming, use streaming instead. + reference_counter_->AddStreamingDynamicReturn(object_id, maybe_generator_id); RAY_UNUSED(reference_counter_->AddObjectLocation(object_id, node_id)); } } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index ddfc6928da03..05b05fc53214 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -704,15 +704,15 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Trigger garbage collection on each worker in the cluster. void TriggerGlobalGC(); - + /// SANG-TODO Update the docstring. /// SANG-TODO Support close separately. Status ObjectRefStreamWrite( - const std::pair> &dynamic_return_object, - const ObjectID &generator_id, - const rpc::Address &caller_address, - int64_t idx, - bool finished); + const std::pair> &dynamic_return_object, + const ObjectID &generator_id, + const rpc::Address &caller_address, + int64_t idx, + bool finished); /// Get a string describing object store memory usage for debugging purposes. /// diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 2c433d023a00..3c0ade09278b 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -399,7 +399,7 @@ void TaskManager::HandleIntermediateResult( HandleTaskReturn(object_id, return_object, NodeID::FromBinary(request.worker_addr().raylet_id()), - /*store_in_plasma*/store_in_plasma_ids.count(object_id)); + /*store_in_plasma*/ store_in_plasma_ids.count(object_id)); } RAY_LOG(DEBUG) << "SANG-TODO Finished handling intermediate result"; } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 9a80893c8edd..bf4561db8d17 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -181,7 +181,6 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa const rpc::Address &worker_addr, bool is_application_error) override; - // SANG-TODO Docstring + change the method. void HandleIntermediateResult(const rpc::WriteObjectRefStreamRequest &request) override; diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index d5f312864e79..d5e17ee0fe62 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -165,7 +165,7 @@ RayTask CreateTask( address, 0, /*returns_dynamic=*/false, - /*is_streaming_generator*/false, + /*is_streaming_generator*/ false, required_resources, {}, "", From c140a5caefd759d0006dd394b018a212996ea557 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 12 May 2023 14:08:52 -0700 Subject: [PATCH 03/77] finished basics. Signed-off-by: SangBin Cho --- python/ray/_private/worker.py | 5 + python/ray/_raylet.pyx | 75 +++++++-- python/ray/actor.py | 2 +- python/ray/includes/libcoreworker.pxd | 7 +- python/ray/remote_function.py | 2 +- python/ray/tests/test_generators.py | 52 ++++-- python/ray/tests/test_streaming_generator.py | 52 +++++- src/ray/core_worker/core_worker.cc | 63 ++++--- src/ray/core_worker/core_worker.h | 43 ++++- src/ray/core_worker/reference_count.cc | 17 +- src/ray/core_worker/reference_count.h | 10 +- src/ray/core_worker/task_manager.cc | 167 ++++++++++++------- src/ray/core_worker/task_manager.h | 96 +++++++++-- src/ray/protobuf/core_worker.proto | 19 ++- src/ray/rpc/worker/core_worker_client.h | 8 +- src/ray/rpc/worker/core_worker_server.h | 4 +- 16 files changed, 458 insertions(+), 164 deletions(-) diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 9fa4e2e9bcc2..87c5dbdd8eca 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2492,6 +2492,11 @@ def get( blocking_get_inside_async_warned = True with profiling.profile("ray.get"): + # TODO(sang): Should make StreamingObjectRefGenerator + # compatible to ray.get for dataset. + if isinstance(object_refs, ray._raylet.StreamingObjectRefGenerator): + return object_refs + is_individual_id = isinstance(object_refs, ray.ObjectRef) if is_individual_id: object_refs = [object_refs] diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 0c3843b60e8e..d4983abbbba5 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -197,24 +197,45 @@ class ObjectRefGenerator: class StreamingObjectRefGenerator: - def __init__(self, generator_ref): + def __init__(self, generator_ref, worker): self._generator_ref = generator_ref self._generator_task_completed_time = None self._generator_task_exception = None + self.worker = worker + assert hasattr(worker, "core_worker") + worker.core_worker.create_generator(self._generator_ref) def __iter__(self): return self def __next__(self): - core_worker = ray._private.worker.global_worker.core_worker + """Wait until the next ref is available to the + generator and return the object ref. + + The API will raise StopIteration if there's no more objects + to generate. + + The object ref will contain an exception if the task fails. + When the generator task returns N objects, it can return + up to N + 1 objects (if there's a system failure, the + last object will contain a system level exception). + """ obj = self._handle_next() + + # The generator ref will be None if the task succeeds. + # It will contain an exception if the task fails by + # a system error. while obj.is_nil(): if self._generator_task_exception: - # The generator task has failed. We raise StopIteration + # The generator task has failed already. + # We raise StopIteration # to conform the next interface in Python. raise StopIteration else: - # Otherwise, check the task status. + # Otherwise, we should ray.get on the generator + # ref to find if the task has a system failure. + # Return the generator ref that contains the system + # error as soon as possible. r, _ = ray.wait([self._generator_ref], timeout=0) if len(r) > 0: try: @@ -228,6 +249,12 @@ class StreamingObjectRefGenerator: if self._generator_task_completed_time is None: self._generator_task_completed_time = time.time() + # Currently, since the ordering of intermediate result report + # is not guaranteed, it is possible that althoug the task + # has succeeded, all of the object references are not reported + # (e.g., when there are network failures). + # If all the object refs are not reported to the generator + # within 30 seconds, we consider is as an unreconverable error. if self._generator_task_completed_time: if time.time() - self._generator_task_completed_time > 30: # It means the next wasn't reported although the task @@ -241,16 +268,26 @@ class StreamingObjectRefGenerator: def _handle_next(self): try: - core_worker = ray._private.worker.global_worker.core_worker - obj = core_worker.generator_get_next(self._generator_ref) - return obj + worker = ray._private.worker.global_worker + if hasattr(worker, "core_worker"): + obj = worker.core_worker.generator_get_next(self._generator_ref) + return obj + else: + raise ValueError( + "Cannot access the core worker. " + "Did you already shutdown Ray via ray.shutdown()?") except RayKeyError: raise StopIteration def __del__(self): worker = ray._private.worker.global_worker if hasattr(worker, "core_worker"): - worker.core_worker.generator_del(self._generator_ref) + worker.core_worker.delete_generator(self._generator_ref) + + def __getstate__(self): + raise TypeError( + "Serialization of the StreamingObjectRefGenerator " + "is now allowed") cdef int check_status(const CRayStatus& status) nogil except -1: @@ -815,7 +852,7 @@ cdef execute_streaming_generator( function_name, task_type, title, &intermediate_result, application_error, caller_address) - CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( + CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( intermediate_result.back(), generator_id, caller_address, generator_index, False) @@ -841,7 +878,7 @@ cdef execute_streaming_generator( assert intermediate_result.size() == 1 del output - CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( + CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( intermediate_result.back(), generator_id, caller_address, @@ -856,13 +893,13 @@ cdef execute_streaming_generator( assert intermediate_result.size() == 0 # Report the owner that there's no more objects. # print("SANG-TODO Closes an index ", i) - CCoreWorkerProcess.GetCoreWorker().ObjectRefStreamWrite( + CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( c_pair[CObjectID, shared_ptr[CRayObject]]( CObjectID.Nil(), shared_ptr[CRayObject]()), generator_id, caller_address, generator_index, - True) + True) # finished. cdef execute_dynamic_generator_and_store_task_outputs( @@ -3015,6 +3052,8 @@ cdef class CoreWorker: raise ValueError( "Task returned more than num_returns={} objects.".format( num_returns)) + # TODO(sang): Remove it when the streaming generator is + # enabled by default. while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() .AllocateDynamicReturnId(caller_address)) @@ -3289,11 +3328,17 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) - def generator_del(self, ObjectRef generator_id): + def create_generator(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + + CCoreWorkerProcess.GetCoreWorker().CreateObjectRefStream(c_generator_id) + + def delete_generator(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() - CCoreWorkerProcess.GetCoreWorker().DelGenerator(c_generator_id) + CCoreWorkerProcess.GetCoreWorker().DelObjectRefStream(c_generator_id) def generator_get_next(self, ObjectRef generator_id): cdef: @@ -3301,7 +3346,7 @@ cdef class CoreWorker: CObjectReference c_object_ref check_status( - CCoreWorkerProcess.GetCoreWorker().GetNextObjectRef( + CCoreWorkerProcess.GetCoreWorker().AsyncReadObjectRefStream( c_generator_id, &c_object_ref)) return ObjectRef( c_object_ref.object_id(), diff --git a/python/ray/actor.py b/python/ray/actor.py index 6b4127067680..91b88de7b947 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -1188,7 +1188,7 @@ def _actor_method_call( # that is for the generator task. assert len(object_refs) == 1 generator_ref = object_refs[0] - return StreamingObjectRefGenerator(generator_ref) + return StreamingObjectRefGenerator(generator_ref, worker) if len(object_refs) == 1: object_refs = object_refs[0] elif len(object_refs) == 0: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 1575c0687b88..c9e99a24d8ef 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -146,8 +146,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const CObjectID& return_id, shared_ptr[CRayObject] *return_object, const CObjectID& generator_id) - void DelGenerator(const CObjectID &generator_id) - CRayStatus GetNextObjectRef( + void DelObjectRefStream(const CObjectID &generator_id) + void CreateObjectRefStream(const CObjectID &generator_id) + CRayStatus AsyncReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) CObjectID AllocateDynamicReturnId(const CAddress &owner_address) @@ -238,7 +239,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t timeout_ms, c_vector[shared_ptr[CObjectLocation]] *results) CRayStatus TriggerGlobalGC() - CRayStatus ObjectRefStreamWrite( + CRayStatus ReportIntermediateTaskReturn( const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object, const CObjectID &generator_id, const CAddress &caller_address, diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 607ae9fec640..bb627f09af92 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -406,7 +406,7 @@ def invocation(args, kwargs): # that is for the generator task. assert len(object_refs) == 1 generator_ref = object_refs[0] - return StreamingObjectRefGenerator(generator_ref) + return StreamingObjectRefGenerator(generator_ref, worker) if len(object_refs) == 1: return object_refs[0] elif len(object_refs) > 1: diff --git a/python/ray/tests/test_generators.py b/python/ray/tests/test_generators.py index 64cd59d6002a..3430da39cda2 100644 --- a/python/ray/tests/test_generators.py +++ b/python/ray/tests/test_generators.py @@ -117,7 +117,10 @@ def generator(num_returns, store_in_plasma): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_generator_errors(ray_start_regular, use_actors, store_in_plasma): +@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +def test_generator_errors( + ray_start_regular, use_actors, store_in_plasma, num_returns_type +): remote_generator_fn = None if use_actors: @@ -158,7 +161,7 @@ def generator(num_returns, store_in_plasma): with pytest.raises(ray.exceptions.RayTaskError): ray.get(ref3) - dynamic_ref = remote_generator_fn.options(num_returns="dynamic").remote( + dynamic_ref = remote_generator_fn.options(num_returns=num_returns_type).remote( 3, store_in_plasma ) ref1, ref2 = ray.get(dynamic_ref) @@ -218,10 +221,13 @@ def generator(num_returns, store_in_plasma, counter): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_dynamic_generator(ray_start_regular, use_actors, store_in_plasma): +@pytest.mark.parametrize("num_returns_type", ["streaming"]) +def test_dynamic_generator( + ray_start_regular, use_actors, store_in_plasma, num_returns_type +): if use_actors: - @ray.remote(num_returns="dynamic") + @ray.remote(num_returns=num_returns_type) def dynamic_generator(num_returns, store_in_plasma): for i in range(num_returns): if store_in_plasma: @@ -255,21 +261,34 @@ def read(gen): return True gen = ray.get( - remote_generator_fn.options(num_returns="dynamic").remote(10, store_in_plasma) + remote_generator_fn.options(num_returns=num_returns_type).remote( + 10, store_in_plasma + ) ) for i, ref in enumerate(gen): assert ray.get(ref)[0] == i # Test empty generator. gen = ray.get( - remote_generator_fn.options(num_returns="dynamic").remote(0, store_in_plasma) + remote_generator_fn.options(num_returns=num_returns_type).remote( + 0, store_in_plasma + ) ) - assert len(gen) == 0 + assert len(list(gen)) == 0 # Check that passing as task arg. - gen = remote_generator_fn.options(num_returns="dynamic").remote(10, store_in_plasma) - assert ray.get(read.remote(gen)) - assert ray.get(read.remote(ray.get(gen))) + if num_returns_type == "dynamic": + gen = remote_generator_fn.options(num_returns=num_returns_type).remote( + 10, store_in_plasma + ) + assert ray.get(read.remote(gen)) + assert ray.get(read.remote(ray.get(gen))) + else: + with pytest.raises(TypeError): + gen = remote_generator_fn.options(num_returns=num_returns_type).remote( + 10, store_in_plasma + ) + assert ray.get(read.remote(gen)) # Also works if we override num_returns with a static value. ray.get( @@ -279,15 +298,18 @@ def read(gen): ) # Normal remote functions don't work with num_returns="dynamic". - @ray.remote(num_returns="dynamic") + @ray.remote(num_returns=num_returns_type) def static(num_returns): return list(range(num_returns)) with pytest.raises(ray.exceptions.RayTaskError): - ray.get(static.remote(3)) + gen = ray.get(static.remote(3)) + for ref in gen: + ray.get(ref) -def test_dynamic_generator_distributed(ray_start_cluster): +@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +def test_dynamic_generator_distributed(ray_start_cluster, num_returns_type): cluster = ray_start_cluster # Head node with no resources. cluster.add_node(num_cpus=0) @@ -295,7 +317,7 @@ def test_dynamic_generator_distributed(ray_start_cluster): cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - @ray.remote(num_returns="dynamic") + @ray.remote(num_returns=num_returns_type) def dynamic_generator(num_returns): for i in range(num_returns): yield np.ones(1_000_000, dtype=np.int8) * i @@ -535,7 +557,7 @@ def maybe_empty_generator(exec_counter): @ray.remote def check(empty_generator): - return len(empty_generator) == 0 + return len(list(empty_generator)) == 0 exec_counter = ExecutionCounter.remote() gen = maybe_empty_generator.remote(exec_counter) diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index d99234d23e71..72c07a2af658 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -2,9 +2,11 @@ import numpy as np import sys import time +import gc import ray from ray._private.test_utils import wait_for_condition +from ray.experimental.state.api import list_objects def test_generator_basic(shutdown_only): @@ -131,6 +133,55 @@ def f(): next(gen) +@pytest.mark.parametrize("crash_type", ["exception", "worker_crash"]) +def test_generator_streaming_no_leak_upon_failures( + monkeypatch, shutdown_only, crash_type +): + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server.ReportIntermediateTaskReturn=100000:1000000", + ) + ray.init(num_cpus=1) + + @ray.remote + def g(): + try: + gen = f.options(num_returns="streaming").remote() + for ref in gen: + print(ref) + ray.get(ref) + except Exception: + print("exception!") + del ref + + del gen + gc.collect() + + # Only the ref g is alive. + def verify(): + print(list_objects()) + return len(list_objects()) == 1 + + wait_for_condition(verify) + return True + + @ray.remote + def f(): + for i in range(10): + time.sleep(0.2) + if i == 4: + if crash_type == "exception": + raise ValueError + else: + sys.exit(9) + yield 2 + + for _ in range(5): + ray.get(g.remote()) + + @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) def test_generator_streaming(shutdown_only, use_actors, store_in_plasma): @@ -180,7 +231,6 @@ def generator(num_returns, store_in_plasma): assert ray.get(ref) == expected del ref - from ray.experimental.state.api import list_objects wait_for_condition( lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 4a57e99904f6..6c8899d226de 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1657,7 +1657,7 @@ void CoreWorker::TriggerGlobalGC() { }); } -Status CoreWorker::ObjectRefStreamWrite( +Status CoreWorker::ReportIntermediateTaskReturn( const std::pair> &dynamic_return_object, const ObjectID &generator_id, const rpc::Address &caller_address, @@ -1665,7 +1665,7 @@ Status CoreWorker::ObjectRefStreamWrite( bool finished) { RAY_LOG(DEBUG) << "SANG-TODO Write the object ref stream, index: " << idx << " finished: " << finished << ", id: " << dynamic_return_object.first; - rpc::WriteObjectRefStreamRequest request; + rpc::ReportIntermediateTaskReturnRequest request; request.mutable_worker_addr()->CopyFrom(rpc_address_); request.set_idx(idx); request.set_finished(finished); @@ -1685,8 +1685,9 @@ Status CoreWorker::ObjectRefStreamWrite( memory_store_->Delete(deleted); } - client->WriteObjectRefStream( - request, [](const Status &status, const rpc::WriteObjectRefStreamReply &reply) { + client->ReportIntermediateTaskReturn( + request, + [](const Status &status, const rpc::ReportIntermediateTaskReturnReply &reply) { if (status.ok()) { RAY_LOG(DEBUG) << "SANG-TODO Succeeded to send the object ref"; } else { @@ -2798,14 +2799,18 @@ Status CoreWorker::SealReturnObject(const ObjectID &return_id, return status; } -void CoreWorker::DelGenerator(const ObjectID &generator_id) { - task_manager_->DelGenerator(generator_id); +void CoreWorker::CreateObjectRefStream(const ObjectID &generator_id) { + task_manager_->CreateObjectRefStream(generator_id); } -Status CoreWorker::GetNextObjectRef(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out) { +void CoreWorker::DelObjectRefStream(const ObjectID &generator_id) { + task_manager_->DelObjectRefStream(generator_id); +} + +Status CoreWorker::AsyncReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out) { ObjectID object_id; - const auto &status = task_manager_->GetNextObjectRef(generator_id, &object_id); + const auto &status = task_manager_->AsyncReadObjectRefStream(generator_id, &object_id); if (!status.ok()) { return status; } @@ -3245,8 +3250,12 @@ void CoreWorker::ProcessSubscribeForObjectEviction( // counter so that we know that it exists. const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); - // SANG-TODO If streaming, use streaming instead. - reference_counter_->AddStreamingDynamicReturn(object_id, generator_id); + if (task_manager_->ObjectRefStreamExists(generator_id)) { + reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, + generator_id); + } else { + reference_counter_->AddDynamicReturn(object_id, generator_id); + } } // Returns true if the object was present and the callback was added. It might have @@ -3380,8 +3389,12 @@ void CoreWorker::AddSpilledObjectLocationOwner( // object. Add the dynamically created object to our ref counter so that we // know that it exists. RAY_CHECK(!generator_id->IsNil()); - // SANG-TODO If streaming, use streaming instead. - reference_counter_->AddStreamingDynamicReturn(object_id, *generator_id); + if (task_manager_->ObjectRefStreamExists(*generator_id)) { + reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, + *generator_id); + } else { + reference_counter_->AddDynamicReturn(object_id, *generator_id); + } } auto reference_exists = @@ -3409,10 +3422,15 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, // until the task finishes. const auto &maybe_generator_id = task_manager_->TaskGeneratorId(object_id.TaskId()); if (!maybe_generator_id.IsNil()) { - // The task is a generator and may not have finished yet. Add the internal - // ObjectID so that we can update its location. - // SANG-TODO If streaming, use streaming instead. - reference_counter_->AddStreamingDynamicReturn(object_id, maybe_generator_id); + if (task_manager_->ObjectRefStreamExists(maybe_generator_id)) { + // If the stream exists, it means it is a streaming generator. + reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, + maybe_generator_id); + } else { + // The task is a generator and may not have finished yet. Add the internal + // ObjectID so that we can update its location. + reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); + } RAY_UNUSED(reference_counter_->AddObjectLocation(object_id, node_id)); } } @@ -3443,11 +3461,12 @@ void CoreWorker::ProcessSubscribeObjectLocations( reference_counter_->PublishObjectLocationSnapshot(object_id); } -void CoreWorker::HandleWriteObjectRefStream(rpc::WriteObjectRefStreamRequest request, - rpc::WriteObjectRefStreamReply *reply, - rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "SANG-TODO HandleWriteObjectRefStream"; - task_manager_->HandleIntermediateResult(request); +void CoreWorker::HandleReportIntermediateTaskReturn( + rpc::ReportIntermediateTaskReturnRequest request, + rpc::ReportIntermediateTaskReturnReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "SANG-TODO HandleReportIntermediateTaskReturn"; + task_manager_->HandleReportIntermediateTaskReturn(request); send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 05b05fc53214..713df90f7cd9 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -355,11 +355,13 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { NodeID GetCurrentNodeId() const { return NodeID::FromBinary(rpc_address_.raylet_id()); } // SANG-TODO Update the docstring. - void DelGenerator(const ObjectID &generator_id); + void DelObjectRefStream(const ObjectID &generator_id); + + void CreateObjectRefStream(const ObjectID &generator_id); // SANG-TODO Update the docstring. - Status GetNextObjectRef(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out); + Status AsyncReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out); const PlacementGroupID &GetCurrentPlacementGroupId() const { return worker_context_.GetCurrentPlacementGroupId(); @@ -705,9 +707,31 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Trigger garbage collection on each worker in the cluster. void TriggerGlobalGC(); - /// SANG-TODO Update the docstring. - /// SANG-TODO Support close separately. - Status ObjectRefStreamWrite( + /// Report the task caller at caller_address that the intermediate + /// task return. It means if this API is used, the caller will be notified + /// the task return before the current task is terminated. The caller must + /// implement HandleReportIntermediateTaskReturn API endpoint + /// to handle the intermediate result report. + /// This API makes sense only for a generator task + /// (task that can return multiple intermediate + /// result before the task terminates). + /// + /// NOTE: The API doesn't guarantee the ordering of the report. The + /// caller is supposed to reorder the report based on the idx. + /// + /// \param[in] dynamic_return_object A intermediate ray object to report + /// to the caller before the task terminates. This object must have been + /// created dynamically from this worker via AllocateReturnObject. + /// \param[in] generator_id The return object ref ID from a current generator + /// task. + /// \param[in] caller_address The address of the caller of the current task + /// that created a generator_id. + /// \param[in] idx The index of the task return. It is used to reorder the + /// report from the caller side. + /// \param[in] finished True indicates there's going to be no more intermediate + /// task return. When finished is provided dynamic_return_object input will be + /// ignored. + Status ReportIntermediateTaskReturn( const std::pair> &dynamic_return_object, const ObjectID &generator_id, const rpc::Address &caller_address, @@ -1055,9 +1079,10 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::SendReplyCallback send_reply_callback) override; /// Implements gRPC server handler. - void HandleWriteObjectRefStream(rpc::WriteObjectRefStreamRequest request, - rpc::WriteObjectRefStreamReply *reply, - rpc::SendReplyCallback send_reply_callback) override; + void HandleReportIntermediateTaskReturn( + rpc::ReportIntermediateTaskReturnRequest request, + rpc::ReportIntermediateTaskReturnReply *reply, + rpc::SendReplyCallback send_reply_callback) override; /// Implements gRPC server handler. void HandleKillActor(rpc::KillActorRequest request, diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 3aa7cce0a044..b50d816ed8fa 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -239,17 +239,14 @@ void ReferenceCounter::AddDynamicReturn(const ObjectID &object_id, AddNestedObjectIdsInternal(generator_id, {object_id}, owner_address); } -void ReferenceCounter::AddStreamingDynamicReturn(const ObjectID &object_id, - const ObjectID &generator_id) { +void ReferenceCounter::AddIntermediatelyReporteDynamicReturnRef( + const ObjectID &object_id, const ObjectID &generator_id) { absl::MutexLock lock(&mutex_); auto outer_it = object_id_refs_.find(generator_id); if (outer_it == object_id_refs_.end()) { - // Outer object already went out of scope. Either: - // 1. The inner object was never deserialized and has already gone out of - // scope. - // 2. The inner object was deserialized and we already added it as a - // dynamic return. - // Either way, we shouldn't add the inner object to the ref count. + // Generator object already went out of scope. + // It means the generator is already GC'ed. No need to + // update the reference. return; } RAY_LOG(DEBUG) << "Adding dynamic return " << object_id @@ -257,6 +254,8 @@ void ReferenceCounter::AddStreamingDynamicReturn(const ObjectID &object_id, RAY_CHECK(outer_it->second.owned_by_us); RAY_CHECK(outer_it->second.owner_address.has_value()); rpc::Address owner_address(outer_it->second.owner_address.value()); + // We add a local reference here. The ref removal will be handled + // by the ObjectRefStream. RAY_UNUSED(AddOwnedObjectInternal(object_id, {}, owner_address, @@ -265,6 +264,8 @@ void ReferenceCounter::AddStreamingDynamicReturn(const ObjectID &object_id, outer_it->second.is_reconstructable, /*add_local_ref=*/true, absl::optional())); + // When the intermediate object ref is reported, it means the + // object is already created. UpdateObjectPendingCreation(object_id, false); } diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 4b4a597595a2..69e3a269d6df 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -201,8 +201,14 @@ class ReferenceCounter : public ReferenceCounterInterface, void AddDynamicReturn(const ObjectID &object_id, const ObjectID &generator_id) LOCKS_EXCLUDED(mutex_); - // SANG-TODO Update the docstring. - void AddStreamingDynamicReturn(const ObjectID &object_id, const ObjectID &generator_id) + /// Add a owned object that was dynamically created and reported intermediately. + /// These are objects that were created by a task that we called, but that we own. + /// + /// + /// \param[in] object_id The ID of the object that we now own. + /// \param[in] generator_id The Object ID of the streaming generator task. + void AddIntermediatelyReporteDynamicReturnRef(const ObjectID &object_id, + const ObjectID &generator_id) LOCKS_EXCLUDED(mutex_); /// Update the size of the object. diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 3c0ade09278b..4147b5874950 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -30,6 +30,50 @@ const int64_t kTaskFailureThrottlingThreshold = 50; // Throttle task failure logs to once this interval. const int64_t kTaskFailureLoggingFrequencyMillis = 5000; +Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { + bool is_eof_set = last_ != -1; + if (is_eof_set && curr_ >= last_) { + RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ + << "has no more objects."; + return Status::KeyError("Finished"); + } + + auto it = idx_to_refs_.find(curr_); + if (it != idx_to_refs_.end()) { + // If the current index has been written, + // return the object ref. + // The returned object ref will always have a ref count of 1. + // The caller of this API is supposed to remove the reference + // when the obtained object id goes out of scope. + *object_id_out = it->second; + curr_ += 1; + RAY_LOG(DEBUG) << "SANG-TODO Get the next object id " << *object_id_out + << " generator id: " << generator_id_; + } else { + // If the current index hasn't been written, return nothing. + // The caller is supposed to retry. + RAY_LOG(DEBUG) << "SANG-TODO Object not available. Current index: " << curr_ + << " last: " << last_ << " generator id: " << generator_id_; + *object_id_out = ObjectID::Nil(); + } + return Status::OK(); +} + +bool ObjectRefStream::Write(const ObjectID &object_id, int64_t idx) { + if (last_ != -1) { + RAY_CHECK(curr_ < last_); + } + + if (idx < curr_) { + return false; + } + + idx_to_refs_.emplace(idx, object_id); + return true; +} + +void ObjectRefStream::WriteEoF(int64_t idx) { last_ = idx; } + std::vector TaskManager::AddPendingTask( const rpc::Address &caller_address, const TaskSpecification &spec, @@ -300,101 +344,106 @@ bool TaskManager::HandleTaskReturn(const ObjectID &object_id, return direct_return; } -void TaskManager::DelGenerator(const ObjectID &generator_id) { +void TaskManager::CreateObjectRefStream(const ObjectID &generator_id) { + absl::MutexLock lock(&mu_); + auto it = object_ref_streams_.find(generator_id); + RAY_CHECK(it == object_ref_streams_.end()) + << "CreateObjectRefStream can be called only once. The caller of the API should " + "guarantee the API is not called twice."; + object_ref_streams_.emplace(generator_id, ObjectRefStream(generator_id)); +} + +void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { + RAY_LOG(DEBUG) << "Deleting the object ref stream of an id " << generator_id; while (true) { ObjectID object_id; - const auto &status = GetNextObjectRef(generator_id, &object_id); - // SANG-TODO We should remove a reference. Need a test. + const auto &status = AsyncReadObjectRefStream(generator_id, &object_id); + + // keyError means the stream reaches to EoF. if (status.IsKeyError()) { break; } + if (object_id == ObjectID::Nil()) { break; + } else { + std::vector deleted; + reference_counter_->RemoveLocalReference(object_id, &deleted); + RAY_CHECK_EQ(deleted.size(), 1); } - RAY_LOG(DEBUG) << "SANG-TODO DelGenerator Get Next"; } - RAY_LOG(DEBUG) << "SANG-TODO Delete generator from " << generator_id; + + absl::MutexLock lock(&mu_); + object_ref_streams_.erase(generator_id); } -Status TaskManager::GetNextObjectRef(const ObjectID &generator_id, - ObjectID *object_id_out) { +Status TaskManager::AsyncReadObjectRefStream(const ObjectID &generator_id, + ObjectID *object_id_out) { absl::MutexLock lock(&mu_); RAY_CHECK(object_id_out != nullptr); - auto it = dynamic_ids_from_generator_.find(generator_id); - if (it == dynamic_ids_from_generator_.end()) { - RAY_LOG(DEBUG) << "SANG-TODO Generator already GC'ed " << *object_id_out - << " generator id: " << generator_id; - *object_id_out = ObjectID::Nil(); - return Status::OK(); - } - auto &reader = dynamic_ids_from_generator_[generator_id]; - if (reader.last != -1 && reader.curr >= reader.last) { - RAY_LOG(DEBUG) << "SANG-TODO Generator has no more objects " << generator_id; - return Status::KeyError("Finished"); - } - auto reader_it = reader.idx_to_refs.find(reader.curr); - if (reader_it != reader.idx_to_refs.end()) { - *object_id_out = reader_it->second; - reader.idx_to_refs.erase(reader.curr); - reader.curr += 1; - RAY_LOG(DEBUG) << "SANG-TODO Get the next object id " << *object_id_out - << " generator id: " << generator_id; - } else { - RAY_LOG(DEBUG) << "SANG-TODO Object not available. Current index: " << reader.curr - << " last: " << reader.last << " generator id: " << generator_id; - *object_id_out = ObjectID::Nil(); - } - return Status::OK(); + auto it = object_ref_streams_.find(generator_id); + RAY_CHECK(it != object_ref_streams_.end()) + << "AsyncReadObjectRefStream API can be used only when the stream has been created " + "and not removed."; + auto &stream = it->second; + + const auto &status = stream.AsyncReadNext(object_id_out); + return status; +} + +bool TaskManager::ObjectRefStreamExists(const ObjectID &generator_id) { + absl::MutexLock lock(&mu_); + auto it = object_ref_streams_.find(generator_id); + return it != object_ref_streams_.end(); } -void TaskManager::HandleIntermediateResult( - const rpc::WriteObjectRefStreamRequest &request) { +void TaskManager::HandleReportIntermediateTaskReturn( + const rpc::ReportIntermediateTaskReturnRequest &request) { const auto &generator_id = ObjectID::FromBinary(request.generator_id()); const auto &task_id = generator_id.TaskId(); int64_t idx = request.idx(); // Every generated object has the same task id. - RAY_LOG(DEBUG) << "SANG-TODO Received an intermediate result of index " << request.idx() + RAY_LOG(DEBUG) << "SANG-TODO Received an intermediate result of index " << idx << " generator_id: " << generator_id; - { + if (request.finished()) { absl::MutexLock lock(&mu_); - if (request.finished()) { - RAY_LOG(DEBUG) << "SANG-TODO Finished with an index " << request.idx(); - auto &reader = dynamic_ids_from_generator_[generator_id]; - reader.last = request.idx(); - RAY_CHECK(request.dynamic_return_objects_size() == 0); + RAY_LOG(DEBUG) << "SANG-TODO Finished with an index " << idx; + auto it = object_ref_streams_.find(generator_id); + if (it != object_ref_streams_.end()) { + it->second.WriteEoF(idx); } + // The last report should not have any return objects. + RAY_CHECK(request.dynamic_return_objects_size() == 0); + return; } + // Handle the intermediate values. + // NOTE: Until we support the retry, this is always empty return value. const auto store_in_plasma_ids = GetTaskReturnObjectsToStoreInPlasma(task_id); + // TODO(sang): Support the regular return values as well. for (const auto &return_object : request.dynamic_return_objects()) { const auto object_id = ObjectID::FromBinary(return_object.object_id()); RAY_LOG(DEBUG) << "SANG-TODO Add an object " << object_id; - int64_t curr; + bool is_written_to_stream = false; { absl::MutexLock lock(&mu_); - auto &reader = dynamic_ids_from_generator_[generator_id]; - curr = reader.curr; - if (idx >= curr) { - reader.idx_to_refs.emplace(idx, object_id); - // TODO(sang): Add it when retry is supported. - // auto it = submissible_tasks_.find(task_id); - // if (it != submissible_tasks_.end()) { - // // NOTE(sang): This is a hack to modify immutable field. - // // It is possible because most of attributes under - // // TaskSpecification is a pointer to the protobuf message. - // TaskSpecification spec; - // spec = it->second.spec; - // spec.AddDynamicReturnId(object_id); - // it->second.reconstructable_return_ids.insert(object_id); - // } + auto it = object_ref_streams_.find(generator_id); + if (it != object_ref_streams_.end()) { + is_written_to_stream = it->second.Write(object_id, idx); } + // TODO(sang): Update the reconstruct ids and task spec + // when we support retry. } + + // If the ref was written to a stream, we should also + // update the ref count accordingly. // If we call this method while holding a lock, it can deadlock. - if (idx >= curr) { - reference_counter_->AddStreamingDynamicReturn(object_id, generator_id); + if (is_written_to_stream) { + reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, + generator_id); } HandleTaskReturn(object_id, return_object, diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index bf4561db8d17..3534a06fba28 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -37,13 +37,17 @@ class TaskFinisherInterface { const rpc::Address &actor_addr, bool is_application_error) = 0; - virtual void HandleIntermediateResult( - const rpc::WriteObjectRefStreamRequest &request) = 0; + virtual void HandleReportIntermediateTaskReturn( + const rpc::ReportIntermediateTaskReturnRequest &request) = 0; - virtual void DelGenerator(const ObjectID &generator_id) = 0; + virtual void DelObjectRefStream(const ObjectID &generator_id) = 0; - virtual Status GetNextObjectRef(const ObjectID &generator_id, - ObjectID *object_id_out) = 0; + virtual void CreateObjectRefStream(const ObjectID &generator_id) = 0; + + virtual bool ObjectRefStreamExists(const ObjectID &generator_id) = 0; + + virtual Status AsyncReadObjectRefStream(const ObjectID &generator_id, + ObjectID *object_id_out) = 0; virtual bool RetryTaskIfPossible(const TaskID &task_id, const rpc::RayErrorInfo &error_info) = 0; @@ -95,10 +99,44 @@ using PushErrorCallback = std::function; -struct ObjectRefStreamReader { - absl::flat_hash_map idx_to_refs; - int64_t last = -1; - int64_t curr = 0; +/// When the streaming generator tasks are submitted, +/// the intermediate return objects are streamed +/// back to the task manager. +/// This class manages the references of intermediately +/// streamed object references. +/// The API is not thread-safe. +class ObjectRefStream { + public: + ObjectRefStream(const ObjectID &generator_id) : generator_id_(generator_id) {} + + /// Asynchronously read object reference of the next index. + /// + /// \param[out] object_id_out The next object ID from the stream. + /// Nil ID is returned if the next index hasn't been written. + /// \return KeyError if it reaches to EoF. Ok otherwise. + Status AsyncReadNext(ObjectID *object_id_out); + + /// Write the object id to the stream of an index idx. + /// + /// \param[in] The object id that will be read at index idx. + /// \param[in] The index where the object id will be written. + bool Write(const ObjectID &object_id, int64_t idx); + + /// Mark the stream canont be used anymore. + void WriteEoF(int64_t idx); + + private: + const ObjectID generator_id_; + + /// The index -> object reference ids. + absl::flat_hash_map idx_to_refs_; + /// The last index of the stream. + /// idx < last will contain object references. + /// If -1, that means the stream hasn't reached to EoF. + int64_t last_ = -1; + /// The current index of the stream. + /// If curr_ == last_, that means it is EoF. + int64_t curr_ = 0; }; class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterface { @@ -181,14 +219,37 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa const rpc::Address &worker_addr, bool is_application_error) override; - // SANG-TODO Docstring + change the method. - void HandleIntermediateResult(const rpc::WriteObjectRefStreamRequest &request) override; + /// Handle the task return reported before the task terminates. + /// + void HandleReportIntermediateTaskReturn( + const rpc::ReportIntermediateTaskReturnRequest &request) override; - // SANG-TODO Docstring + change the method. - void DelGenerator(const ObjectID &generator_id) override; + /// Delete the object ref stream. + /// Once the stream is deleted, it will clean up all unconsumed + /// object references, and all the future intermediate report + /// will be ignored. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + void DelObjectRefStream(const ObjectID &generator_id) override; + + /// Create the object ref stream. + /// If the object ref stream is not created by this API, + /// all object ref stream operation will be no-op. + /// Once the stream is created, it has to be deleted + /// by DelObjectRefStream when it is not used anymore. + /// The API is not idempotent. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + void CreateObjectRefStream(const ObjectID &generator_id) override; + + /// Return true if the object ref stream exists. + bool ObjectRefStreamExists(const ObjectID &generator_id) override; // SANG-TODO Docstring + change the method. - Status GetNextObjectRef(const ObjectID &generator_id, ObjectID *object_id_out) override; + Status AsyncReadObjectRefStream(const ObjectID &generator_id, + ObjectID *object_id_out) override; /// Returns true if task can be retried. /// @@ -538,6 +599,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// submitted tasks (dependencies and return objects). std::shared_ptr reference_counter_; + /// Mapping from a streaming generator task id -> object ref stream. + absl::flat_hash_map object_ref_streams_ GUARDED_BY(mu_); + /// Callback to store objects in plasma. This is used for objects that were /// originally stored in plasma. During reconstruction, we ensure that these /// objects get stored in plasma again so that any reference holders can @@ -585,10 +649,6 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// error). worker::TaskEventBuffer &task_event_buffer_; - // SANG-TODO Docstring + change the name. - absl::flat_hash_map dynamic_ids_from_generator_ - GUARDED_BY(mu_); - friend class TaskManagerTest; }; diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index a0ac0832185a..605c3a5460de 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -382,15 +382,26 @@ message RayletNotifyGCSRestartRequest {} message RayletNotifyGCSRestartReply {} -message WriteObjectRefStreamRequest { +message ReportIntermediateTaskReturnRequest { + // The intermediate return object that's dynamically + // generated from the executor side. repeated ReturnObject dynamic_return_objects = 1; + // The address of the executor. Address worker_addr = 2; + // The index of the task return. It is used to + // reorder the intermediate return object + // because the ordering of this request + // is not guaranteed. int64 idx = 3; + // If true, it means there's going to be no more + // task return after this request. bool finished = 4; + // The object ref id of the executor task that + // generates intermediate results. bytes generator_id = 5; } -message WriteObjectRefStreamReply {} +message ReportIntermediateTaskReturnReply {} service CoreWorkerService { // Notify core worker GCS has restarted. @@ -412,8 +423,8 @@ service CoreWorkerService { /// It is replied once there are batch of objects that need to be published to /// the caller (subscriber). rpc PubsubLongPolling(PubsubLongPollingRequest) returns (PubsubLongPollingReply); - // SANG-TODO Write a docstring and change the RPC name. - rpc WriteObjectRefStream(WriteObjectRefStreamRequest) returns (WriteObjectRefStreamReply); + // The RPC to report the intermediate task return to the caller. + rpc ReportIntermediateTaskReturn(ReportIntermediateTaskReturnRequest) returns (ReportIntermediateTaskReturnReply); /// The pubsub command batch request used by the subscriber. rpc PubsubCommandBatch(PubsubCommandBatchRequest) returns (PubsubCommandBatchReply); // Update the batched object location information to the ownership-based object diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index c950f9ad56ff..3b7caa1592f2 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -154,9 +154,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { const GetObjectLocationsOwnerRequest &request, const ClientCallback &callback) {} - virtual void WriteObjectRefStream( - const WriteObjectRefStreamRequest &request, - const ClientCallback &callback) {} + virtual void ReportIntermediateTaskReturn( + const ReportIntermediateTaskReturnRequest &request, + const ClientCallback &callback) {} /// Tell this actor to exit immediately. virtual void KillActor(const KillActorRequest &request, @@ -288,7 +288,7 @@ class CoreWorkerClient : public std::enable_shared_from_this, override) VOID_RPC_CLIENT_METHOD(CoreWorkerService, - WriteObjectRefStream, + ReportIntermediateTaskReturn, grpc_client_, /*method_timeout_ms*/ -1, override) diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index c9dc97967edd..c41486fb4af8 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -44,7 +44,7 @@ namespace rpc { RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ CoreWorkerService, GetObjectLocationsOwner, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, WriteObjectRefStream, -1) \ + CoreWorkerService, ReportIntermediateTaskReturn, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, KillActor, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, CancelTask, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, RemoteCancelTask, -1) \ @@ -70,7 +70,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubCommandBatch) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(UpdateObjectLocationBatch) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectLocationsOwner) \ - DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(WriteObjectRefStream) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(ReportIntermediateTaskReturn) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RemoteCancelTask) \ From b83af80c215582f8ce218fcbaa0b92b744533470 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 12 May 2023 18:26:08 -0700 Subject: [PATCH 04/77] fix cpp error Signed-off-by: SangBin Cho --- cpp/src/ray/runtime/task/task_executor.cc | 2 +- cpp/src/ray/runtime/task/task_executor.h | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index ca4aae05fd7e..9c9d131acbdf 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -135,7 +135,7 @@ Status TaskExecutor::ExecuteTask( std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt) { + bool is_reattempt.bool is_streaming_generator) { RAY_LOG(DEBUG) << "Execute task type: " << TaskType_Name(task_type) << " name:" << task_name; RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index 4ce2f6009e7e..4ec3df555de9 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -91,7 +91,8 @@ class TaskExecutor { std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt); + bool is_reattempt, + bool is_streaming_generator); virtual ~TaskExecutor(){}; From 509b3114449708459c26faa191f00d70228e9a5d Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 May 2023 05:13:42 -0700 Subject: [PATCH 05/77] working now. Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 30 ++++++------ python/ray/exceptions.py | 2 +- python/ray/includes/common.pxd | 4 ++ src/ray/common/status.h | 8 +++- src/ray/core_worker/core_worker.cc | 25 +++++----- src/ray/core_worker/core_worker.h | 46 +++++++++++++----- src/ray/core_worker/reference_count.cc | 23 +++++---- src/ray/core_worker/reference_count.h | 26 +++++++++-- src/ray/core_worker/task_manager.cc | 64 +++++++++++++++----------- src/ray/core_worker/task_manager.h | 44 +++++++++--------- 10 files changed, 169 insertions(+), 103 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index d4983abbbba5..a977ef744478 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -131,7 +131,7 @@ from ray.exceptions import ( AsyncioActorExit, PendingCallsLimitExceeded, RpcError, - RayKeyError, + ObjectRefStreamEoFError, ) from ray._private import external_storage from ray.util.scheduling_strategies import ( @@ -203,7 +203,7 @@ class StreamingObjectRefGenerator: self._generator_task_exception = None self.worker = worker assert hasattr(worker, "core_worker") - worker.core_worker.create_generator(self._generator_ref) + worker.core_worker.create_object_ref_stream(self._generator_ref) def __iter__(self): return self @@ -270,19 +270,20 @@ class StreamingObjectRefGenerator: try: worker = ray._private.worker.global_worker if hasattr(worker, "core_worker"): - obj = worker.core_worker.generator_get_next(self._generator_ref) + obj = worker.core_worker.async_read_object_ref_stream( + self._generator_ref) return obj else: raise ValueError( "Cannot access the core worker. " "Did you already shutdown Ray via ray.shutdown()?") - except RayKeyError: + except ObjectRefStreamEoFError: raise StopIteration def __del__(self): worker = ray._private.worker.global_worker if hasattr(worker, "core_worker"): - worker.core_worker.delete_generator(self._generator_ref) + worker.core_worker.delete_object_ref_stream(self._generator_ref) def __getstate__(self): raise TypeError( @@ -301,9 +302,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsOutOfDisk(): raise OutOfDiskError(message) - # SANG-TODO Use a different error NotFound - elif status.IsKeyError(): - raise RayKeyError(message) + elif status.IsObjectRefStreamEoF(): + raise ObjectRefStreamEoFError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() elif status.IsTimedOut(): @@ -874,7 +874,9 @@ cdef execute_streaming_generator( &intermediate_result, caller_address, generator_id) - # print("SANG-TODO Writes an index ", i) + logger.debug( + "Writes to a ObjectRefStream of an " + "index {}".format(generator_index)) assert intermediate_result.size() == 1 del output @@ -892,7 +894,9 @@ cdef execute_streaming_generator( # All the intermediate result has to be popped and reported. assert intermediate_result.size() == 0 # Report the owner that there's no more objects. - # print("SANG-TODO Closes an index ", i) + logger.debug( + "Writes EoF to a ObjectRefStream " + "of an index {}".format(generator_index)) CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( c_pair[CObjectID, shared_ptr[CRayObject]]( CObjectID.Nil(), shared_ptr[CRayObject]()), @@ -3328,19 +3332,19 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) - def create_generator(self, ObjectRef generator_id): + def create_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() CCoreWorkerProcess.GetCoreWorker().CreateObjectRefStream(c_generator_id) - def delete_generator(self, ObjectRef generator_id): + def delete_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() CCoreWorkerProcess.GetCoreWorker().DelObjectRefStream(c_generator_id) - def generator_get_next(self, ObjectRef generator_id): + def async_read_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() CObjectReference c_object_ref diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 945661d0a96e..dd97806fecaf 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -336,7 +336,7 @@ def __str__(self): return error_msg -class RayKeyError(RayError): +class ObjectRefStreamEoFError(RayError): pass diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index e0f8b8ee9712..3c5640eae62a 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -99,6 +99,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: @staticmethod CRayStatus NotFound() + @staticmethod + CRayStatus ObjectRefStreamEoF() + c_bool ok() c_bool IsOutOfMemory() c_bool IsKeyError() @@ -118,6 +121,7 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: c_bool IsObjectUnknownOwner() c_bool IsRpcError() c_bool IsOutOfResource() + c_bool IsObjectRefStreamEoF() c_string ToString() c_string CodeAsString() diff --git a/src/ray/common/status.h b/src/ray/common/status.h index bda9860ddc4a..25d9befdfd08 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -114,7 +114,8 @@ enum class StatusCode : char { OutOfDisk = 28, ObjectUnknownOwner = 29, RpcError = 30, - OutOfResource = 31 + OutOfResource = 31, + ObjectRefStreamEoF = 32 }; #if defined(__clang__) @@ -146,6 +147,10 @@ class RAY_EXPORT Status { return Status(StatusCode::KeyError, msg); } + static Status ObjectRefStreamEoF(const std::string &msg) { + return Status(StatusCode::ObjectRefStreamEoF, msg); + } + static Status TypeError(const std::string &msg) { return Status(StatusCode::TypeError, msg); } @@ -254,6 +259,7 @@ class RAY_EXPORT Status { bool IsOutOfMemory() const { return code() == StatusCode::OutOfMemory; } bool IsOutOfDisk() const { return code() == StatusCode::OutOfDisk; } bool IsKeyError() const { return code() == StatusCode::KeyError; } + bool IsObjectRefStreamEoF() const { return code() == StatusCode::ObjectRefStreamEoF; } bool IsInvalid() const { return code() == StatusCode::Invalid; } bool IsIOError() const { return code() == StatusCode::IOError; } bool IsTypeError() const { return code() == StatusCode::TypeError; } diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 6c8899d226de..7674de813d4c 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1663,7 +1663,7 @@ Status CoreWorker::ReportIntermediateTaskReturn( const rpc::Address &caller_address, int64_t idx, bool finished) { - RAY_LOG(DEBUG) << "SANG-TODO Write the object ref stream, index: " << idx + RAY_LOG(DEBUG) << "Write the object ref stream, index: " << idx << " finished: " << finished << ", id: " << dynamic_return_object.first; rpc::ReportIntermediateTaskReturnRequest request; request.mutable_worker_addr()->CopyFrom(rpc_address_); @@ -1672,9 +1672,8 @@ Status CoreWorker::ReportIntermediateTaskReturn( request.set_generator_id(generator_id.Binary()); auto client = core_worker_client_pool_->GetOrConnect(caller_address); - // Object id is nil if it is the close operations. - // SANG-TODO Support a separate endpoint Close. if (!dynamic_return_object.first.IsNil()) { + RAY_CHECK_EQ(finished, false); auto return_object_proto = request.add_dynamic_return_objects(); SerializeReturnObject( dynamic_return_object.first, dynamic_return_object.second, return_object_proto); @@ -1688,10 +1687,9 @@ Status CoreWorker::ReportIntermediateTaskReturn( client->ReportIntermediateTaskReturn( request, [](const Status &status, const rpc::ReportIntermediateTaskReturnReply &reply) { - if (status.ok()) { - RAY_LOG(DEBUG) << "SANG-TODO Succeeded to send the object ref"; - } else { - RAY_LOG(DEBUG) << "SANG-TODO Failed to send the object ref"; + if (!status.ok()) { + // TODO(sang): Handle network error more gracefully. + RAY_LOG(ERROR) << "Failed to send the object ref."; } }); return Status::OK(); @@ -3251,8 +3249,8 @@ void CoreWorker::ProcessSubscribeForObjectEviction( const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); if (task_manager_->ObjectRefStreamExists(generator_id)) { - reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, - generator_id); + reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, + generator_id); } else { reference_counter_->AddDynamicReturn(object_id, generator_id); } @@ -3390,8 +3388,8 @@ void CoreWorker::AddSpilledObjectLocationOwner( // know that it exists. RAY_CHECK(!generator_id->IsNil()); if (task_manager_->ObjectRefStreamExists(*generator_id)) { - reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, - *generator_id); + reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, + *generator_id); } else { reference_counter_->AddDynamicReturn(object_id, *generator_id); } @@ -3424,8 +3422,8 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, if (!maybe_generator_id.IsNil()) { if (task_manager_->ObjectRefStreamExists(maybe_generator_id)) { // If the stream exists, it means it is a streaming generator. - reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, - maybe_generator_id); + reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, + maybe_generator_id); } else { // The task is a generator and may not have finished yet. Add the internal // ObjectID so that we can update its location. @@ -3465,7 +3463,6 @@ void CoreWorker::HandleReportIntermediateTaskReturn( rpc::ReportIntermediateTaskReturnRequest request, rpc::ReportIntermediateTaskReturnReply *reply, rpc::SendReplyCallback send_reply_callback) { - RAY_LOG(DEBUG) << "SANG-TODO HandleReportIntermediateTaskReturn"; task_manager_->HandleReportIntermediateTaskReturn(request); send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 713df90f7cd9..7ab896740690 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -354,15 +354,38 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { NodeID GetCurrentNodeId() const { return NodeID::FromBinary(rpc_address_.raylet_id()); } - // SANG-TODO Update the docstring. - void DelObjectRefStream(const ObjectID &generator_id); - + /// Create the ObjectRefStream of generator_id. + /// + /// It is a pass-through method. See TaskManager::CreateObjectRefStream + /// for details. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. void CreateObjectRefStream(const ObjectID &generator_id); - // SANG-TODO Update the docstring. + /// Read the next index of a ObjectRefStream of generator_id. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + /// \param[out] object_ref_out The ObjectReference + /// that the caller can convert to its own ObjectRef. + /// The current process is always the owner of the + /// generated ObjectReference. + /// \return Status RayKeyError if the stream reaches to EoF. + /// OK otherwise. Status AsyncReadObjectRefStream(const ObjectID &generator_id, rpc::ObjectReference *object_ref_out); + /// Delete the ObjectRefStream of generator_id + /// created by CreateObjectRefStream. + /// + /// It is a pass-through method. See TaskManager::DelObjectRefStream + /// for details. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + void DelObjectRefStream(const ObjectID &generator_id); + const PlacementGroupID &GetCurrentPlacementGroupId() const { return worker_context_.GetCurrentPlacementGroupId(); } @@ -738,6 +761,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { int64_t idx, bool finished); + /// Implements gRPC server handler. + /// If an executor can generator task return before the task is finished, + /// it invokes this endpoint via ReportIntermediateTaskReturn RPC. + void HandleReportIntermediateTaskReturn( + rpc::ReportIntermediateTaskReturnRequest request, + rpc::ReportIntermediateTaskReturnReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Get a string describing object store memory usage for debugging purposes. /// /// \return std::string The string describing memory usage. @@ -977,7 +1008,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { std::shared_ptr *return_object, const ObjectID &generator_id); - /// SANG-TODO Update the docstring. /// Dynamically allocate an object. /// /// This should be used during task execution, if the task wants to return an @@ -1078,12 +1108,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { rpc::GetObjectLocationsOwnerReply *reply, rpc::SendReplyCallback send_reply_callback) override; - /// Implements gRPC server handler. - void HandleReportIntermediateTaskReturn( - rpc::ReportIntermediateTaskReturnRequest request, - rpc::ReportIntermediateTaskReturnReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - /// Implements gRPC server handler. void HandleKillActor(rpc::KillActorRequest request, rpc::KillActorReply *reply, diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index b50d816ed8fa..854ac6f283e1 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -239,9 +239,12 @@ void ReferenceCounter::AddDynamicReturn(const ObjectID &object_id, AddNestedObjectIdsInternal(generator_id, {object_id}, owner_address); } -void ReferenceCounter::AddIntermediatelyReporteDynamicReturnRef( +void ReferenceCounter::OwnDynamicallyGeneratedStreamingTaskReturn( const ObjectID &object_id, const ObjectID &generator_id) { absl::MutexLock lock(&mutex_); + // NOTE: The upper layer (the layer that manges the object ref stream) + // should make sure the generator ref is not GC'ed when the + // auto outer_it = object_id_refs_.find(generator_id); if (outer_it == object_id_refs_.end()) { // Generator object already went out of scope. @@ -264,9 +267,6 @@ void ReferenceCounter::AddIntermediatelyReporteDynamicReturnRef( outer_it->second.is_reconstructable, /*add_local_ref=*/true, absl::optional())); - // When the intermediate object ref is reported, it means the - // object is already created. - UpdateObjectPendingCreation(object_id, false); } bool ReferenceCounter::AddOwnedObjectInternal( @@ -412,7 +412,7 @@ void ReferenceCounter::UpdateSubmittedTaskReferences( std::vector *deleted) { absl::MutexLock lock(&mutex_); for (const auto &return_id : return_ids) { - UpdateObjectPendingCreation(return_id, true); + UpdateObjectPendingCreationInternal(return_id, true); } for (const ObjectID &argument_id : argument_ids_to_add) { RAY_LOG(DEBUG) << "Increment ref count for submitted task argument " << argument_id; @@ -441,7 +441,7 @@ void ReferenceCounter::UpdateResubmittedTaskReferences( const std::vector return_ids, const std::vector &argument_ids) { absl::MutexLock lock(&mutex_); for (const auto &return_id : return_ids) { - UpdateObjectPendingCreation(return_id, true); + UpdateObjectPendingCreationInternal(return_id, true); } for (const ObjectID &argument_id : argument_ids) { auto it = object_id_refs_.find(argument_id); @@ -463,7 +463,7 @@ void ReferenceCounter::UpdateFinishedTaskReferences( std::vector *deleted) { absl::MutexLock lock(&mutex_); for (const auto &return_id : return_ids) { - UpdateObjectPendingCreation(return_id, false); + UpdateObjectPendingCreationInternal(return_id, false); } // Must merge the borrower refs before decrementing any ref counts. This is // to make sure that for serialized IDs, we increment the borrower count for @@ -1308,8 +1308,8 @@ void ReferenceCounter::RemoveObjectLocationInternal(ReferenceTable::iterator it, PushToLocationSubscribers(it); } -void ReferenceCounter::UpdateObjectPendingCreation(const ObjectID &object_id, - bool pending_creation) { +void ReferenceCounter::UpdateObjectPendingCreationInternal(const ObjectID &object_id, + bool pending_creation) { auto it = object_id_refs_.find(object_id); bool push = false; if (it != object_id_refs_.end()) { @@ -1469,6 +1469,11 @@ bool ReferenceCounter::IsObjectReconstructable(const ObjectID &object_id, return it->second.is_reconstructable; } +void ReferenceCounter::UpdateObjectReady(const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); + UpdateObjectPendingCreationInternal(object_id, /*pending_creation*/ false); +} + bool ReferenceCounter::IsObjectPendingCreation(const ObjectID &object_id) const { absl::MutexLock lock(&mutex_); auto it = object_id_refs_.find(object_id); diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index 69e3a269d6df..a0eba802bf08 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -201,14 +201,26 @@ class ReferenceCounter : public ReferenceCounterInterface, void AddDynamicReturn(const ObjectID &object_id, const ObjectID &generator_id) LOCKS_EXCLUDED(mutex_); - /// Add a owned object that was dynamically created and reported intermediately. - /// These are objects that were created by a task that we called, but that we own. + /// Own an object that the current owner (current process) dynamically created. /// + /// The API is idempotent. + /// + /// TODO(sang): This API should be merged with AddDynamicReturn when + /// we turn on streaming generator by default. + /// + /// For normal task return, the owner creates and owns the references before + /// the object values are created. However, when you dynamically create objects, + /// the owner doesn't know (i.e., own) the references until it is reported from + /// the executor side. + /// + /// This API is used to own this type of dynamically generated references. + /// The executor should ensure the objects are not GC'ed until the owner + /// registers the dynamically created references by this API. /// /// \param[in] object_id The ID of the object that we now own. /// \param[in] generator_id The Object ID of the streaming generator task. - void AddIntermediatelyReporteDynamicReturnRef(const ObjectID &object_id, - const ObjectID &generator_id) + void OwnDynamicallyGeneratedStreamingTaskReturn(const ObjectID &object_id, + const ObjectID &generator_id) LOCKS_EXCLUDED(mutex_); /// Update the size of the object. @@ -520,6 +532,9 @@ class ReferenceCounter : public ReferenceCounterInterface, /// \param[in] min_bytes_to_evict The minimum number of bytes to evict. int64_t EvictLineage(int64_t min_bytes_to_evict); + /// Update that the object is ready to be fetched. + void UpdateObjectReady(const ObjectID &object_id); + /// Whether the object is pending creation (the task that creates it is /// scheduled/executing). bool IsObjectPendingCreation(const ObjectID &object_id) const; @@ -925,7 +940,8 @@ class ReferenceCounter : public ReferenceCounterInterface, void RemoveObjectLocationInternal(ReferenceTable::iterator it, const NodeID &node_id) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - void UpdateObjectPendingCreation(const ObjectID &object_id, bool pending_creation) + void UpdateObjectPendingCreationInternal(const ObjectID &object_id, + bool pending_creation) EXCLUSIVE_LOCKS_REQUIRED(mutex_); /// Publish object locations to all subscribers. diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 4147b5874950..a700ecdf2d6d 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -35,7 +35,7 @@ Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { if (is_eof_set && curr_ >= last_) { RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ << "has no more objects."; - return Status::KeyError("Finished"); + return Status::ObjectRefStreamEoF(""); } auto it = idx_to_refs_.find(curr_); @@ -47,13 +47,14 @@ Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { // when the obtained object id goes out of scope. *object_id_out = it->second; curr_ += 1; - RAY_LOG(DEBUG) << "SANG-TODO Get the next object id " << *object_id_out - << " generator id: " << generator_id_; + RAY_LOG_EVERY_MS(DEBUG, 10000) << "Get the next object id " << *object_id_out + << " generator id: " << generator_id_; } else { // If the current index hasn't been written, return nothing. // The caller is supposed to retry. - RAY_LOG(DEBUG) << "SANG-TODO Object not available. Current index: " << curr_ - << " last: " << last_ << " generator id: " << generator_id_; + RAY_LOG_EVERY_MS(DEBUG, 10000) + << "Object not available. Current index: " << curr_ << " last: " << last_ + << " generator id: " << generator_id_; *object_id_out = ObjectID::Nil(); } return Status::OK(); @@ -65,6 +66,7 @@ bool ObjectRefStream::Write(const ObjectID &object_id, int64_t idx) { } if (idx < curr_) { + // Index is already used. Don't write it to the stream. return false; } @@ -360,7 +362,7 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { const auto &status = AsyncReadObjectRefStream(generator_id, &object_id); // keyError means the stream reaches to EoF. - if (status.IsKeyError()) { + if (status.IsObjectRefStreamEoF()) { break; } @@ -382,13 +384,11 @@ Status TaskManager::AsyncReadObjectRefStream(const ObjectID &generator_id, absl::MutexLock lock(&mu_); RAY_CHECK(object_id_out != nullptr); - auto it = object_ref_streams_.find(generator_id); - RAY_CHECK(it != object_ref_streams_.end()) + auto stream_it = object_ref_streams_.find(generator_id); + RAY_CHECK(stream_it != object_ref_streams_.end()) << "AsyncReadObjectRefStream API can be used only when the stream has been created " "and not removed."; - auto &stream = it->second; - - const auto &status = stream.AsyncReadNext(object_id_out); + const auto &status = stream_it->second.AsyncReadNext(object_id_out); return status; } @@ -404,15 +404,15 @@ void TaskManager::HandleReportIntermediateTaskReturn( const auto &task_id = generator_id.TaskId(); int64_t idx = request.idx(); // Every generated object has the same task id. - RAY_LOG(DEBUG) << "SANG-TODO Received an intermediate result of index " << idx + RAY_LOG(DEBUG) << "Received an intermediate result of index " << idx << " generator_id: " << generator_id; if (request.finished()) { absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "SANG-TODO Finished with an index " << idx; - auto it = object_ref_streams_.find(generator_id); - if (it != object_ref_streams_.end()) { - it->second.WriteEoF(idx); + RAY_LOG(DEBUG) << "Write EoF to the object ref stream. Index: " << idx; + auto stream_it = object_ref_streams_.find(generator_id); + if (stream_it != object_ref_streams_.end()) { + stream_it->second.WriteEoF(idx); } // The last report should not have any return objects. RAY_CHECK(request.dynamic_return_objects_size() == 0); @@ -426,31 +426,41 @@ void TaskManager::HandleReportIntermediateTaskReturn( // TODO(sang): Support the regular return values as well. for (const auto &return_object : request.dynamic_return_objects()) { const auto object_id = ObjectID::FromBinary(return_object.object_id()); - RAY_LOG(DEBUG) << "SANG-TODO Add an object " << object_id; - bool is_written_to_stream = false; + RAY_LOG(DEBUG) << "Write an object " << object_id + << " to the object ref stream of id " << generator_id; + bool index_not_used_yet = false; { absl::MutexLock lock(&mu_); - auto it = object_ref_streams_.find(generator_id); - if (it != object_ref_streams_.end()) { - is_written_to_stream = it->second.Write(object_id, idx); + auto stream_it = object_ref_streams_.find(generator_id); + if (stream_it != object_ref_streams_.end()) { + index_not_used_yet = stream_it->second.Write(object_id, idx); } // TODO(sang): Update the reconstruct ids and task spec // when we support retry. } // If the ref was written to a stream, we should also - // update the ref count accordingly. - // If we call this method while holding a lock, it can deadlock. - if (is_written_to_stream) { - reference_counter_->AddIntermediatelyReporteDynamicReturnRef(object_id, - generator_id); + // own the dynamically generated task return. + // NOTE: If we call this method while holding a lock, it can deadlock. + if (index_not_used_yet) { + reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, + generator_id); + // When an object is reported, the object is ready to be fetched. + // TODO(sang): It is possible this invairant is not true + // if tasks can be retried. For example, imagine the intermediate + // task return is reported after a task is resubmitted. + // It is okay now because we don't support retry yet. But when + // we support retry, we should guarantee it is not called + // after the task resubmission. We can do it by guaranteeing + // HandleReportIntermediateTaskReturn is not called after the task + // CompletePendingTask. + reference_counter_->UpdateObjectReady(object_id); } HandleTaskReturn(object_id, return_object, NodeID::FromBinary(request.worker_addr().raylet_id()), /*store_in_plasma*/ store_in_plasma_ids.count(object_id)); } - RAY_LOG(DEBUG) << "SANG-TODO Finished handling intermediate result"; } void TaskManager::CompletePendingTask(const TaskID &task_id, diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 3534a06fba28..67b7e13e1be6 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -37,18 +37,6 @@ class TaskFinisherInterface { const rpc::Address &actor_addr, bool is_application_error) = 0; - virtual void HandleReportIntermediateTaskReturn( - const rpc::ReportIntermediateTaskReturnRequest &request) = 0; - - virtual void DelObjectRefStream(const ObjectID &generator_id) = 0; - - virtual void CreateObjectRefStream(const ObjectID &generator_id) = 0; - - virtual bool ObjectRefStreamExists(const ObjectID &generator_id) = 0; - - virtual Status AsyncReadObjectRefStream(const ObjectID &generator_id, - ObjectID *object_id_out) = 0; - virtual bool RetryTaskIfPossible(const TaskID &task_id, const rpc::RayErrorInfo &error_info) = 0; @@ -118,8 +106,9 @@ class ObjectRefStream { /// Write the object id to the stream of an index idx. /// - /// \param[in] The object id that will be read at index idx. - /// \param[in] The index where the object id will be written. + /// \param[in] object_id The object id that will be read at index idx. + /// \param[in] idx The index where the object id will be written. + /// \return True if the idx hasn't been used. False otherwise. bool Write(const ObjectID &object_id, int64_t idx); /// Mark the stream canont be used anymore. @@ -220,9 +209,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa bool is_application_error) override; /// Handle the task return reported before the task terminates. - /// void HandleReportIntermediateTaskReturn( - const rpc::ReportIntermediateTaskReturnRequest &request) override; + const rpc::ReportIntermediateTaskReturnRequest &request); /// Delete the object ref stream. /// Once the stream is deleted, it will clean up all unconsumed @@ -231,7 +219,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// /// \param[in] generator_id The object ref id of the streaming /// generator task. - void DelObjectRefStream(const ObjectID &generator_id) override; + void DelObjectRefStream(const ObjectID &generator_id); /// Create the object ref stream. /// If the object ref stream is not created by this API, @@ -242,14 +230,26 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// /// \param[in] generator_id The object ref id of the streaming /// generator task. - void CreateObjectRefStream(const ObjectID &generator_id) override; + void CreateObjectRefStream(const ObjectID &generator_id); /// Return true if the object ref stream exists. - bool ObjectRefStreamExists(const ObjectID &generator_id) override; + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + bool ObjectRefStreamExists(const ObjectID &generator_id); - // SANG-TODO Docstring + change the method. - Status AsyncReadObjectRefStream(const ObjectID &generator_id, - ObjectID *object_id_out) override; + /// Asynchronously read object reference of the next index from the + /// object stream of a generator_id. + /// + /// The caller should ensure the ObjectRefStream is already created + /// via CreateObjectRefStream. + /// If it is called after the stream hasn't been created or deleted + /// it will panic. + /// + /// \param[out] object_id_out The next object ID from the stream. + /// Nil ID is returned if the next index hasn't been written. + /// \return KeyError if it reaches to EoF. Ok otherwise. + Status AsyncReadObjectRefStream(const ObjectID &generator_id, ObjectID *object_id_out); /// Returns true if task can be retried. /// From f8a90f6e6bd3bbe459a66de95f3633ef7043235d Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 May 2023 05:15:10 -0700 Subject: [PATCH 06/77] fix a bug Signed-off-by: SangBin Cho --- cpp/src/ray/runtime/task/task_executor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index 9c9d131acbdf..97d67c760279 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -135,7 +135,8 @@ Status TaskExecutor::ExecuteTask( std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt.bool is_streaming_generator) { + bool is_reattempt, + bool is_streaming_generator) { RAY_LOG(DEBUG) << "Execute task type: " << TaskType_Name(task_type) << " name:" << task_name; RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP); From 0a9169d57a2eed32241c65a9f87185a836d00381 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 09:31:58 -0700 Subject: [PATCH 07/77] Basic version finished. Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 68 ++- python/ray/tests/BUILD | 2 +- python/ray/tests/test_streaming_generator.py | 127 +++++ src/ray/core_worker/core_worker.cc | 105 ++-- src/ray/core_worker/core_worker.h | 2 + src/ray/core_worker/reference_count.cc | 4 +- src/ray/core_worker/reference_count.h | 4 +- src/ray/core_worker/task_manager.cc | 89 ++- src/ray/core_worker/task_manager.h | 14 + src/ray/core_worker/test/task_manager_test.cc | 508 +++++++++++++++++- 10 files changed, 824 insertions(+), 99 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 865e7fe45667..6762ade4578e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -201,21 +201,24 @@ class ObjectRefGenerator: class StreamingObjectRefGenerator: def __init__(self, generator_ref, worker): + # The reference to a generator task. self._generator_ref = generator_ref + # The last time generator task has completed. self._generator_task_completed_time = None + # The exception raised from a generator task. self._generator_task_exception = None + # Ray's worker class. ray._private.worker.global_worker self.worker = worker assert hasattr(worker, "core_worker") - worker.core_worker.create_object_ref_stream(self._generator_ref) + self.worker.core_worker.create_object_ref_stream(self._generator_ref) def __iter__(self): return self def __next__(self): - """Wait until the next ref is available to the - generator and return the object ref. + """Waits until a next ref is available and returns the object ref. - The API will raise StopIteration if there's no more objects + Raises StopIteration if there's no more objects to generate. The object ref will contain an exception if the task fails. @@ -223,7 +226,42 @@ class StreamingObjectRefGenerator: up to N + 1 objects (if there's a system failure, the last object will contain a system level exception). """ + return self._next() + + def _next( + self, + timeout_s: float = -1, + sleep_interval_s: float = 0.0001, + unexpected_network_failure_timeout_s: float = 30): + """Waits for timeout_s and returns the object ref if available. + + If an object is not available within the given timeout, it + returns a nil object reference. + + If -1 timeout is provided, it means it waits infinitely. + + Waiting is implemented as busy waiting. You can control + the busy waiting interval via sleep_interval_s. + + Raises StopIteration if there's no more objects + to generate. + + The object ref will contain an exception if the task fails. + When the generator task returns N objects, it can return + up to N + 1 objects (if there's a system failure, the + last object will contain a system level exception). + + Args: + timeout_s: If the next object is not ready within + this timeout, it returns the nil object ref. + sleep_interval_s: busy waiting interval. + unexpected_network_failure_timeout_s: If the + task is finished, but the next ref is not + available within this time, it will hard fail + the generator. + """ obj = self._handle_next() + last_time = time.time() # The generator ref will be None if the task succeeds. # It will contain an exception if the task fails by @@ -259,21 +297,25 @@ class StreamingObjectRefGenerator: # If all the object refs are not reported to the generator # within 30 seconds, we consider is as an unreconverable error. if self._generator_task_completed_time: - if time.time() - self._generator_task_completed_time > 30: + if (time.time() - self._generator_task_completed_time + > unexpected_network_failure_timeout_s): # It means the next wasn't reported although the task # has been terminated 30 seconds ago. + self._generator_task_exception = AssertionError assert False, "Unexpected network failure occured." + if timeout_s != -1 and time.time() - last_time > timeout_s: + return ObjectRef.nil() + # 100us busy waiting - time.sleep(0.0001) + time.sleep(sleep_interval_s) obj = self._handle_next() return obj def _handle_next(self): try: - worker = ray._private.worker.global_worker - if hasattr(worker, "core_worker"): - obj = worker.core_worker.async_read_object_ref_stream( + if hasattr(self.worker, "core_worker"): + obj = self.worker.core_worker.async_read_object_ref_stream( self._generator_ref) return obj else: @@ -284,9 +326,11 @@ class StreamingObjectRefGenerator: raise StopIteration def __del__(self): - worker = ray._private.worker.global_worker - if hasattr(worker, "core_worker"): - worker.core_worker.delete_object_ref_stream(self._generator_ref) + if hasattr(self.worker, "core_worker"): + # NOTE: This can be called multiple times + # because python doesn't guarantee __del__ is called + # only once. + self.worker.core_worker.delete_object_ref_stream(self._generator_ref) def __getstate__(self): raise TypeError( diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index f2311d1a0963..2321c1ef14e7 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -46,7 +46,7 @@ py_test_module_list( "test_gcs_fault_tolerance.py", "test_gcs_utils.py", "test_generators.py", - "test_streaming_generators.py", + "test_streaming_generator.py", "test_metrics_agent.py", "test_metrics_head.py", "test_component_failures_2.py", diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 72c07a2af658..277d8226cb50 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -4,9 +4,136 @@ import time import gc +from unittest.mock import patch, Mock + import ray from ray._private.test_utils import wait_for_condition from ray.experimental.state.api import list_objects +from ray._raylet import StreamingObjectRefGenerator +from ray.cloudpickle import dumps +from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError + + +class MockedWorker: + def __init__(self, mocked_core_worker): + self.core_worker = mocked_core_worker + + def reset_core_worker(self): + """Emulate the case ray.shutdown is called + and the core_worker instance is GC'ed. + """ + self.core_worker = None + + +@pytest.fixture +def mocked_worker(): + mocked_core_worker = Mock() + mocked_core_worker.async_read_object_ref_stream.return_value = None + mocked_core_worker.delete_object_ref_stream.return_value = None + mocked_core_worker.create_object_ref_stream.return_value = None + worker = MockedWorker(mocked_core_worker) + yield worker + + +def test_streaming_object_ref_generator_basic_unit(mocked_worker): + """ + Verify the basic case: + create a generator -> read values -> nothing more to read -> delete. + """ + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.create_object_ref_stream.assert_called() + + # Test when there's no new ref, it returns a nil. + mocked_ray_wait.return_value = [], [generator_ref] + ref = generator._next(timeout_s=0) + assert ref.is_nil() + + # When the new ref is available, next should return it. + for _ in range(3): + new_ref = ray.ObjectRef.from_random() + c.async_read_object_ref_stream.return_value = new_ref + ref = generator._next(timeout_s=0) + assert new_ref == ref + + # When async_read_object_ref_stream raises a + # ObjectRefStreamEoFError, it should raise a stop iteration. + c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + with pytest.raises(StopIteration): + ref = generator._next(timeout_s=0) + + # Make sure we cannot serialize the generator. + with pytest.raises(TypeError): + dumps(generator) + + del generator + c.delete_object_ref_stream.assert_called() + + +def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): + """ + Verify when a task is failed by a system error, + the generator ref is returned. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the worker failure happens. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.side_effect = WorkerCrashedError() + + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = generator._next(timeout_s=0) + # If the generator task fails by a systsem error, + # meaning the ref will raise an exception + # it should be returned. + print(ref) + print(generator_ref) + assert ref == generator_ref + + # Once exception is raised, it should always + # raise stopIteration regardless of what + # the ref contains now. + with pytest.raises(StopIteration): + ref = generator._next(timeout_s=0) + + +def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): + """ + Verify when a task is finished, but if the next ref is not available + on time, it raises an assertion error. + + TODO(sang): Once we move the task subimssion path to use pubsub + to guarantee the ordering, we don't need this test anymore. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the task has finished. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.return_value = None + + # If StopIteration is not raised within + # unexpected_network_failure_timeout_s second, + # it should fail. + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + assert ref == ray.ObjectRef.nil() + time.sleep(1) + with pytest.raises(AssertionError): + generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + # After that StopIteration should be raised. + with pytest.raises(StopIteration): + generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) def test_generator_basic(shutdown_only): diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 088a8ece3692..9f2d950db681 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1665,44 +1665,6 @@ void CoreWorker::TriggerGlobalGC() { }); } -Status CoreWorker::ReportIntermediateTaskReturn( - const std::pair> &dynamic_return_object, - const ObjectID &generator_id, - const rpc::Address &caller_address, - int64_t idx, - bool finished) { - RAY_LOG(DEBUG) << "Write the object ref stream, index: " << idx - << " finished: " << finished << ", id: " << dynamic_return_object.first; - rpc::ReportIntermediateTaskReturnRequest request; - request.mutable_worker_addr()->CopyFrom(rpc_address_); - request.set_idx(idx); - request.set_finished(finished); - request.set_generator_id(generator_id.Binary()); - auto client = core_worker_client_pool_->GetOrConnect(caller_address); - - if (!dynamic_return_object.first.IsNil()) { - RAY_CHECK_EQ(finished, false); - auto return_object_proto = request.add_dynamic_return_objects(); - SerializeReturnObject( - dynamic_return_object.first, dynamic_return_object.second, return_object_proto); - std::vector deleted; - ReferenceCounter::ReferenceTableProto borrowed_refs; - reference_counter_->PopAndClearLocalBorrowers( - {dynamic_return_object.first}, &borrowed_refs, &deleted); - memory_store_->Delete(deleted); - } - - client->ReportIntermediateTaskReturn( - request, - [](const Status &status, const rpc::ReportIntermediateTaskReturnReply &reply) { - if (!status.ok()) { - // TODO(sang): Handle network error more gracefully. - RAY_LOG(ERROR) << "Failed to send the object ref."; - } - }); - return Status::OK(); -} - std::string CoreWorker::MemoryUsageString() { // Currently only the Plasma store returns a debug string. return plasma_store_provider_->MemoryUsageString(); @@ -2889,6 +2851,56 @@ ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) return return_id; } +Status CoreWorker::ReportIntermediateTaskReturn( + const std::pair> &dynamic_return_object, + const ObjectID &generator_id, + const rpc::Address &caller_address, + int64_t idx, + bool finished) { + RAY_LOG(DEBUG) << "Write the object ref stream, index: " << idx + << " finished: " << finished << ", id: " << dynamic_return_object.first; + rpc::ReportIntermediateTaskReturnRequest request; + request.mutable_worker_addr()->CopyFrom(rpc_address_); + request.set_idx(idx); + request.set_finished(finished); + request.set_generator_id(generator_id.Binary()); + auto client = core_worker_client_pool_->GetOrConnect(caller_address); + + if (!dynamic_return_object.first.IsNil()) { + RAY_CHECK_EQ(finished, false); + auto return_object_proto = request.add_dynamic_return_objects(); + SerializeReturnObject( + dynamic_return_object.first, dynamic_return_object.second, return_object_proto); + std::vector deleted; + // When we allocate a dynamic return ID (AllocateDynamicReturnId), + // we borrow the object. When the object value is allocatd, the + // memory store is updated. We should clear borrowers and memory store + // here. + ReferenceCounter::ReferenceTableProto borrowed_refs; + reference_counter_->PopAndClearLocalBorrowers( + {dynamic_return_object.first}, &borrowed_refs, &deleted); + memory_store_->Delete(deleted); + } + + client->ReportIntermediateTaskReturn( + request, + [](const Status &status, const rpc::ReportIntermediateTaskReturnReply &reply) { + if (!status.ok()) { + // TODO(sang): Handle network error more gracefully. + RAY_LOG(ERROR) << "Failed to send the object ref."; + } + }); + return Status::OK(); +} + +void CoreWorker::HandleReportIntermediateTaskReturn( + rpc::ReportIntermediateTaskReturnRequest request, + rpc::ReportIntermediateTaskReturnReply *reply, + rpc::SendReplyCallback send_reply_callback) { + task_manager_->HandleReportIntermediateTaskReturn(request); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + std::vector CoreWorker::ExecuteTaskLocalMode( const TaskSpecification &task_spec, const ActorID &actor_id) { auto resource_ids = std::make_shared(); @@ -3257,8 +3269,7 @@ void CoreWorker::ProcessSubscribeForObjectEviction( const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); if (task_manager_->ObjectRefStreamExists(generator_id)) { - reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, - generator_id); + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, generator_id); } else { reference_counter_->AddDynamicReturn(object_id, generator_id); } @@ -3396,8 +3407,7 @@ void CoreWorker::AddSpilledObjectLocationOwner( // know that it exists. RAY_CHECK(!generator_id->IsNil()); if (task_manager_->ObjectRefStreamExists(*generator_id)) { - reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, - *generator_id); + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, *generator_id); } else { reference_counter_->AddDynamicReturn(object_id, *generator_id); } @@ -3430,8 +3440,7 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, if (!maybe_generator_id.IsNil()) { if (task_manager_->ObjectRefStreamExists(maybe_generator_id)) { // If the stream exists, it means it is a streaming generator. - reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, - maybe_generator_id); + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, maybe_generator_id); } else { // The task is a generator and may not have finished yet. Add the internal // ObjectID so that we can update its location. @@ -3467,14 +3476,6 @@ void CoreWorker::ProcessSubscribeObjectLocations( reference_counter_->PublishObjectLocationSnapshot(object_id); } -void CoreWorker::HandleReportIntermediateTaskReturn( - rpc::ReportIntermediateTaskReturnRequest request, - rpc::ReportIntermediateTaskReturnReply *reply, - rpc::SendReplyCallback send_reply_callback) { - task_manager_->HandleReportIntermediateTaskReturn(request); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - void CoreWorker::HandleGetObjectLocationsOwner( rpc::GetObjectLocationsOwnerRequest request, rpc::GetObjectLocationsOwnerReply *reply, diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 154bf94cbb18..1f8c725f7080 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -441,6 +441,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { std::vector deleted; reference_counter_->RemoveLocalReference(object_id, &deleted); // TOOD(ilr): better way of keeping an object from being deleted + // TODO(sang): This seems bad... We should delete the memory store + // properly from reference counter. if (!options_.is_local_mode) { memory_store_->Delete(deleted); } diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index 854ac6f283e1..ede9cd844705 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -239,8 +239,8 @@ void ReferenceCounter::AddDynamicReturn(const ObjectID &object_id, AddNestedObjectIdsInternal(generator_id, {object_id}, owner_address); } -void ReferenceCounter::OwnDynamicallyGeneratedStreamingTaskReturn( - const ObjectID &object_id, const ObjectID &generator_id) { +void ReferenceCounter::OwnDynamicStreamingTaskReturnRef(const ObjectID &object_id, + const ObjectID &generator_id) { absl::MutexLock lock(&mutex_); // NOTE: The upper layer (the layer that manges the object ref stream) // should make sure the generator ref is not GC'ed when the diff --git a/src/ray/core_worker/reference_count.h b/src/ray/core_worker/reference_count.h index a0eba802bf08..894b426a9d97 100644 --- a/src/ray/core_worker/reference_count.h +++ b/src/ray/core_worker/reference_count.h @@ -219,8 +219,8 @@ class ReferenceCounter : public ReferenceCounterInterface, /// /// \param[in] object_id The ID of the object that we now own. /// \param[in] generator_id The Object ID of the streaming generator task. - void OwnDynamicallyGeneratedStreamingTaskReturn(const ObjectID &object_id, - const ObjectID &generator_id) + void OwnDynamicStreamingTaskReturnRef(const ObjectID &object_id, + const ObjectID &generator_id) LOCKS_EXCLUDED(mutex_); /// Update the size of the object. diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index a700ecdf2d6d..f2afe87c9bb5 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -34,7 +34,8 @@ Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { bool is_eof_set = last_ != -1; if (is_eof_set && curr_ >= last_) { RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ - << "has no more objects."; + << " has no more objects."; + *object_id_out = ObjectID::Nil(); return Status::ObjectRefStreamEoF(""); } @@ -62,7 +63,7 @@ Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { bool ObjectRefStream::Write(const ObjectID &object_id, int64_t idx) { if (last_ != -1) { - RAY_CHECK(curr_ < last_); + RAY_CHECK(curr_ <= last_); } if (idx < curr_) { @@ -70,6 +71,18 @@ bool ObjectRefStream::Write(const ObjectID &object_id, int64_t idx) { return false; } + auto it = idx_to_refs_.find(idx); + if (it != idx_to_refs_.end()) { + // It means the when a task is retried it returns a different object id + // for the same index, which means the task was not deterministic. + // Fail the owner if it happens. + RAY_CHECK_EQ(object_id, it->second) + << "The task has been retried with none deterministic task return ids. Previous " + "return id: " + << it->second << ". New task return id: " << object_id + << ". It means a undeterministic task has been retried. Disable the retry " + "feature using `max_retries=0` (task) or `max_task_retries=0` (actor)."; + } idx_to_refs_.emplace(idx, object_id); return true; } @@ -347,6 +360,7 @@ bool TaskManager::HandleTaskReturn(const ObjectID &object_id, } void TaskManager::CreateObjectRefStream(const ObjectID &generator_id) { + RAY_LOG(DEBUG) << "Create an object ref stream of an id " << generator_id; absl::MutexLock lock(&mu_); auto it = object_ref_streams_.find(generator_id); RAY_CHECK(it == object_ref_streams_.end()) @@ -356,34 +370,49 @@ void TaskManager::CreateObjectRefStream(const ObjectID &generator_id) { } void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { - RAY_LOG(DEBUG) << "Deleting the object ref stream of an id " << generator_id; - while (true) { - ObjectID object_id; - const auto &status = AsyncReadObjectRefStream(generator_id, &object_id); + RAY_LOG(DEBUG) << "Deleting an object ref stream of an id " << generator_id; + std::vector object_ids_unconsumed; - // keyError means the stream reaches to EoF. - if (status.IsObjectRefStreamEoF()) { - break; + { + absl::MutexLock lock(&mu_); + auto it = object_ref_streams_.find(generator_id); + if (it == object_ref_streams_.end()) { + return; } - if (object_id == ObjectID::Nil()) { - break; - } else { - std::vector deleted; - reference_counter_->RemoveLocalReference(object_id, &deleted); - RAY_CHECK_EQ(deleted.size(), 1); + while (true) { + ObjectID object_id; + const auto &status = AsyncReadObjectRefStreamInternal(generator_id, &object_id); + + // keyError means the stream reaches to EoF. + if (status.IsObjectRefStreamEoF()) { + break; + } + + if (object_id == ObjectID::Nil()) { + // No more objects to obtain. Stop iteration. + break; + } else { + // It means the object hasn't been consumed. + // We should remove references since we have 1 reference to this object. + object_ids_unconsumed.push_back(object_id); + } } + + object_ref_streams_.erase(generator_id); } - absl::MutexLock lock(&mu_); - object_ref_streams_.erase(generator_id); + // When calling RemoveLocalReference, we shouldn't hold a lock. + for (const auto &object_id : object_ids_unconsumed) { + std::vector deleted; + reference_counter_->RemoveLocalReference(object_id, &deleted); + RAY_CHECK(deleted.size() == 1); + } } -Status TaskManager::AsyncReadObjectRefStream(const ObjectID &generator_id, - ObjectID *object_id_out) { - absl::MutexLock lock(&mu_); +Status TaskManager::AsyncReadObjectRefStreamInternal(const ObjectID &generator_id, + ObjectID *object_id_out) { RAY_CHECK(object_id_out != nullptr); - auto stream_it = object_ref_streams_.find(generator_id); RAY_CHECK(stream_it != object_ref_streams_.end()) << "AsyncReadObjectRefStream API can be used only when the stream has been created " @@ -392,6 +421,12 @@ Status TaskManager::AsyncReadObjectRefStream(const ObjectID &generator_id, return status; } +Status TaskManager::AsyncReadObjectRefStream(const ObjectID &generator_id, + ObjectID *object_id_out) { + absl::MutexLock lock(&mu_); + return AsyncReadObjectRefStreamInternal(generator_id, object_id_out); +} + bool TaskManager::ObjectRefStreamExists(const ObjectID &generator_id) { absl::MutexLock lock(&mu_); auto it = object_ref_streams_.find(generator_id); @@ -438,13 +473,11 @@ void TaskManager::HandleReportIntermediateTaskReturn( // TODO(sang): Update the reconstruct ids and task spec // when we support retry. } - // If the ref was written to a stream, we should also // own the dynamically generated task return. // NOTE: If we call this method while holding a lock, it can deadlock. if (index_not_used_yet) { - reference_counter_->OwnDynamicallyGeneratedStreamingTaskReturn(object_id, - generator_id); + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, generator_id); // When an object is reported, the object is ready to be fetched. // TODO(sang): It is possible this invairant is not true // if tasks can be retried. For example, imagine the intermediate @@ -455,11 +488,11 @@ void TaskManager::HandleReportIntermediateTaskReturn( // HandleReportIntermediateTaskReturn is not called after the task // CompletePendingTask. reference_counter_->UpdateObjectReady(object_id); + HandleTaskReturn(object_id, + return_object, + NodeID::FromBinary(request.worker_addr().raylet_id()), + /*store_in_plasma*/ store_in_plasma_ids.count(object_id)); } - HandleTaskReturn(object_id, - return_object, - NodeID::FromBinary(request.worker_addr().raylet_id()), - /*store_in_plasma*/ store_in_plasma_ids.count(object_id)); } } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 67b7e13e1be6..4a55a090c201 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -106,6 +106,10 @@ class ObjectRefStream { /// Write the object id to the stream of an index idx. /// + /// If the idx has been already read (by AsyncReadNext), + /// the write request will be ignored. If the idx has been + /// already written, it will be no-op. It doesn't override. + /// /// \param[in] object_id The object id that will be read at index idx. /// \param[in] idx The index where the object id will be written. /// \return True if the idx hasn't been used. False otherwise. @@ -213,10 +217,16 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa const rpc::ReportIntermediateTaskReturnRequest &request); /// Delete the object ref stream. + /// /// Once the stream is deleted, it will clean up all unconsumed /// object references, and all the future intermediate report /// will be ignored. /// + /// This method is idempotent. It is because the language + /// frontend often calls this method upon destructor, but + /// not every langauge guarantees the destructor is called + /// only once. + /// /// \param[in] generator_id The object ref id of the streaming /// generator task. void DelObjectRefStream(const ObjectID &generator_id); @@ -591,6 +601,10 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param task_entry Task entry for the corresponding task attempt void MarkTaskRetryOnFailed(TaskEntry &task_entry, const rpc::RayErrorInfo &error_info); + Status AsyncReadObjectRefStreamInternal(const ObjectID &generator_id, + ObjectID *object_id_out) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + /// Used to store task results. std::shared_ptr in_memory_store_; diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 94d7749f521a..0fc669be9082 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -29,7 +29,8 @@ namespace core { TaskSpecification CreateTaskHelper(uint64_t num_returns, std::vector dependencies, - bool dynamic_returns = false) { + bool dynamic_returns = false, + bool streaming_generator = false) { TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::FromRandom(JobID::FromInt(1)).Binary()); task.GetMutableMessage().set_num_returns(num_returns); @@ -41,6 +42,9 @@ TaskSpecification CreateTaskHelper(uint64_t num_returns, if (dynamic_returns) { task.GetMutableMessage().set_returns_dynamic(true); } + if (streaming_generator) { + task.GetMutableMessage().set_streaming_generator(true); + } return task; } @@ -51,6 +55,37 @@ rpc::Address GetRandomWorkerAddr() { return addr; } +rpc::ReportIntermediateTaskReturnRequest GetIntermediateTaskReturn( + int64_t idx, + bool finished, + const ObjectID &generator_id, + const ObjectID &dynamic_return_id, + std::shared_ptr data, + bool set_in_plasma) { + rpc::ReportIntermediateTaskReturnRequest request; + rpc::Address addr; + request.mutable_worker_addr()->CopyFrom(addr); + request.set_idx(idx); + request.set_finished(finished); + request.set_generator_id(generator_id.Binary()); + auto dynamic_return_object = request.add_dynamic_return_objects(); + dynamic_return_object->set_object_id(dynamic_return_id.Binary()); + dynamic_return_object->set_data(data->Data(), data->Size()); + dynamic_return_object->set_in_plasma(set_in_plasma); + return request; +} + +rpc::ReportIntermediateTaskReturnRequest GetEoFTaskReturn(int64_t idx, + const ObjectID &generator_id) { + rpc::ReportIntermediateTaskReturnRequest request; + rpc::Address addr; + request.mutable_worker_addr()->CopyFrom(addr); + request.set_idx(idx); + request.set_finished(true); + request.set_generator_id(generator_id.Binary()); + return request; +} + class MockTaskEventBuffer : public worker::TaskEventBuffer { public: MOCK_METHOD(void, @@ -73,7 +108,8 @@ class TaskManagerTest : public ::testing::Test { public: TaskManagerTest(bool lineage_pinning_enabled = false, int64_t max_lineage_bytes = 1024 * 1024 * 1024) - : addr_(GetRandomWorkerAddr()), + : lineage_pinning_enabled_(lineage_pinning_enabled), + addr_(GetRandomWorkerAddr()), publisher_(std::make_shared()), subscriber_(std::make_shared()), task_event_buffer_mock_(std::make_unique()), @@ -113,6 +149,7 @@ class TaskManagerTest : public ::testing::Test { ASSERT_EQ(manager_.total_lineage_footprint_bytes_, 0); } + bool lineage_pinning_enabled_; rpc::Address addr_; std::shared_ptr publisher_; std::shared_ptr subscriber_; @@ -1145,6 +1182,473 @@ TEST_F(TaskManagerLineageTest, TestResubmittedDynamicReturnsTaskFails) { ASSERT_EQ(stored_in_plasma.size(), 3); } +TEST_F(TaskManagerTest, TestObjectRefStreamCreateDelete) { + /** + * Test create and deletion of stream works. + * CREATE EXISTS (true) DELETE EXISTS (false) + */ + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + manager_.CreateObjectRefStream(generator_id); + ASSERT_TRUE(manager_.ObjectRefStreamExists(generator_id)); + manager_.DelObjectRefStream(generator_id); + ASSERT_FALSE(manager_.ObjectRefStreamExists(generator_id)); + // Test DelObjectRefStream is idempotent + manager_.DelObjectRefStream(generator_id); + manager_.DelObjectRefStream(generator_id); + manager_.DelObjectRefStream(generator_id); + manager_.DelObjectRefStream(generator_id); + ASSERT_FALSE(manager_.ObjectRefStreamExists(generator_id)); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamBasic) { + /** + * Test the basic cases (write -> read). + * CREATE WRITE, WRITE, WRITEEoF, READ, READ, KeyERROR DELETE + */ + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + + auto last_idx = 2; + std::vector dynamic_return_ids; + std::vector> datas; + for (auto i = 0; i < last_idx; i++) { + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), i + 2); + dynamic_return_ids.push_back(dynamic_return_id); + auto data = GenerateRandomBuffer(); + datas.push_back(data); + + auto req = GetIntermediateTaskReturn( + /*idx*/ i, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + // WRITE * 2 + manager_.HandleReportIntermediateTaskReturn(req); + } + // WRITEEoF + manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(last_idx, generator_id)); + + ObjectID obj_id; + for (auto i = 0; i < last_idx; i++) { + // READ * 2 + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_ids[i]); + } + // READ (EoF) + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_EQ(obj_id, ObjectID::Nil()); + // DELETE + manager_.DelObjectRefStream(generator_id); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamMixture) { + /** + * Test the basic cases, but write and read are mixed up. + * CREATE WRITE READ WRITE READ WRITEEoF KeyError DELETE + */ + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + + auto last_idx = 2; + std::vector dynamic_return_ids; + std::vector> datas; + for (auto i = 0; i < last_idx; i++) { + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), i + 2); + dynamic_return_ids.push_back(dynamic_return_id); + auto data = GenerateRandomBuffer(); + datas.push_back(data); + + auto req = GetIntermediateTaskReturn( + /*idx*/ i, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + // WRITE + manager_.HandleReportIntermediateTaskReturn(req); + // READ + ObjectID obj_id; + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_ids[i]); + } + // WRITEEoF + manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(last_idx, generator_id)); + + ObjectID obj_id; + // READ (EoF) + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_EQ(obj_id, ObjectID::Nil()); + // DELETE + manager_.DelObjectRefStream(generator_id); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { + /** + * Test that after writing EoF, write/read doesn't work. + * CREATE WRITE WRITEEoF, WRITE(verify no op) DELETE + */ + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + + // WRITE + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 2); + auto data = GenerateRandomBuffer(); + auto req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // WRITEEoF + manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(1, generator_id)); + // READ (works) + ObjectID obj_id; + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_id); + + // WRITE + dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 3); + data = GenerateRandomBuffer(); + req = GetIntermediateTaskReturn( + /*idx*/ 2, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // READ (doesn't works because EoF is already written) + status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.IsObjectRefStreamEoF()); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamIndexDiscarded) { + /** + * Test that when the ObjectRefStream is already written + * the WRITE will be ignored. + */ + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + + // WRITE + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 2); + auto data = GenerateRandomBuffer(); + auto req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // READ + ObjectID obj_id; + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_id); + + // WRITE to the first index again. + dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 3); + data = GenerateRandomBuffer(); + req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // READ (New write will be ignored). + status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, ObjectID::Nil()); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamReadIgnoredWhenNothingWritten) { + /** + * Test read will return Nil if nothing was written. + * CREATE READ (no op) WRITE READ (working) READ (no op) + */ + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + + // READ (no-op) + ObjectID obj_id; + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, ObjectID::Nil()); + + // WRITE + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 2); + auto data = GenerateRandomBuffer(); + auto req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // READ (works this time) + status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_id); + + // READ (nothing should return) + status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, ObjectID::Nil()); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { + /** + * Test e2e + * (task submitted -> report intermediate task return -> task finished) + * This also tests if we can read / write stream before / after task finishes. + */ + // Submit a task. + rpc::Address caller_address; + auto spec = + CreateTaskHelper(1, {}, /*dynamic_returns=*/true, /*streaming_generator*/ true); + auto generator_id = spec.ReturnId(0); + manager_.AddPendingTask(caller_address, spec, "", /*num_retries=*/0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + manager_.MarkDependenciesResolved(spec.TaskId()); + manager_.MarkTaskWaitingForExecution( + spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom()); + + // The results are reported before the task is finished. + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 2); + auto data = GenerateRandomBuffer(); + auto req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + + // NumObjectIDsInScope == Generator + intermediate result. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 2); + std::vector> results; + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + + // Make sure you can read. + ObjectID obj_id; + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_id); + + // Finish the task. + rpc::PushTaskReply reply; + auto return_object = reply.add_return_objects(); + return_object->set_object_id(generator_id.Binary()); + data = GenerateRandomBuffer(); + return_object->set_data(data->Data(), data->Size()); + manager_.CompletePendingTask(spec.TaskId(), reply, caller_address, false); + + // Test you can write to the stream after task finishes. + // TODO(sang): Make sure this doesn't happen by ensuring the ordering + // from the executor side. + auto dynamic_return_id2 = ObjectID::FromIndex(spec.TaskId(), 3); + data = GenerateRandomBuffer(); + req = GetIntermediateTaskReturn( + /*idx*/ 1, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id2, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // EoF + manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(2, generator_id)); + + // NumObjectIDsInScope == Generator + 2 intermediate result. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); + results.clear(); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + + // Make sure you can read. + status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_id2); + + // Nothing more to read. + status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.IsObjectRefStreamEoF()); + + manager_.DelObjectRefStream(generator_id); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { + /** + * Verify DEL cleans all references and ignore all future WRITE. + * + * CREATE WRITE WRITE DEL (make sure no refs are leaked) + */ + // Submit a task so that generator ID will be available + // to the reference counter. + rpc::Address caller_address; + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + manager_.AddPendingTask(caller_address, spec, "", /*num_retries=*/0); + manager_.MarkDependenciesResolved(spec.TaskId()); + manager_.MarkTaskWaitingForExecution( + spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom()); + RAY_LOG(ERROR) << "SANG-TODO 0"; + // CREATE + manager_.CreateObjectRefStream(generator_id); + // WRITE + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), 2); + auto data = GenerateRandomBuffer(); + auto req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // WRITE 2 + auto dynamic_return_id2 = ObjectID::FromIndex(spec.TaskId(), 3); + data = GenerateRandomBuffer(); + req = GetIntermediateTaskReturn( + /*idx*/ 1, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id2, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + RAY_LOG(ERROR) << "SANG-TODO 1"; + // NumObjectIDsInScope == Generator + 2 WRITE + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); + std::vector> results; + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); + RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + results.clear(); + RAY_CHECK_OK(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results)); + ASSERT_EQ(results.size(), 1); + results.clear(); + RAY_LOG(ERROR) << "SANG-TODO 2"; + // DELETE. This should clean all references except generator id. + manager_.DelObjectRefStream(generator_id); + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); + // Unfortunately, when the obj ref goes out of scope, + // this is called from the language frontend. We mimic this behavior + // by manually calling these APIs. + store_->Delete({dynamic_return_id}); + store_->Delete({dynamic_return_id2}); + ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results).IsTimedOut()); + results.clear(); + ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut()); + results.clear(); + + // NOTE: We panic if READ is called after DELETE. The + // API caller should guarantee this doesn't happen. + // So we don't test it. + RAY_LOG(ERROR) << "SANG-TODO 3"; + // WRITE 3. Should be ignored. + auto dynamic_return_id3 = ObjectID::FromIndex(spec.TaskId(), 4); + data = GenerateRandomBuffer(); + req = GetIntermediateTaskReturn( + /*idx*/ 2, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id3, + /*data*/ data, + /*set_in_plasma*/ false); + manager_.HandleReportIntermediateTaskReturn(req); + // The write should have been no op. No refs and no obj values except the generator id. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); + ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut()); + results.clear(); + RAY_LOG(ERROR) << "SANG-TODO 4"; + // Finish the task. + // This is needed to pass AssertNoLeaks. + rpc::PushTaskReply reply; + auto return_object = reply.add_return_objects(); + return_object->set_object_id(generator_id.Binary()); + data = GenerateRandomBuffer(); + return_object->set_data(data->Data(), data->Size()); + manager_.CompletePendingTask(spec.TaskId(), reply, caller_address, false); +} + +TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { + /** + * Test the case where the task return RPC is received out of order + */ + auto spec = + CreateTaskHelper(1, {}, /*dynamic_returns=*/true, /*streaming_generator*/ true); + auto generator_id = spec.ReturnId(0); + // CREATE + manager_.CreateObjectRefStream(generator_id); + + auto last_idx = 2; + std::vector dynamic_return_ids; + // EoF reported first. + manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(last_idx, generator_id)); + + // Write index 1 -> 0 + for (auto i = last_idx - 1; i > -1; i--) { + auto dynamic_return_id = ObjectID::FromIndex(spec.TaskId(), i + 2); + dynamic_return_ids.insert(dynamic_return_ids.begin(), dynamic_return_id); + auto data = GenerateRandomBuffer(); + + auto req = GetIntermediateTaskReturn( + /*idx*/ i, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id, + /*data*/ data, + /*set_in_plasma*/ false); + // WRITE * 2 + manager_.HandleReportIntermediateTaskReturn(req); + } + + // Verify read works. + ObjectID obj_id; + for (auto i = 0; i < last_idx; i++) { + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(obj_id, dynamic_return_ids[i]); + } + + // READ (EoF) + auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_EQ(obj_id, ObjectID::Nil()); + // DELETE + manager_.DelObjectRefStream(generator_id); +} + } // namespace core } // namespace ray From 05f468a3cc79b52512bb375a73bfb223ef3f8c21 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 10:07:27 -0700 Subject: [PATCH 08/77] [Please Revert] Work e2e. Signed-off-by: SangBin Cho --- .../runtime/task/local_mode_task_submitter.cc | 1 - cpp/src/ray/runtime/task/task_executor.cc | 3 +- cpp/src/ray/runtime/task/task_executor.h | 3 +- python/ray/_private/ray_option_utils.py | 2 +- python/ray/_private/worker.py | 9 - python/ray/_private/workers/default_worker.py | 1 + python/ray/_raylet.pxd | 1 - python/ray/_raylet.pyx | 232 ++------------- python/ray/actor.py | 12 +- python/ray/includes/libcoreworker.pxd | 5 +- python/ray/remote_function.py | 12 +- python/ray/tests/test_generators.py | 44 +-- python/ray/tests/test_streaming_generator.py | 268 ------------------ src/ray/common/task/task_spec.cc | 6 - src/ray/common/task/task_spec.h | 2 - src/ray/common/task/task_util.h | 2 - src/ray/core_worker/core_worker.cc | 42 +-- src/ray/core_worker/core_worker.h | 4 +- src/ray/core_worker/core_worker_options.h | 5 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 3 +- src/ray/core_worker/test/core_worker_test.cc | 1 - .../test/dependency_resolver_test.cc | 1 - .../test/direct_task_transport_test.cc | 1 - src/ray/core_worker/test/mock_worker.cc | 3 +- src/ray/gcs/test/gcs_test_util.h | 1 - src/ray/protobuf/common.proto | 6 - .../scheduling/cluster_task_manager_test.cc | 1 - 27 files changed, 49 insertions(+), 622 deletions(-) diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 6052531e1211..145e8130fe15 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -61,7 +61,6 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, address, 1, /*returns_dynamic=*/false, - /*is_streaming_generator*/ false, required_resources, required_placement_resources, "", diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index 97d67c760279..ca4aae05fd7e 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -135,8 +135,7 @@ Status TaskExecutor::ExecuteTask( std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt, - bool is_streaming_generator) { + bool is_reattempt) { RAY_LOG(DEBUG) << "Execute task type: " << TaskType_Name(task_type) << " name:" << task_name; RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index 4ec3df555de9..4ce2f6009e7e 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -91,8 +91,7 @@ class TaskExecutor { std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt, - bool is_streaming_generator); + bool is_reattempt); virtual ~TaskExecutor(){}; diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index 97c35f9449ca..f433fc3f153a 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -154,7 +154,7 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "num_returns": Option( (int, str, type(None)), lambda x: None - if (x is None or x == "dynamic" or x == "streaming" or x >= 0) + if (x is None or x == "dynamic") else "The keyword 'num_returns' only accepts None, a non-negative integer, or " '"dynamic" (for generators)', default_value=1, diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index e45c33eb3027..1bb275a2312e 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2499,11 +2499,6 @@ def get( blocking_get_inside_async_warned = True with profiling.profile("ray.get"): - # TODO(sang): Should make StreamingObjectRefGenerator - # compatible to ray.get for dataset. - if isinstance(object_refs, ray._raylet.StreamingObjectRefGenerator): - return object_refs - is_individual_id = isinstance(object_refs, ray.ObjectRef) if is_individual_id: object_refs = [object_refs] @@ -2822,10 +2817,6 @@ def cancel(object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool worker = ray._private.worker.global_worker worker.check_connected() - if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): - assert hasattr(object_ref, "_generator_ref") - object_ref = object_ref._generator_ref - if not isinstance(object_ref, ray.ObjectRef): raise TypeError( "ray.cancel() only supported for non-actor object refs. " diff --git a/python/ray/_private/workers/default_worker.py b/python/ray/_private/workers/default_worker.py index 462c9e284f49..937f45a8b85d 100644 --- a/python/ray/_private/workers/default_worker.py +++ b/python/ray/_private/workers/default_worker.py @@ -169,6 +169,7 @@ # https://github.com/ray-project/ray/pull/12225#issue-525059663. args = parser.parse_args() ray._private.ray_logging.setup_logger(args.logging_level, args.logging_format) + worker_launched_time_ms = time.time_ns() // 1e6 if args.worker_type == "WORKER": diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 28a7632ed8c1..6af1879a5d8a 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -143,7 +143,6 @@ cdef class CoreWorker: self, worker, outputs, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, - const CAddress &caller_address, CObjectID ref_generator_id=*) cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 6762ade4578e..89773c560aaf 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -741,8 +741,7 @@ cdef store_task_errors( CTaskType task_type, proctitle, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, - c_string* application_error, - const CAddress &caller_address): + c_string* application_error): cdef: CoreWorker core_worker = worker.core_worker @@ -786,8 +785,7 @@ cdef store_task_errors( errors.append(failure_object) num_errors_stored = core_worker.store_task_outputs( worker, errors, - returns, - caller_address) + returns) ray._private.utils.push_error_to_driver( worker, @@ -799,160 +797,6 @@ cdef store_task_errors( return num_errors_stored -cdef execute_streaming_generator( - generator, - const CObjectID &generator_id, - CTaskType task_type, - const CAddress &caller_address, - TaskID task_id, - const c_string &serialized_retry_exception_allowlist, - function_name, - function_descriptor, - title, - actor, - actor_id, - c_bool *is_retryable_error, - c_string *application_error): - """Execute a given generator and streaming-report the - result to the given caller_address. - - The output from the generator will be stored to the in-memory - or plasma object store. The generated return objects will be - reported to the owner of the task as soon as they are generated. - - It means when this method is used, the result of each generator - will be reported and available from the given "caller address" - before the task is finished. - - Args: - generator: The generator to run. - generator_id: The object ref id of the generator task. - task_type: The type of the task. E.g., actor task, normal task. - caller_address: The address of the caller. By our protocol, - the caller of the streaming generator task is always - the owner, so we can also call it "owner address". - task_id: The task ID of the generator task. - serialized_retry_exception_allowlist: A list of - exceptions that are allowed to retry this generator task. - function_name: The name of the generator function. Used for - writing an error message. - function_descriptor: The function descriptor of - the generator function. Used for writing an error message. - title: The process title of the generator task. Used for - writing an error message. - actor: The instance of the actor created in this worker. - It is used to write an error message. - actor_id: The ID of the actor. It is used to write an error message. - is_retryable_error(out): It is set to True if the generator - raises an exception, and the error is retryable. - application_error(out): It is set if the generator raises an - application error. - """ - worker = ray._private.worker.global_worker - cdef: - CoreWorker core_worker = worker.core_worker - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result - - generator_index = 0 - assert inspect.isgenerator(generator), ( - "execute_generator's first argument must be a generator." - ) - - while True: - try: - output = next(generator) - except StopIteration: - break - except Exception as e: - # Report the error if the generator failed to execute. - is_retryable_error[0] = determine_if_retryable( - e, - serialized_retry_exception_allowlist, - function_descriptor, - ) - - if ( - is_retryable_error[0] - and core_worker.get_current_task_retry_exceptions() - ): - logger.debug("Task failed with retryable exception:" - " {}.".format(task_id), exc_info=True) - # Raise an exception directly and halt the execution - # because there's no need to set the exception - # for the return value when the task is retryable. - raise e - - logger.debug("Task failed with unretryable exception:" - " {}.".format(task_id), exc_info=True) - - error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) - intermediate_result.push_back( - c_pair[CObjectID, shared_ptr[CRayObject]]( - error_id, shared_ptr[CRayObject]())) - - store_task_errors( - worker, e, - True, # task_exception - actor, # actor - actor_id, # actor id - function_name, task_type, title, - &intermediate_result, application_error, caller_address) - - CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( - intermediate_result.back(), - generator_id, caller_address, generator_index, False) - - if intermediate_result.size() > 0: - intermediate_result.pop_back() - generator_index += 1 - break - else: - # Report the intermediate result if there was no error. - return_id = ( - CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( - caller_address)) - intermediate_result.push_back( - c_pair[CObjectID, shared_ptr[CRayObject]]( - return_id, shared_ptr[CRayObject]())) - - core_worker.store_task_outputs( - worker, [output], - &intermediate_result, - caller_address, - generator_id) - logger.debug( - "Writes to a ObjectRefStream of an " - "index {}".format(generator_index)) - assert intermediate_result.size() == 1 - del output - - CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( - intermediate_result.back(), - generator_id, - caller_address, - generator_index, - False) - - if intermediate_result.size() > 0: - intermediate_result.pop_back() - generator_index += 1 - - # All the intermediate result has to be popped and reported. - assert intermediate_result.size() == 0 - # Report the owner that there's no more objects. - logger.debug( - "Writes EoF to a ObjectRefStream " - "of an index {}".format(generator_index)) - CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( - c_pair[CObjectID, shared_ptr[CRayObject]]( - CObjectID.Nil(), shared_ptr[CRayObject]()), - generator_id, - caller_address, - generator_index, - True) # finished. - - cdef execute_dynamic_generator_and_store_task_outputs( generator, const CObjectID &generator_id, @@ -964,8 +808,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( c_bool is_reattempt, function_name, function_descriptor, - title, - const CAddress &caller_address): + title): worker = ray._private.worker.global_worker cdef: CoreWorker core_worker = worker.core_worker @@ -974,7 +817,6 @@ cdef execute_dynamic_generator_and_store_task_outputs( core_worker.store_task_outputs( worker, generator, dynamic_returns, - caller_address, generator_id) except Exception as error: is_retryable_error[0] = determine_if_retryable( @@ -1002,7 +844,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( # generate one additional ObjectRef. This last # ObjectRef will contain the error. error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + .AllocateDynamicReturnId()) dynamic_returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -1016,7 +858,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( None, # actor None, # actor id function_name, task_type, title, - dynamic_returns, application_error, caller_address) + dynamic_returns, application_error) if num_errors_stored == 0: assert is_reattempt # TODO(swang): The generator task failed and we @@ -1053,8 +895,7 @@ cdef void execute_task( c_bool is_reattempt, execution_info, title, - task_name, - c_bool is_streaming_generator) except *: + task_name) except *: worker = ray._private.worker.global_worker manager = worker.function_actor_manager actor = None @@ -1212,35 +1053,6 @@ cdef void execute_task( ray.util.pdb.set_trace( breakpoint_uuid=debugger_breakpoint) outputs = function_executor(*args, **kwargs) - - if is_streaming_generator: - # Streaming generator always has a single return value - # which is the generator task return. - assert returns[0].size() == 1 - - if not inspect.isgenerator(outputs): - raise ValueError( - "Functions with " - "@ray.remote(num_returns=\"streaming\" " - "must return a generator") - - execute_streaming_generator( - outputs, - returns[0][0].first, # generator object ID. - task_type, - caller_address, - task_id, - serialized_retry_exception_allowlist, - function_name, - function_descriptor, - title, - actor, - actor_id, - is_retryable_error, - application_error) - # Streaming generator output is not used, so set it to None. - outputs = None - next_breakpoint = ( ray._private.worker.global_worker.debugger_breakpoint) if next_breakpoint != b"": @@ -1325,9 +1137,7 @@ cdef void execute_task( # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): - # TODO(sang): Remove it once we use streaming generator - # by default. - if dynamic_returns != NULL and not is_streaming_generator: + if dynamic_returns != NULL: if not inspect.isgenerator(outputs): raise ValueError( "Functions with " @@ -1346,8 +1156,7 @@ cdef void execute_task( is_reattempt, function_name, function_descriptor, - title, - caller_address) + title) task_exception = False dynamic_refs = [] @@ -1365,12 +1174,11 @@ cdef void execute_task( # all generator tasks, both static and dynamic. core_worker.store_task_outputs( worker, outputs, - returns, - caller_address) + returns) except Exception as e: num_errors_stored = store_task_errors( worker, e, task_exception, actor, actor_id, function_name, - task_type, title, returns, application_error, caller_address) + task_type, title, returns, application_error) if returns[0].size() > 0 and num_errors_stored == 0: logger.exception( "Unhandled error: Task threw exception, but all " @@ -1397,8 +1205,7 @@ cdef execute_task_with_cancellation_handler( # the concurrency groups of this actor. const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups, const c_string c_name_of_concurrency_group_to_execute, - c_bool is_reattempt, - c_bool is_streaming_generator): + c_bool is_reattempt): is_retryable_error[0] = False @@ -1483,8 +1290,7 @@ cdef execute_task_with_cancellation_handler( application_error, c_defined_concurrency_groups, c_name_of_concurrency_group_to_execute, - is_reattempt, execution_info, title, task_name, - is_streaming_generator) + is_reattempt, execution_info, title, task_name) # Check for cancellation. PyErr_CheckSignals() @@ -1511,8 +1317,7 @@ cdef execute_task_with_cancellation_handler( task_type, title, returns, # application_error: we are passing NULL since we don't want the # cancel tasks to fail. - NULL, - caller_address) + NULL) finally: with current_task_id_lock: current_task_id = None @@ -1557,8 +1362,7 @@ cdef CRayStatus task_execution_handler( c_string *application_error, const c_vector[CConcurrencyGroup] &defined_concurrency_groups, const c_string name_of_concurrency_group_to_execute, - c_bool is_reattempt, - c_bool is_streaming_generator) nogil: + c_bool is_reattempt) nogil: with gil, disable_client_hook(): # Initialize job_config if it hasn't already. # Setup system paths configured in job_config. @@ -1582,8 +1386,7 @@ cdef CRayStatus task_execution_handler( application_error, defined_concurrency_groups, name_of_concurrency_group_to_execute, - is_reattempt, - is_streaming_generator) + is_reattempt) except Exception as e: sys_exit = SystemExit() if isinstance(e, RayActorError) and \ @@ -3140,7 +2943,6 @@ cdef class CoreWorker: worker, outputs, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, - const CAddress &caller_address, CObjectID ref_generator_id=CObjectID.Nil()): cdef: CObjectID return_id @@ -3180,11 +2982,9 @@ cdef class CoreWorker: raise ValueError( "Task returned more than num_returns={} objects.".format( num_returns)) - # TODO(sang): Remove it when the streaming generator is - # enabled by default. while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + .AllocateDynamicReturnId()) returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) diff --git a/python/ray/actor.py b/python/ray/actor.py index 91b88de7b947..7191031e059b 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -22,7 +22,7 @@ ) from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group from ray._private.utils import get_runtime_env_info, parse_runtime_env -from ray._raylet import PythonFunctionDescriptor, StreamingObjectRefGenerator +from ray._raylet import PythonFunctionDescriptor from ray.exceptions import AsyncioActorExit from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.placement_group import _configure_placement_group_based_on_context @@ -1167,10 +1167,6 @@ def _actor_method_call( if num_returns == "dynamic": num_returns = -1 - elif num_returns == "streaming": - # TODO(sang): This is a temporary private API. - # Remove it when we migrate to the streaming generator. - num_returns = -2 object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, @@ -1183,12 +1179,6 @@ def _actor_method_call( concurrency_group_name if concurrency_group_name is not None else b"", ) - if num_returns == -2: - # Streaming generator will return a single ref - # that is for the generator task. - assert len(object_refs) == 1 - generator_ref = object_refs[0] - return StreamingObjectRefGenerator(generator_ref, worker) if len(object_refs) == 1: object_refs = object_refs[0] elif len(object_refs) == 0: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index b8a5f14f9d6b..8dac68ea651e 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -151,7 +151,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus AsyncReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) - CObjectID AllocateDynamicReturnId(const CAddress &owner_address) + CObjectID AllocateDynamicReturnId() CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() @@ -315,8 +315,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_string *application_error, const c_vector[CConcurrencyGroup] &defined_concurrency_groups, const c_string name_of_concurrency_group_to_execute, - c_bool is_reattempt, - c_bool is_streaming_generator) nogil + c_bool is_reattempt) nogil ) task_execution_callback (void(const CWorkerID &) nogil) on_worker_shutdown (CRayStatus() nogil) check_signals diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index bb627f09af92..79853deff098 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -15,7 +15,7 @@ from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group from ray._private.serialization import pickle_dumps from ray._private.utils import get_runtime_env_info, parse_runtime_env -from ray._raylet import PythonFunctionDescriptor, StreamingObjectRefGenerator +from ray._raylet import PythonFunctionDescriptor from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.placement_group import _configure_placement_group_based_on_context from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -306,10 +306,6 @@ def _remote(self, args=None, kwargs=None, **task_options): num_returns = task_options["num_returns"] if num_returns == "dynamic": num_returns = -1 - elif num_returns == "streaming": - # TODO(sang): This is a temporary private API. - # Remove it when we migrate to the streaming generator. - num_returns = -2 max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] @@ -401,12 +397,6 @@ def invocation(args, kwargs): # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). worker.debugger_breakpoint = b"" - if num_returns == -2: - # Streaming generator will return a single ref - # that is for the generator task. - assert len(object_refs) == 1 - generator_ref = object_refs[0] - return StreamingObjectRefGenerator(generator_ref, worker) if len(object_refs) == 1: return object_refs[0] elif len(object_refs) > 1: diff --git a/python/ray/tests/test_generators.py b/python/ray/tests/test_generators.py index 3430da39cda2..9284c6a3f8c3 100644 --- a/python/ray/tests/test_generators.py +++ b/python/ray/tests/test_generators.py @@ -117,10 +117,7 @@ def generator(num_returns, store_in_plasma): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) -def test_generator_errors( - ray_start_regular, use_actors, store_in_plasma, num_returns_type -): +def test_generator_errors(ray_start_regular, use_actors, store_in_plasma): remote_generator_fn = None if use_actors: @@ -161,7 +158,7 @@ def generator(num_returns, store_in_plasma): with pytest.raises(ray.exceptions.RayTaskError): ray.get(ref3) - dynamic_ref = remote_generator_fn.options(num_returns=num_returns_type).remote( + dynamic_ref = remote_generator_fn.options(num_returns="dynamic").remote( 3, store_in_plasma ) ref1, ref2 = ray.get(dynamic_ref) @@ -221,13 +218,10 @@ def generator(num_returns, store_in_plasma, counter): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["streaming"]) -def test_dynamic_generator( - ray_start_regular, use_actors, store_in_plasma, num_returns_type -): +def test_dynamic_generator(ray_start_regular, use_actors, store_in_plasma): if use_actors: - @ray.remote(num_returns=num_returns_type) + @ray.remote(num_returns="dynamic") def dynamic_generator(num_returns, store_in_plasma): for i in range(num_returns): if store_in_plasma: @@ -261,34 +255,21 @@ def read(gen): return True gen = ray.get( - remote_generator_fn.options(num_returns=num_returns_type).remote( - 10, store_in_plasma - ) + remote_generator_fn.options(num_returns="dynamic").remote(10, store_in_plasma) ) for i, ref in enumerate(gen): assert ray.get(ref)[0] == i # Test empty generator. gen = ray.get( - remote_generator_fn.options(num_returns=num_returns_type).remote( - 0, store_in_plasma - ) + remote_generator_fn.options(num_returns="dynamic").remote(0, store_in_plasma) ) assert len(list(gen)) == 0 # Check that passing as task arg. - if num_returns_type == "dynamic": - gen = remote_generator_fn.options(num_returns=num_returns_type).remote( - 10, store_in_plasma - ) - assert ray.get(read.remote(gen)) - assert ray.get(read.remote(ray.get(gen))) - else: - with pytest.raises(TypeError): - gen = remote_generator_fn.options(num_returns=num_returns_type).remote( - 10, store_in_plasma - ) - assert ray.get(read.remote(gen)) + gen = remote_generator_fn.options(num_returns="dynamic").remote(10, store_in_plasma) + assert ray.get(read.remote(gen)) + assert ray.get(read.remote(ray.get(gen))) # Also works if we override num_returns with a static value. ray.get( @@ -298,7 +279,7 @@ def read(gen): ) # Normal remote functions don't work with num_returns="dynamic". - @ray.remote(num_returns=num_returns_type) + @ray.remote(num_returns="dynamic") def static(num_returns): return list(range(num_returns)) @@ -308,8 +289,7 @@ def static(num_returns): ray.get(ref) -@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) -def test_dynamic_generator_distributed(ray_start_cluster, num_returns_type): +def test_dynamic_generator_distributed(ray_start_cluster): cluster = ray_start_cluster # Head node with no resources. cluster.add_node(num_cpus=0) @@ -317,7 +297,7 @@ def test_dynamic_generator_distributed(ray_start_cluster, num_returns_type): cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - @ray.remote(num_returns=num_returns_type) + @ray.remote(num_returns="dynamic") def dynamic_generator(num_returns): for i in range(num_returns): yield np.ones(1_000_000, dtype=np.int8) * i diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 277d8226cb50..c496d52b6179 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -1,14 +1,10 @@ import pytest -import numpy as np import sys import time -import gc from unittest.mock import patch, Mock import ray -from ray._private.test_utils import wait_for_condition -from ray.experimental.state.api import list_objects from ray._raylet import StreamingObjectRefGenerator from ray.cloudpickle import dumps from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError @@ -136,270 +132,6 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) -def test_generator_basic(shutdown_only): - ray.init(num_cpus=1) - - """Basic cases""" - - @ray.remote - def f(): - for i in range(5): - yield i - - gen = f.options(num_returns="streaming").remote() - i = 0 - for ref in gen: - print(ray.get(ref)) - assert i == ray.get(ref) - del ref - i += 1 - - """Exceptions""" - - @ray.remote - def f(): - for i in range(5): - if i == 2: - raise ValueError - yield i - - gen = f.options(num_returns="streaming").remote() - ray.get(next(gen)) - ray.get(next(gen)) - with pytest.raises(ray.exceptions.RayTaskError) as e: - ray.get(next(gen)) - print(str(e.value)) - with pytest.raises(StopIteration): - ray.get(next(gen)) - with pytest.raises(StopIteration): - ray.get(next(gen)) - - """Generator Task failure""" - - @ray.remote - class A: - def getpid(self): - import os - - return os.getpid() - - def f(self): - for i in range(5): - time.sleep(0.1) - yield i - - a = A.remote() - i = 0 - gen = a.f.options(num_returns="streaming").remote() - i = 0 - for ref in gen: - if i == 2: - ray.kill(a) - if i == 3: - with pytest.raises(ray.exceptions.RayActorError) as e: - ray.get(ref) - assert "The actor is dead because it was killed by `ray.kill`" in str( - e.value - ) - break - assert i == ray.get(ref) - del ref - i += 1 - for _ in range(10): - with pytest.raises(StopIteration): - next(gen) - - """Retry exceptions""" - # TODO(sang): Enable it once retry is supported. - # @ray.remote - # class Actor: - # def __init__(self): - # self.should_kill = True - - # def should_kill(self): - # return self.should_kill - - # async def set(self, wait_s): - # await asyncio.sleep(wait_s) - # self.should_kill = False - - # @ray.remote(retry_exceptions=[ValueError], max_retries=10) - # def f(a): - # for i in range(5): - # should_kill = ray.get(a.should_kill.remote()) - # if i == 3 and should_kill: - # raise ValueError - # yield i - - # a = Actor.remote() - # gen = f.options(num_returns="streaming").remote(a) - # assert ray.get(next(gen)) == 0 - # assert ray.get(next(gen)) == 1 - # assert ray.get(next(gen)) == 2 - # a.set.remote(3) - # assert ray.get(next(gen)) == 3 - # assert ray.get(next(gen)) == 4 - # with pytest.raises(StopIteration): - # ray.get(next(gen)) - - """Cancel""" - - @ray.remote - def f(): - for i in range(5): - time.sleep(5) - yield i - - gen = f.options(num_returns="streaming").remote() - assert ray.get(next(gen)) == 0 - ray.cancel(gen) - with pytest.raises(ray.exceptions.RayTaskError) as e: - assert ray.get(next(gen)) == 1 - assert "was cancelled" in str(e.value) - with pytest.raises(StopIteration): - next(gen) - - -@pytest.mark.parametrize("crash_type", ["exception", "worker_crash"]) -def test_generator_streaming_no_leak_upon_failures( - monkeypatch, shutdown_only, crash_type -): - with monkeypatch.context() as m: - # defer for 10s for the second node. - m.setenv( - "RAY_testing_asio_delay_us", - "CoreWorkerService.grpc_server.ReportIntermediateTaskReturn=100000:1000000", - ) - ray.init(num_cpus=1) - - @ray.remote - def g(): - try: - gen = f.options(num_returns="streaming").remote() - for ref in gen: - print(ref) - ray.get(ref) - except Exception: - print("exception!") - del ref - - del gen - gc.collect() - - # Only the ref g is alive. - def verify(): - print(list_objects()) - return len(list_objects()) == 1 - - wait_for_condition(verify) - return True - - @ray.remote - def f(): - for i in range(10): - time.sleep(0.2) - if i == 4: - if crash_type == "exception": - raise ValueError - else: - sys.exit(9) - yield 2 - - for _ in range(5): - ray.get(g.remote()) - - -@pytest.mark.parametrize("use_actors", [False, True]) -@pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_generator_streaming(shutdown_only, use_actors, store_in_plasma): - """Verify the generator is working in a streaming fashion.""" - ray.init() - remote_generator_fn = None - if use_actors: - - @ray.remote - class Generator: - def __init__(self): - pass - - def generator(self, num_returns, store_in_plasma): - for i in range(num_returns): - if store_in_plasma: - yield np.ones(1_000_000, dtype=np.int8) * i - else: - yield [i] - - g = Generator.remote() - remote_generator_fn = g.generator - else: - - @ray.remote(max_retries=0) - def generator(num_returns, store_in_plasma): - for i in range(num_returns): - if store_in_plasma: - yield np.ones(1_000_000, dtype=np.int8) * i - else: - yield [i] - - remote_generator_fn = generator - - """Verify num_returns="streaming" is streaming""" - gen = remote_generator_fn.options(num_returns="streaming").remote( - 3, store_in_plasma - ) - i = 0 - for ref in gen: - id = ref.hex() - if store_in_plasma: - expected = np.ones(1_000_000, dtype=np.int8) * i - assert np.array_equal(ray.get(ref), expected) - else: - expected = [i] - assert ray.get(ref) == expected - - del ref - - wait_for_condition( - lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 - ) - i += 1 - - -def test_generator_dist_chain(ray_start_cluster): - cluster = ray_start_cluster - cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) - ray.init() - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - - @ray.remote - class ChainActor: - def __init__(self, child=None): - self.child = child - - def get_data(self): - if not self.child: - for _ in range(10): - time.sleep(0.1) - yield np.ones(5 * 1024 * 1024) - else: - for data in self.child.get_data.options( - num_returns="streaming" - ).remote(): - yield ray.get(data) - - chain_actor = ChainActor.remote() - chain_actor_2 = ChainActor.remote(chain_actor) - chain_actor_3 = ChainActor.remote(chain_actor_2) - chain_actor_4 = ChainActor.remote(chain_actor_3) - - for ref in chain_actor_4.get_data.options(num_returns="streaming").remote(): - assert np.array_equal(np.ones(5 * 1024 * 1024), ray.get(ref)) - del ref - - if __name__ == "__main__": import os diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 11e4778b297e..71000748cb44 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -218,12 +218,6 @@ ObjectID TaskSpecification::ReturnId(size_t return_index) const { bool TaskSpecification::ReturnsDynamic() const { return message_->returns_dynamic(); } -// TODO(sang): Merge this with ReturnsDynamic once migrating to the -// streaming generator. -bool TaskSpecification::IsStreamingGenerator() const { - return message_->streaming_generator(); -} - std::vector TaskSpecification::DynamicReturnIds() const { RAY_CHECK(message_->returns_dynamic()); std::vector dynamic_return_ids; diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index eea53f3d0348..3b29d2aadb3b 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -262,8 +262,6 @@ class TaskSpecification : public MessageWrapper { bool ReturnsDynamic() const; - bool IsStreamingGenerator() const; - std::vector DynamicReturnIds() const; void AddDynamicReturnId(const ObjectID &dynamic_return_id); diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 1110504ea0b5..c260745b7161 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -126,7 +126,6 @@ class TaskSpecBuilder { const rpc::Address &caller_address, uint64_t num_returns, bool returns_dynamic, - bool is_streaming_generator, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const std::string &debugger_breakpoint, @@ -150,7 +149,6 @@ class TaskSpecBuilder { message_->mutable_caller_address()->CopyFrom(caller_address); message_->set_num_returns(num_returns); message_->set_returns_dynamic(returns_dynamic); - message_->set_streaming_generator(is_streaming_generator); message_->mutable_required_resources()->insert(required_resources.begin(), required_resources.end()); message_->mutable_required_placement_resources()->insert( diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9f2d950db681..e3088741be3e 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1850,16 +1850,6 @@ void CoreWorker::BuildCommonTaskSpec( // is a generator of ObjectRefs. num_returns = 1; } - // TODO(sang): Remove this and integrate it to - // nun_returns == -1 once migrating to streaming - // generator. - bool is_streaming_generator = num_returns == -2; - if (is_streaming_generator) { - num_returns = 1; - // We are using the dynamic return if - // the streaming generator is used. - returns_dynamic = true; - } RAY_CHECK(num_returns >= 0); builder.SetCommonTaskSpec( task_id, @@ -1876,7 +1866,6 @@ void CoreWorker::BuildCommonTaskSpec( address, num_returns, returns_dynamic, - is_streaming_generator, required_resources, required_placement_resources, debugger_breakpoint, @@ -2673,8 +2662,7 @@ Status CoreWorker::ExecuteTask( application_error, defined_concurrency_groups, name_of_concurrency_group_to_execute, - /*is_reattempt=*/task_spec.AttemptNumber() > 0, - /*is_streaming_generator*/ task_spec.IsStreamingGenerator()); + /*is_reattempt=*/task_spec.AttemptNumber() > 0); // Get the reference counts for any IDs that we borrowed during this task, // remove the local reference for these IDs, and return the ref count info to @@ -2842,12 +2830,13 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, } } -ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) { +ObjectID CoreWorker::AllocateDynamicReturnId() { const auto &task_spec = worker_context_.GetCurrentTask(); const auto return_id = ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); AddLocalReference(return_id, ""); - reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); + reference_counter_->AddBorrowedObject( + return_id, ObjectID::Nil(), worker_context_.GetCurrentTask()->CallerAddress()); return return_id; } @@ -3268,11 +3257,7 @@ void CoreWorker::ProcessSubscribeForObjectEviction( // counter so that we know that it exists. const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); - if (task_manager_->ObjectRefStreamExists(generator_id)) { - reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, generator_id); - } else { - reference_counter_->AddDynamicReturn(object_id, generator_id); - } + reference_counter_->AddDynamicReturn(object_id, generator_id); } // Returns true if the object was present and the callback was added. It might have @@ -3406,11 +3391,7 @@ void CoreWorker::AddSpilledObjectLocationOwner( // object. Add the dynamically created object to our ref counter so that we // know that it exists. RAY_CHECK(!generator_id->IsNil()); - if (task_manager_->ObjectRefStreamExists(*generator_id)) { - reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, *generator_id); - } else { - reference_counter_->AddDynamicReturn(object_id, *generator_id); - } + reference_counter_->AddDynamicReturn(object_id, *generator_id); } auto reference_exists = @@ -3438,14 +3419,9 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, // until the task finishes. const auto &maybe_generator_id = task_manager_->TaskGeneratorId(object_id.TaskId()); if (!maybe_generator_id.IsNil()) { - if (task_manager_->ObjectRefStreamExists(maybe_generator_id)) { - // If the stream exists, it means it is a streaming generator. - reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, maybe_generator_id); - } else { - // The task is a generator and may not have finished yet. Add the internal - // ObjectID so that we can update its location. - reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); - } + // The task is a generator and may not have finished yet. Add the internal + // ObjectID so that we can update its location. + reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); RAY_UNUSED(reference_counter_->AddObjectLocation(object_id, node_id)); } } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 1f8c725f7080..4b27ce3f0b2a 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1022,11 +1022,9 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// object to the task caller and have the resulting ObjectRef be owned by /// the caller. This is in contrast to static allocation, where the caller /// decides at task invocation time how many returns the task should have. - /// \param[in] owner_address The address of the owner who will own this - /// dynamically generated object. /// /// \param[out] The ObjectID that the caller should use to store the object. - ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address); + ObjectID AllocateDynamicReturnId(); /// Get a handle to an actor. /// diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 3a8346776077..157a3fbc53a3 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -56,10 +56,7 @@ struct CoreWorkerOptions { // used for actor creation task. const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt, - // True if the task is for streaming generator. - // TODO(sang): Remove it and combine it with dynamic returns. - bool is_streaming_generator)>; + bool is_reattempt)>; CoreWorkerOptions() : store_socket(""), diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 109dd0dc9686..5afb92f853be 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -124,8 +124,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt, - bool is_streaming_generator) { + bool is_reattempt) { // These 2 parameters are used for Python only, and Java worker // will not use them. RAY_UNUSED(defined_concurrency_groups); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 62dd91f4474b..31a97db7bd4f 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -570,7 +570,6 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { address, num_returns, false, - false, resources, resources, "", diff --git a/src/ray/core_worker/test/dependency_resolver_test.cc b/src/ray/core_worker/test/dependency_resolver_test.cc index 5ca82b773b7a..4d2406e006ec 100644 --- a/src/ray/core_worker/test/dependency_resolver_test.cc +++ b/src/ray/core_worker/test/dependency_resolver_test.cc @@ -44,7 +44,6 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r empty_address, 1, false, - false, resources, resources, serialized_runtime_env, diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 498551b61334..61eb4370c3f4 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -65,7 +65,6 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r empty_address, 1, false, - false, resources, resources, serialized_runtime_env, diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 7529a5255ee0..1c782438ae28 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -67,8 +67,7 @@ class MockWorker { std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt, - bool is_streaming_generator) { + bool is_reattempt) { return ExecuteTask(caller_address, task_type, task_name, diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index 744b3ae2bb2a..fdef576c32e3 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -58,7 +58,6 @@ struct Mocker { owner_address, 1, false, - false, required_resources, required_placement_resources, "", diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index bf10020a37b9..b78e354768a4 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -422,12 +422,6 @@ message TaskSpec { // This will be the actor creation task's task id for concurrent actors. Or // the main thread's task id for other cases. bytes submitter_task_id = 33; - // True if the task is a streaming generator. When it is true, - // returns_dynamic has to be true as well. This is a temporary flag - // until we migrate the generator implementatino to streaming. - // TODO(sang): Remove it once migrating to the streaming generator - // by default. - bool streaming_generator = 34; } message TaskInfoEntry { diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index d5e17ee0fe62..de2bd227996c 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -165,7 +165,6 @@ RayTask CreateTask( address, 0, /*returns_dynamic=*/false, - /*is_streaming_generator*/ false, required_resources, {}, "", From 122b70574cef011174262edce12c4b8444f65a53 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 10:16:04 -0700 Subject: [PATCH 09/77] [Revert Please] Support core worker APIs and a generator. Signed-off-by: SangBin Cho --- python/ray/_private/ray_option_utils.py | 2 +- python/ray/_raylet.pyx | 169 ------------------- python/ray/exceptions.py | 4 - python/ray/includes/common.pxd | 4 - python/ray/includes/libcoreworker.pxd | 5 - python/ray/tests/BUILD | 1 - python/ray/tests/test_streaming_generator.py | 141 ---------------- src/ray/core_worker/core_worker.cc | 22 --- src/ray/core_worker/core_worker.h | 32 ---- 9 files changed, 1 insertion(+), 379 deletions(-) delete mode 100644 python/ray/tests/test_streaming_generator.py diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index f433fc3f153a..afe66e816f30 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -154,7 +154,7 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "num_returns": Option( (int, str, type(None)), lambda x: None - if (x is None or x == "dynamic") + if (x is None or x == "dynamic" or x > 0) else "The keyword 'num_returns' only accepts None, a non-negative integer, or " '"dynamic" (for generators)', default_value=1, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 89773c560aaf..94551ca1deb9 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -134,7 +134,6 @@ from ray.exceptions import ( AsyncioActorExit, PendingCallsLimitExceeded, RpcError, - ObjectRefStreamEoFError, ) from ray._private import external_storage from ray.util.scheduling_strategies import ( @@ -199,145 +198,6 @@ class ObjectRefGenerator: return len(self._refs) -class StreamingObjectRefGenerator: - def __init__(self, generator_ref, worker): - # The reference to a generator task. - self._generator_ref = generator_ref - # The last time generator task has completed. - self._generator_task_completed_time = None - # The exception raised from a generator task. - self._generator_task_exception = None - # Ray's worker class. ray._private.worker.global_worker - self.worker = worker - assert hasattr(worker, "core_worker") - self.worker.core_worker.create_object_ref_stream(self._generator_ref) - - def __iter__(self): - return self - - def __next__(self): - """Waits until a next ref is available and returns the object ref. - - Raises StopIteration if there's no more objects - to generate. - - The object ref will contain an exception if the task fails. - When the generator task returns N objects, it can return - up to N + 1 objects (if there's a system failure, the - last object will contain a system level exception). - """ - return self._next() - - def _next( - self, - timeout_s: float = -1, - sleep_interval_s: float = 0.0001, - unexpected_network_failure_timeout_s: float = 30): - """Waits for timeout_s and returns the object ref if available. - - If an object is not available within the given timeout, it - returns a nil object reference. - - If -1 timeout is provided, it means it waits infinitely. - - Waiting is implemented as busy waiting. You can control - the busy waiting interval via sleep_interval_s. - - Raises StopIteration if there's no more objects - to generate. - - The object ref will contain an exception if the task fails. - When the generator task returns N objects, it can return - up to N + 1 objects (if there's a system failure, the - last object will contain a system level exception). - - Args: - timeout_s: If the next object is not ready within - this timeout, it returns the nil object ref. - sleep_interval_s: busy waiting interval. - unexpected_network_failure_timeout_s: If the - task is finished, but the next ref is not - available within this time, it will hard fail - the generator. - """ - obj = self._handle_next() - last_time = time.time() - - # The generator ref will be None if the task succeeds. - # It will contain an exception if the task fails by - # a system error. - while obj.is_nil(): - if self._generator_task_exception: - # The generator task has failed already. - # We raise StopIteration - # to conform the next interface in Python. - raise StopIteration - else: - # Otherwise, we should ray.get on the generator - # ref to find if the task has a system failure. - # Return the generator ref that contains the system - # error as soon as possible. - r, _ = ray.wait([self._generator_ref], timeout=0) - if len(r) > 0: - try: - ray.get(r) - except Exception as e: - # If it has failed, return the generator task ref - # so that the ref will raise an exception. - self._generator_task_exception = e - return self._generator_ref - finally: - if self._generator_task_completed_time is None: - self._generator_task_completed_time = time.time() - - # Currently, since the ordering of intermediate result report - # is not guaranteed, it is possible that althoug the task - # has succeeded, all of the object references are not reported - # (e.g., when there are network failures). - # If all the object refs are not reported to the generator - # within 30 seconds, we consider is as an unreconverable error. - if self._generator_task_completed_time: - if (time.time() - self._generator_task_completed_time - > unexpected_network_failure_timeout_s): - # It means the next wasn't reported although the task - # has been terminated 30 seconds ago. - self._generator_task_exception = AssertionError - assert False, "Unexpected network failure occured." - - if timeout_s != -1 and time.time() - last_time > timeout_s: - return ObjectRef.nil() - - # 100us busy waiting - time.sleep(sleep_interval_s) - obj = self._handle_next() - return obj - - def _handle_next(self): - try: - if hasattr(self.worker, "core_worker"): - obj = self.worker.core_worker.async_read_object_ref_stream( - self._generator_ref) - return obj - else: - raise ValueError( - "Cannot access the core worker. " - "Did you already shutdown Ray via ray.shutdown()?") - except ObjectRefStreamEoFError: - raise StopIteration - - def __del__(self): - if hasattr(self.worker, "core_worker"): - # NOTE: This can be called multiple times - # because python doesn't guarantee __del__ is called - # only once. - self.worker.core_worker.delete_object_ref_stream(self._generator_ref) - - def __getstate__(self): - raise TypeError( - "Serialization of the StreamingObjectRefGenerator " - "is now allowed") - - cdef int check_status(const CRayStatus& status) nogil except -1: if status.ok(): return 0 @@ -349,8 +209,6 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsOutOfDisk(): raise OutOfDiskError(message) - elif status.IsObjectRefStreamEoF(): - raise ObjectRefStreamEoFError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() elif status.IsTimedOut(): @@ -3256,33 +3114,6 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) - def create_object_ref_stream(self, ObjectRef generator_id): - cdef: - CObjectID c_generator_id = generator_id.native() - - CCoreWorkerProcess.GetCoreWorker().CreateObjectRefStream(c_generator_id) - - def delete_object_ref_stream(self, ObjectRef generator_id): - cdef: - CObjectID c_generator_id = generator_id.native() - - CCoreWorkerProcess.GetCoreWorker().DelObjectRefStream(c_generator_id) - - def async_read_object_ref_stream(self, ObjectRef generator_id): - cdef: - CObjectID c_generator_id = generator_id.native() - CObjectReference c_object_ref - - check_status( - CCoreWorkerProcess.GetCoreWorker().AsyncReadObjectRefStream( - c_generator_id, &c_object_ref)) - return ObjectRef( - c_object_ref.object_id(), - c_object_ref.owner_address().SerializeAsString(), - "", - # Already added when the ref is updated. - skip_adding_local_ref=True) - cdef void async_callback(shared_ptr[CRayObject] obj, CObjectID object_ref, void *user_callback) with gil: diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index dd97806fecaf..276acfd372c6 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -336,10 +336,6 @@ def __str__(self): return error_msg -class ObjectRefStreamEoFError(RayError): - pass - - @PublicAPI class ObjectStoreFullError(RayError): """Indicates that the object store is full. diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index d2d06bc357be..4250470f3013 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -99,9 +99,6 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: @staticmethod CRayStatus NotFound() - @staticmethod - CRayStatus ObjectRefStreamEoF() - c_bool ok() c_bool IsOutOfMemory() c_bool IsKeyError() @@ -121,7 +118,6 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: c_bool IsObjectUnknownOwner() c_bool IsRpcError() c_bool IsOutOfResource() - c_bool IsObjectRefStreamEoF() c_string ToString() c_string CodeAsString() diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 8dac68ea651e..eda70ac22767 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -146,11 +146,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const CObjectID& return_id, shared_ptr[CRayObject] *return_object, const CObjectID& generator_id) - void DelObjectRefStream(const CObjectID &generator_id) - void CreateObjectRefStream(const CObjectID &generator_id) - CRayStatus AsyncReadObjectRefStream( - const CObjectID &generator_id, - CObjectReference *object_ref_out) CObjectID AllocateDynamicReturnId() CJobID GetCurrentJobId() diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 2321c1ef14e7..7b483064b550 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -46,7 +46,6 @@ py_test_module_list( "test_gcs_fault_tolerance.py", "test_gcs_utils.py", "test_generators.py", - "test_streaming_generator.py", "test_metrics_agent.py", "test_metrics_head.py", "test_component_failures_2.py", diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py deleted file mode 100644 index c496d52b6179..000000000000 --- a/python/ray/tests/test_streaming_generator.py +++ /dev/null @@ -1,141 +0,0 @@ -import pytest -import sys -import time - -from unittest.mock import patch, Mock - -import ray -from ray._raylet import StreamingObjectRefGenerator -from ray.cloudpickle import dumps -from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError - - -class MockedWorker: - def __init__(self, mocked_core_worker): - self.core_worker = mocked_core_worker - - def reset_core_worker(self): - """Emulate the case ray.shutdown is called - and the core_worker instance is GC'ed. - """ - self.core_worker = None - - -@pytest.fixture -def mocked_worker(): - mocked_core_worker = Mock() - mocked_core_worker.async_read_object_ref_stream.return_value = None - mocked_core_worker.delete_object_ref_stream.return_value = None - mocked_core_worker.create_object_ref_stream.return_value = None - worker = MockedWorker(mocked_core_worker) - yield worker - - -def test_streaming_object_ref_generator_basic_unit(mocked_worker): - """ - Verify the basic case: - create a generator -> read values -> nothing more to read -> delete. - """ - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - c.create_object_ref_stream.assert_called() - - # Test when there's no new ref, it returns a nil. - mocked_ray_wait.return_value = [], [generator_ref] - ref = generator._next(timeout_s=0) - assert ref.is_nil() - - # When the new ref is available, next should return it. - for _ in range(3): - new_ref = ray.ObjectRef.from_random() - c.async_read_object_ref_stream.return_value = new_ref - ref = generator._next(timeout_s=0) - assert new_ref == ref - - # When async_read_object_ref_stream raises a - # ObjectRefStreamEoFError, it should raise a stop iteration. - c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa - with pytest.raises(StopIteration): - ref = generator._next(timeout_s=0) - - # Make sure we cannot serialize the generator. - with pytest.raises(TypeError): - dumps(generator) - - del generator - c.delete_object_ref_stream.assert_called() - - -def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): - """ - Verify when a task is failed by a system error, - the generator ref is returned. - """ - with patch("ray.get") as mocked_ray_get: - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - - # Simulate the worker failure happens. - mocked_ray_wait.return_value = [generator_ref], [] - mocked_ray_get.side_effect = WorkerCrashedError() - - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0) - # If the generator task fails by a systsem error, - # meaning the ref will raise an exception - # it should be returned. - print(ref) - print(generator_ref) - assert ref == generator_ref - - # Once exception is raised, it should always - # raise stopIteration regardless of what - # the ref contains now. - with pytest.raises(StopIteration): - ref = generator._next(timeout_s=0) - - -def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): - """ - Verify when a task is finished, but if the next ref is not available - on time, it raises an assertion error. - - TODO(sang): Once we move the task subimssion path to use pubsub - to guarantee the ordering, we don't need this test anymore. - """ - with patch("ray.get") as mocked_ray_get: - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - - # Simulate the task has finished. - mocked_ray_wait.return_value = [generator_ref], [] - mocked_ray_get.return_value = None - - # If StopIteration is not raised within - # unexpected_network_failure_timeout_s second, - # it should fail. - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) - assert ref == ray.ObjectRef.nil() - time.sleep(1) - with pytest.raises(AssertionError): - generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) - # After that StopIteration should be raised. - with pytest.raises(StopIteration): - generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) - - -if __name__ == "__main__": - import os - - if os.environ.get("PARALLEL_CI"): - sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) - else: - sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e3088741be3e..bd1e3ed48dc0 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2755,28 +2755,6 @@ Status CoreWorker::SealReturnObject(const ObjectID &return_id, return status; } -void CoreWorker::CreateObjectRefStream(const ObjectID &generator_id) { - task_manager_->CreateObjectRefStream(generator_id); -} - -void CoreWorker::DelObjectRefStream(const ObjectID &generator_id) { - task_manager_->DelObjectRefStream(generator_id); -} - -Status CoreWorker::AsyncReadObjectRefStream(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out) { - ObjectID object_id; - const auto &status = task_manager_->AsyncReadObjectRefStream(generator_id, &object_id); - if (!status.ok()) { - return status; - } - - RAY_CHECK(object_ref_out != nullptr); - object_ref_out->set_object_id(object_id.Binary()); - object_ref_out->mutable_owner_address()->CopyFrom(rpc_address_); - return status; -} - bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, std::shared_ptr *return_object, const ObjectID &generator_id) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 4b27ce3f0b2a..754b0cbed229 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -360,38 +360,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { NodeID GetCurrentNodeId() const { return NodeID::FromBinary(rpc_address_.raylet_id()); } - /// Create the ObjectRefStream of generator_id. - /// - /// It is a pass-through method. See TaskManager::CreateObjectRefStream - /// for details. - /// - /// \param[in] generator_id The object ref id of the streaming - /// generator task. - void CreateObjectRefStream(const ObjectID &generator_id); - - /// Read the next index of a ObjectRefStream of generator_id. - /// - /// \param[in] generator_id The object ref id of the streaming - /// generator task. - /// \param[out] object_ref_out The ObjectReference - /// that the caller can convert to its own ObjectRef. - /// The current process is always the owner of the - /// generated ObjectReference. - /// \return Status RayKeyError if the stream reaches to EoF. - /// OK otherwise. - Status AsyncReadObjectRefStream(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out); - - /// Delete the ObjectRefStream of generator_id - /// created by CreateObjectRefStream. - /// - /// It is a pass-through method. See TaskManager::DelObjectRefStream - /// for details. - /// - /// \param[in] generator_id The object ref id of the streaming - /// generator task. - void DelObjectRefStream(const ObjectID &generator_id); - const PlacementGroupID &GetCurrentPlacementGroupId() const { return worker_context_.GetCurrentPlacementGroupId(); } From 7a8fe2cedb384865f2d6588d78c0078f36e00da5 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 10:23:27 -0700 Subject: [PATCH 10/77] fix a bug Signed-off-by: SangBin Cho --- python/ray/_private/ray_option_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index afe66e816f30..88703942f64e 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -154,7 +154,7 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "num_returns": Option( (int, str, type(None)), lambda x: None - if (x is None or x == "dynamic" or x > 0) + if (x is None or x == "dynamic" or x >= 0) else "The keyword 'num_returns' only accepts None, a non-negative integer, or " '"dynamic" (for generators)', default_value=1, From d880763075ced9f1d1c27cbd9c9aa099727be3c2 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 10:24:05 -0700 Subject: [PATCH 11/77] Revert "[Revert Please] Support core worker APIs and a generator." This reverts commit 122b70574cef011174262edce12c4b8444f65a53. Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 169 +++++++++++++++++++ python/ray/exceptions.py | 4 + python/ray/includes/common.pxd | 4 + python/ray/includes/libcoreworker.pxd | 5 + python/ray/tests/BUILD | 1 + python/ray/tests/test_streaming_generator.py | 141 ++++++++++++++++ src/ray/core_worker/core_worker.cc | 22 +++ src/ray/core_worker/core_worker.h | 32 ++++ 8 files changed, 378 insertions(+) create mode 100644 python/ray/tests/test_streaming_generator.py diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 94551ca1deb9..89773c560aaf 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -134,6 +134,7 @@ from ray.exceptions import ( AsyncioActorExit, PendingCallsLimitExceeded, RpcError, + ObjectRefStreamEoFError, ) from ray._private import external_storage from ray.util.scheduling_strategies import ( @@ -198,6 +199,145 @@ class ObjectRefGenerator: return len(self._refs) +class StreamingObjectRefGenerator: + def __init__(self, generator_ref, worker): + # The reference to a generator task. + self._generator_ref = generator_ref + # The last time generator task has completed. + self._generator_task_completed_time = None + # The exception raised from a generator task. + self._generator_task_exception = None + # Ray's worker class. ray._private.worker.global_worker + self.worker = worker + assert hasattr(worker, "core_worker") + self.worker.core_worker.create_object_ref_stream(self._generator_ref) + + def __iter__(self): + return self + + def __next__(self): + """Waits until a next ref is available and returns the object ref. + + Raises StopIteration if there's no more objects + to generate. + + The object ref will contain an exception if the task fails. + When the generator task returns N objects, it can return + up to N + 1 objects (if there's a system failure, the + last object will contain a system level exception). + """ + return self._next() + + def _next( + self, + timeout_s: float = -1, + sleep_interval_s: float = 0.0001, + unexpected_network_failure_timeout_s: float = 30): + """Waits for timeout_s and returns the object ref if available. + + If an object is not available within the given timeout, it + returns a nil object reference. + + If -1 timeout is provided, it means it waits infinitely. + + Waiting is implemented as busy waiting. You can control + the busy waiting interval via sleep_interval_s. + + Raises StopIteration if there's no more objects + to generate. + + The object ref will contain an exception if the task fails. + When the generator task returns N objects, it can return + up to N + 1 objects (if there's a system failure, the + last object will contain a system level exception). + + Args: + timeout_s: If the next object is not ready within + this timeout, it returns the nil object ref. + sleep_interval_s: busy waiting interval. + unexpected_network_failure_timeout_s: If the + task is finished, but the next ref is not + available within this time, it will hard fail + the generator. + """ + obj = self._handle_next() + last_time = time.time() + + # The generator ref will be None if the task succeeds. + # It will contain an exception if the task fails by + # a system error. + while obj.is_nil(): + if self._generator_task_exception: + # The generator task has failed already. + # We raise StopIteration + # to conform the next interface in Python. + raise StopIteration + else: + # Otherwise, we should ray.get on the generator + # ref to find if the task has a system failure. + # Return the generator ref that contains the system + # error as soon as possible. + r, _ = ray.wait([self._generator_ref], timeout=0) + if len(r) > 0: + try: + ray.get(r) + except Exception as e: + # If it has failed, return the generator task ref + # so that the ref will raise an exception. + self._generator_task_exception = e + return self._generator_ref + finally: + if self._generator_task_completed_time is None: + self._generator_task_completed_time = time.time() + + # Currently, since the ordering of intermediate result report + # is not guaranteed, it is possible that althoug the task + # has succeeded, all of the object references are not reported + # (e.g., when there are network failures). + # If all the object refs are not reported to the generator + # within 30 seconds, we consider is as an unreconverable error. + if self._generator_task_completed_time: + if (time.time() - self._generator_task_completed_time + > unexpected_network_failure_timeout_s): + # It means the next wasn't reported although the task + # has been terminated 30 seconds ago. + self._generator_task_exception = AssertionError + assert False, "Unexpected network failure occured." + + if timeout_s != -1 and time.time() - last_time > timeout_s: + return ObjectRef.nil() + + # 100us busy waiting + time.sleep(sleep_interval_s) + obj = self._handle_next() + return obj + + def _handle_next(self): + try: + if hasattr(self.worker, "core_worker"): + obj = self.worker.core_worker.async_read_object_ref_stream( + self._generator_ref) + return obj + else: + raise ValueError( + "Cannot access the core worker. " + "Did you already shutdown Ray via ray.shutdown()?") + except ObjectRefStreamEoFError: + raise StopIteration + + def __del__(self): + if hasattr(self.worker, "core_worker"): + # NOTE: This can be called multiple times + # because python doesn't guarantee __del__ is called + # only once. + self.worker.core_worker.delete_object_ref_stream(self._generator_ref) + + def __getstate__(self): + raise TypeError( + "Serialization of the StreamingObjectRefGenerator " + "is now allowed") + + cdef int check_status(const CRayStatus& status) nogil except -1: if status.ok(): return 0 @@ -209,6 +349,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsOutOfDisk(): raise OutOfDiskError(message) + elif status.IsObjectRefStreamEoF(): + raise ObjectRefStreamEoFError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() elif status.IsTimedOut(): @@ -3114,6 +3256,33 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) + def create_object_ref_stream(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + + CCoreWorkerProcess.GetCoreWorker().CreateObjectRefStream(c_generator_id) + + def delete_object_ref_stream(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + + CCoreWorkerProcess.GetCoreWorker().DelObjectRefStream(c_generator_id) + + def async_read_object_ref_stream(self, ObjectRef generator_id): + cdef: + CObjectID c_generator_id = generator_id.native() + CObjectReference c_object_ref + + check_status( + CCoreWorkerProcess.GetCoreWorker().AsyncReadObjectRefStream( + c_generator_id, &c_object_ref)) + return ObjectRef( + c_object_ref.object_id(), + c_object_ref.owner_address().SerializeAsString(), + "", + # Already added when the ref is updated. + skip_adding_local_ref=True) + cdef void async_callback(shared_ptr[CRayObject] obj, CObjectID object_ref, void *user_callback) with gil: diff --git a/python/ray/exceptions.py b/python/ray/exceptions.py index 276acfd372c6..dd97806fecaf 100644 --- a/python/ray/exceptions.py +++ b/python/ray/exceptions.py @@ -336,6 +336,10 @@ def __str__(self): return error_msg +class ObjectRefStreamEoFError(RayError): + pass + + @PublicAPI class ObjectStoreFullError(RayError): """Indicates that the object store is full. diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 4250470f3013..d2d06bc357be 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -99,6 +99,9 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: @staticmethod CRayStatus NotFound() + @staticmethod + CRayStatus ObjectRefStreamEoF() + c_bool ok() c_bool IsOutOfMemory() c_bool IsKeyError() @@ -118,6 +121,7 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: c_bool IsObjectUnknownOwner() c_bool IsRpcError() c_bool IsOutOfResource() + c_bool IsObjectRefStreamEoF() c_string ToString() c_string CodeAsString() diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index eda70ac22767..8dac68ea651e 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -146,6 +146,11 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const CObjectID& return_id, shared_ptr[CRayObject] *return_object, const CObjectID& generator_id) + void DelObjectRefStream(const CObjectID &generator_id) + void CreateObjectRefStream(const CObjectID &generator_id) + CRayStatus AsyncReadObjectRefStream( + const CObjectID &generator_id, + CObjectReference *object_ref_out) CObjectID AllocateDynamicReturnId() CJobID GetCurrentJobId() diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 7b483064b550..2321c1ef14e7 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -46,6 +46,7 @@ py_test_module_list( "test_gcs_fault_tolerance.py", "test_gcs_utils.py", "test_generators.py", + "test_streaming_generator.py", "test_metrics_agent.py", "test_metrics_head.py", "test_component_failures_2.py", diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py new file mode 100644 index 000000000000..c496d52b6179 --- /dev/null +++ b/python/ray/tests/test_streaming_generator.py @@ -0,0 +1,141 @@ +import pytest +import sys +import time + +from unittest.mock import patch, Mock + +import ray +from ray._raylet import StreamingObjectRefGenerator +from ray.cloudpickle import dumps +from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError + + +class MockedWorker: + def __init__(self, mocked_core_worker): + self.core_worker = mocked_core_worker + + def reset_core_worker(self): + """Emulate the case ray.shutdown is called + and the core_worker instance is GC'ed. + """ + self.core_worker = None + + +@pytest.fixture +def mocked_worker(): + mocked_core_worker = Mock() + mocked_core_worker.async_read_object_ref_stream.return_value = None + mocked_core_worker.delete_object_ref_stream.return_value = None + mocked_core_worker.create_object_ref_stream.return_value = None + worker = MockedWorker(mocked_core_worker) + yield worker + + +def test_streaming_object_ref_generator_basic_unit(mocked_worker): + """ + Verify the basic case: + create a generator -> read values -> nothing more to read -> delete. + """ + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.create_object_ref_stream.assert_called() + + # Test when there's no new ref, it returns a nil. + mocked_ray_wait.return_value = [], [generator_ref] + ref = generator._next(timeout_s=0) + assert ref.is_nil() + + # When the new ref is available, next should return it. + for _ in range(3): + new_ref = ray.ObjectRef.from_random() + c.async_read_object_ref_stream.return_value = new_ref + ref = generator._next(timeout_s=0) + assert new_ref == ref + + # When async_read_object_ref_stream raises a + # ObjectRefStreamEoFError, it should raise a stop iteration. + c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + with pytest.raises(StopIteration): + ref = generator._next(timeout_s=0) + + # Make sure we cannot serialize the generator. + with pytest.raises(TypeError): + dumps(generator) + + del generator + c.delete_object_ref_stream.assert_called() + + +def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): + """ + Verify when a task is failed by a system error, + the generator ref is returned. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the worker failure happens. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.side_effect = WorkerCrashedError() + + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = generator._next(timeout_s=0) + # If the generator task fails by a systsem error, + # meaning the ref will raise an exception + # it should be returned. + print(ref) + print(generator_ref) + assert ref == generator_ref + + # Once exception is raised, it should always + # raise stopIteration regardless of what + # the ref contains now. + with pytest.raises(StopIteration): + ref = generator._next(timeout_s=0) + + +def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): + """ + Verify when a task is finished, but if the next ref is not available + on time, it raises an assertion error. + + TODO(sang): Once we move the task subimssion path to use pubsub + to guarantee the ordering, we don't need this test anymore. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the task has finished. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.return_value = None + + # If StopIteration is not raised within + # unexpected_network_failure_timeout_s second, + # it should fail. + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + assert ref == ray.ObjectRef.nil() + time.sleep(1) + with pytest.raises(AssertionError): + generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + # After that StopIteration should be raised. + with pytest.raises(StopIteration): + generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + + +if __name__ == "__main__": + import os + + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index bd1e3ed48dc0..e3088741be3e 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2755,6 +2755,28 @@ Status CoreWorker::SealReturnObject(const ObjectID &return_id, return status; } +void CoreWorker::CreateObjectRefStream(const ObjectID &generator_id) { + task_manager_->CreateObjectRefStream(generator_id); +} + +void CoreWorker::DelObjectRefStream(const ObjectID &generator_id) { + task_manager_->DelObjectRefStream(generator_id); +} + +Status CoreWorker::AsyncReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out) { + ObjectID object_id; + const auto &status = task_manager_->AsyncReadObjectRefStream(generator_id, &object_id); + if (!status.ok()) { + return status; + } + + RAY_CHECK(object_ref_out != nullptr); + object_ref_out->set_object_id(object_id.Binary()); + object_ref_out->mutable_owner_address()->CopyFrom(rpc_address_); + return status; +} + bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, std::shared_ptr *return_object, const ObjectID &generator_id) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 754b0cbed229..4b27ce3f0b2a 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -360,6 +360,38 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { NodeID GetCurrentNodeId() const { return NodeID::FromBinary(rpc_address_.raylet_id()); } + /// Create the ObjectRefStream of generator_id. + /// + /// It is a pass-through method. See TaskManager::CreateObjectRefStream + /// for details. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + void CreateObjectRefStream(const ObjectID &generator_id); + + /// Read the next index of a ObjectRefStream of generator_id. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + /// \param[out] object_ref_out The ObjectReference + /// that the caller can convert to its own ObjectRef. + /// The current process is always the owner of the + /// generated ObjectReference. + /// \return Status RayKeyError if the stream reaches to EoF. + /// OK otherwise. + Status AsyncReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out); + + /// Delete the ObjectRefStream of generator_id + /// created by CreateObjectRefStream. + /// + /// It is a pass-through method. See TaskManager::DelObjectRefStream + /// for details. + /// + /// \param[in] generator_id The object ref id of the streaming + /// generator task. + void DelObjectRefStream(const ObjectID &generator_id); + const PlacementGroupID &GetCurrentPlacementGroupId() const { return worker_context_.GetCurrentPlacementGroupId(); } From f501c22a72346e38199038277856ca533ea141f9 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 10:28:43 -0700 Subject: [PATCH 12/77] Revert "[Please Revert] Work e2e." This reverts commit 05f468a3cc79b52512bb375a73bfb223ef3f8c21. Signed-off-by: SangBin Cho --- .../runtime/task/local_mode_task_submitter.cc | 1 + cpp/src/ray/runtime/task/task_executor.cc | 3 +- cpp/src/ray/runtime/task/task_executor.h | 3 +- python/ray/_private/ray_option_utils.py | 2 +- python/ray/_private/worker.py | 9 + python/ray/_private/workers/default_worker.py | 1 - python/ray/_raylet.pxd | 1 + python/ray/_raylet.pyx | 232 +++++++++++++-- python/ray/actor.py | 12 +- python/ray/includes/libcoreworker.pxd | 5 +- python/ray/remote_function.py | 12 +- python/ray/tests/test_generators.py | 44 ++- python/ray/tests/test_streaming_generator.py | 268 ++++++++++++++++++ src/ray/common/task/task_spec.cc | 6 + src/ray/common/task/task_spec.h | 2 + src/ray/common/task/task_util.h | 2 + src/ray/core_worker/core_worker.cc | 42 ++- src/ray/core_worker/core_worker.h | 4 +- src/ray/core_worker/core_worker_options.h | 5 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 3 +- src/ray/core_worker/test/core_worker_test.cc | 1 + .../test/dependency_resolver_test.cc | 1 + .../test/direct_task_transport_test.cc | 1 + src/ray/core_worker/test/mock_worker.cc | 3 +- src/ray/gcs/test/gcs_test_util.h | 1 + src/ray/protobuf/common.proto | 6 + .../scheduling/cluster_task_manager_test.cc | 1 + 27 files changed, 622 insertions(+), 49 deletions(-) diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 145e8130fe15..6052531e1211 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -61,6 +61,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation, address, 1, /*returns_dynamic=*/false, + /*is_streaming_generator*/ false, required_resources, required_placement_resources, "", diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index ca4aae05fd7e..97d67c760279 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -135,7 +135,8 @@ Status TaskExecutor::ExecuteTask( std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt) { + bool is_reattempt, + bool is_streaming_generator) { RAY_LOG(DEBUG) << "Execute task type: " << TaskType_Name(task_type) << " name:" << task_name; RAY_CHECK(ray_function.GetLanguage() == ray::Language::CPP); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index 4ce2f6009e7e..4ec3df555de9 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -91,7 +91,8 @@ class TaskExecutor { std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt); + bool is_reattempt, + bool is_streaming_generator); virtual ~TaskExecutor(){}; diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index 88703942f64e..97c35f9449ca 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -154,7 +154,7 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "num_returns": Option( (int, str, type(None)), lambda x: None - if (x is None or x == "dynamic" or x >= 0) + if (x is None or x == "dynamic" or x == "streaming" or x >= 0) else "The keyword 'num_returns' only accepts None, a non-negative integer, or " '"dynamic" (for generators)', default_value=1, diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 1bb275a2312e..e45c33eb3027 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2499,6 +2499,11 @@ def get( blocking_get_inside_async_warned = True with profiling.profile("ray.get"): + # TODO(sang): Should make StreamingObjectRefGenerator + # compatible to ray.get for dataset. + if isinstance(object_refs, ray._raylet.StreamingObjectRefGenerator): + return object_refs + is_individual_id = isinstance(object_refs, ray.ObjectRef) if is_individual_id: object_refs = [object_refs] @@ -2817,6 +2822,10 @@ def cancel(object_ref: "ray.ObjectRef", *, force: bool = False, recursive: bool worker = ray._private.worker.global_worker worker.check_connected() + if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): + assert hasattr(object_ref, "_generator_ref") + object_ref = object_ref._generator_ref + if not isinstance(object_ref, ray.ObjectRef): raise TypeError( "ray.cancel() only supported for non-actor object refs. " diff --git a/python/ray/_private/workers/default_worker.py b/python/ray/_private/workers/default_worker.py index 937f45a8b85d..462c9e284f49 100644 --- a/python/ray/_private/workers/default_worker.py +++ b/python/ray/_private/workers/default_worker.py @@ -169,7 +169,6 @@ # https://github.com/ray-project/ray/pull/12225#issue-525059663. args = parser.parse_args() ray._private.ray_logging.setup_logger(args.logging_level, args.logging_format) - worker_launched_time_ms = time.time_ns() // 1e6 if args.worker_type == "WORKER": diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 6af1879a5d8a..28a7632ed8c1 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -143,6 +143,7 @@ cdef class CoreWorker: self, worker, outputs, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + const CAddress &caller_address, CObjectID ref_generator_id=*) cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 89773c560aaf..6762ade4578e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -741,7 +741,8 @@ cdef store_task_errors( CTaskType task_type, proctitle, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, - c_string* application_error): + c_string* application_error, + const CAddress &caller_address): cdef: CoreWorker core_worker = worker.core_worker @@ -785,7 +786,8 @@ cdef store_task_errors( errors.append(failure_object) num_errors_stored = core_worker.store_task_outputs( worker, errors, - returns) + returns, + caller_address) ray._private.utils.push_error_to_driver( worker, @@ -797,6 +799,160 @@ cdef store_task_errors( return num_errors_stored +cdef execute_streaming_generator( + generator, + const CObjectID &generator_id, + CTaskType task_type, + const CAddress &caller_address, + TaskID task_id, + const c_string &serialized_retry_exception_allowlist, + function_name, + function_descriptor, + title, + actor, + actor_id, + c_bool *is_retryable_error, + c_string *application_error): + """Execute a given generator and streaming-report the + result to the given caller_address. + + The output from the generator will be stored to the in-memory + or plasma object store. The generated return objects will be + reported to the owner of the task as soon as they are generated. + + It means when this method is used, the result of each generator + will be reported and available from the given "caller address" + before the task is finished. + + Args: + generator: The generator to run. + generator_id: The object ref id of the generator task. + task_type: The type of the task. E.g., actor task, normal task. + caller_address: The address of the caller. By our protocol, + the caller of the streaming generator task is always + the owner, so we can also call it "owner address". + task_id: The task ID of the generator task. + serialized_retry_exception_allowlist: A list of + exceptions that are allowed to retry this generator task. + function_name: The name of the generator function. Used for + writing an error message. + function_descriptor: The function descriptor of + the generator function. Used for writing an error message. + title: The process title of the generator task. Used for + writing an error message. + actor: The instance of the actor created in this worker. + It is used to write an error message. + actor_id: The ID of the actor. It is used to write an error message. + is_retryable_error(out): It is set to True if the generator + raises an exception, and the error is retryable. + application_error(out): It is set if the generator raises an + application error. + """ + worker = ray._private.worker.global_worker + cdef: + CoreWorker core_worker = worker.core_worker + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result + + generator_index = 0 + assert inspect.isgenerator(generator), ( + "execute_generator's first argument must be a generator." + ) + + while True: + try: + output = next(generator) + except StopIteration: + break + except Exception as e: + # Report the error if the generator failed to execute. + is_retryable_error[0] = determine_if_retryable( + e, + serialized_retry_exception_allowlist, + function_descriptor, + ) + + if ( + is_retryable_error[0] + and core_worker.get_current_task_retry_exceptions() + ): + logger.debug("Task failed with retryable exception:" + " {}.".format(task_id), exc_info=True) + # Raise an exception directly and halt the execution + # because there's no need to set the exception + # for the return value when the task is retryable. + raise e + + logger.debug("Task failed with unretryable exception:" + " {}.".format(task_id), exc_info=True) + + error_id = (CCoreWorkerProcess.GetCoreWorker() + .AllocateDynamicReturnId(caller_address)) + intermediate_result.push_back( + c_pair[CObjectID, shared_ptr[CRayObject]]( + error_id, shared_ptr[CRayObject]())) + + store_task_errors( + worker, e, + True, # task_exception + actor, # actor + actor_id, # actor id + function_name, task_type, title, + &intermediate_result, application_error, caller_address) + + CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( + intermediate_result.back(), + generator_id, caller_address, generator_index, False) + + if intermediate_result.size() > 0: + intermediate_result.pop_back() + generator_index += 1 + break + else: + # Report the intermediate result if there was no error. + return_id = ( + CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + caller_address)) + intermediate_result.push_back( + c_pair[CObjectID, shared_ptr[CRayObject]]( + return_id, shared_ptr[CRayObject]())) + + core_worker.store_task_outputs( + worker, [output], + &intermediate_result, + caller_address, + generator_id) + logger.debug( + "Writes to a ObjectRefStream of an " + "index {}".format(generator_index)) + assert intermediate_result.size() == 1 + del output + + CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( + intermediate_result.back(), + generator_id, + caller_address, + generator_index, + False) + + if intermediate_result.size() > 0: + intermediate_result.pop_back() + generator_index += 1 + + # All the intermediate result has to be popped and reported. + assert intermediate_result.size() == 0 + # Report the owner that there's no more objects. + logger.debug( + "Writes EoF to a ObjectRefStream " + "of an index {}".format(generator_index)) + CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( + c_pair[CObjectID, shared_ptr[CRayObject]]( + CObjectID.Nil(), shared_ptr[CRayObject]()), + generator_id, + caller_address, + generator_index, + True) # finished. + + cdef execute_dynamic_generator_and_store_task_outputs( generator, const CObjectID &generator_id, @@ -808,7 +964,8 @@ cdef execute_dynamic_generator_and_store_task_outputs( c_bool is_reattempt, function_name, function_descriptor, - title): + title, + const CAddress &caller_address): worker = ray._private.worker.global_worker cdef: CoreWorker core_worker = worker.core_worker @@ -817,6 +974,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( core_worker.store_task_outputs( worker, generator, dynamic_returns, + caller_address, generator_id) except Exception as error: is_retryable_error[0] = determine_if_retryable( @@ -844,7 +1002,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( # generate one additional ObjectRef. This last # ObjectRef will contain the error. error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId()) + .AllocateDynamicReturnId(caller_address)) dynamic_returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -858,7 +1016,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( None, # actor None, # actor id function_name, task_type, title, - dynamic_returns, application_error) + dynamic_returns, application_error, caller_address) if num_errors_stored == 0: assert is_reattempt # TODO(swang): The generator task failed and we @@ -895,7 +1053,8 @@ cdef void execute_task( c_bool is_reattempt, execution_info, title, - task_name) except *: + task_name, + c_bool is_streaming_generator) except *: worker = ray._private.worker.global_worker manager = worker.function_actor_manager actor = None @@ -1053,6 +1212,35 @@ cdef void execute_task( ray.util.pdb.set_trace( breakpoint_uuid=debugger_breakpoint) outputs = function_executor(*args, **kwargs) + + if is_streaming_generator: + # Streaming generator always has a single return value + # which is the generator task return. + assert returns[0].size() == 1 + + if not inspect.isgenerator(outputs): + raise ValueError( + "Functions with " + "@ray.remote(num_returns=\"streaming\" " + "must return a generator") + + execute_streaming_generator( + outputs, + returns[0][0].first, # generator object ID. + task_type, + caller_address, + task_id, + serialized_retry_exception_allowlist, + function_name, + function_descriptor, + title, + actor, + actor_id, + is_retryable_error, + application_error) + # Streaming generator output is not used, so set it to None. + outputs = None + next_breakpoint = ( ray._private.worker.global_worker.debugger_breakpoint) if next_breakpoint != b"": @@ -1137,7 +1325,9 @@ cdef void execute_task( # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): - if dynamic_returns != NULL: + # TODO(sang): Remove it once we use streaming generator + # by default. + if dynamic_returns != NULL and not is_streaming_generator: if not inspect.isgenerator(outputs): raise ValueError( "Functions with " @@ -1156,7 +1346,8 @@ cdef void execute_task( is_reattempt, function_name, function_descriptor, - title) + title, + caller_address) task_exception = False dynamic_refs = [] @@ -1174,11 +1365,12 @@ cdef void execute_task( # all generator tasks, both static and dynamic. core_worker.store_task_outputs( worker, outputs, - returns) + returns, + caller_address) except Exception as e: num_errors_stored = store_task_errors( worker, e, task_exception, actor, actor_id, function_name, - task_type, title, returns, application_error) + task_type, title, returns, application_error, caller_address) if returns[0].size() > 0 and num_errors_stored == 0: logger.exception( "Unhandled error: Task threw exception, but all " @@ -1205,7 +1397,8 @@ cdef execute_task_with_cancellation_handler( # the concurrency groups of this actor. const c_vector[CConcurrencyGroup] &c_defined_concurrency_groups, const c_string c_name_of_concurrency_group_to_execute, - c_bool is_reattempt): + c_bool is_reattempt, + c_bool is_streaming_generator): is_retryable_error[0] = False @@ -1290,7 +1483,8 @@ cdef execute_task_with_cancellation_handler( application_error, c_defined_concurrency_groups, c_name_of_concurrency_group_to_execute, - is_reattempt, execution_info, title, task_name) + is_reattempt, execution_info, title, task_name, + is_streaming_generator) # Check for cancellation. PyErr_CheckSignals() @@ -1317,7 +1511,8 @@ cdef execute_task_with_cancellation_handler( task_type, title, returns, # application_error: we are passing NULL since we don't want the # cancel tasks to fail. - NULL) + NULL, + caller_address) finally: with current_task_id_lock: current_task_id = None @@ -1362,7 +1557,8 @@ cdef CRayStatus task_execution_handler( c_string *application_error, const c_vector[CConcurrencyGroup] &defined_concurrency_groups, const c_string name_of_concurrency_group_to_execute, - c_bool is_reattempt) nogil: + c_bool is_reattempt, + c_bool is_streaming_generator) nogil: with gil, disable_client_hook(): # Initialize job_config if it hasn't already. # Setup system paths configured in job_config. @@ -1386,7 +1582,8 @@ cdef CRayStatus task_execution_handler( application_error, defined_concurrency_groups, name_of_concurrency_group_to_execute, - is_reattempt) + is_reattempt, + is_streaming_generator) except Exception as e: sys_exit = SystemExit() if isinstance(e, RayActorError) and \ @@ -2943,6 +3140,7 @@ cdef class CoreWorker: worker, outputs, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + const CAddress &caller_address, CObjectID ref_generator_id=CObjectID.Nil()): cdef: CObjectID return_id @@ -2982,9 +3180,11 @@ cdef class CoreWorker: raise ValueError( "Task returned more than num_returns={} objects.".format( num_returns)) + # TODO(sang): Remove it when the streaming generator is + # enabled by default. while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId()) + .AllocateDynamicReturnId(caller_address)) returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) diff --git a/python/ray/actor.py b/python/ray/actor.py index 7191031e059b..91b88de7b947 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -22,7 +22,7 @@ ) from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group from ray._private.utils import get_runtime_env_info, parse_runtime_env -from ray._raylet import PythonFunctionDescriptor +from ray._raylet import PythonFunctionDescriptor, StreamingObjectRefGenerator from ray.exceptions import AsyncioActorExit from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.placement_group import _configure_placement_group_based_on_context @@ -1167,6 +1167,10 @@ def _actor_method_call( if num_returns == "dynamic": num_returns = -1 + elif num_returns == "streaming": + # TODO(sang): This is a temporary private API. + # Remove it when we migrate to the streaming generator. + num_returns = -2 object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, @@ -1179,6 +1183,12 @@ def _actor_method_call( concurrency_group_name if concurrency_group_name is not None else b"", ) + if num_returns == -2: + # Streaming generator will return a single ref + # that is for the generator task. + assert len(object_refs) == 1 + generator_ref = object_refs[0] + return StreamingObjectRefGenerator(generator_ref, worker) if len(object_refs) == 1: object_refs = object_refs[0] elif len(object_refs) == 0: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 8dac68ea651e..b8a5f14f9d6b 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -151,7 +151,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus AsyncReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) - CObjectID AllocateDynamicReturnId() + CObjectID AllocateDynamicReturnId(const CAddress &owner_address) CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() @@ -315,7 +315,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_string *application_error, const c_vector[CConcurrencyGroup] &defined_concurrency_groups, const c_string name_of_concurrency_group_to_execute, - c_bool is_reattempt) nogil + c_bool is_reattempt, + c_bool is_streaming_generator) nogil ) task_execution_callback (void(const CWorkerID &) nogil) on_worker_shutdown (CRayStatus() nogil) check_signals diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 79853deff098..bb627f09af92 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -15,7 +15,7 @@ from ray._private.ray_option_utils import _warn_if_using_deprecated_placement_group from ray._private.serialization import pickle_dumps from ray._private.utils import get_runtime_env_info, parse_runtime_env -from ray._raylet import PythonFunctionDescriptor +from ray._raylet import PythonFunctionDescriptor, StreamingObjectRefGenerator from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.placement_group import _configure_placement_group_based_on_context from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -306,6 +306,10 @@ def _remote(self, args=None, kwargs=None, **task_options): num_returns = task_options["num_returns"] if num_returns == "dynamic": num_returns = -1 + elif num_returns == "streaming": + # TODO(sang): This is a temporary private API. + # Remove it when we migrate to the streaming generator. + num_returns = -2 max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] @@ -397,6 +401,12 @@ def invocation(args, kwargs): # Reset worker's debug context from the last "remote" command # (which applies only to this .remote call). worker.debugger_breakpoint = b"" + if num_returns == -2: + # Streaming generator will return a single ref + # that is for the generator task. + assert len(object_refs) == 1 + generator_ref = object_refs[0] + return StreamingObjectRefGenerator(generator_ref, worker) if len(object_refs) == 1: return object_refs[0] elif len(object_refs) > 1: diff --git a/python/ray/tests/test_generators.py b/python/ray/tests/test_generators.py index 9284c6a3f8c3..3430da39cda2 100644 --- a/python/ray/tests/test_generators.py +++ b/python/ray/tests/test_generators.py @@ -117,7 +117,10 @@ def generator(num_returns, store_in_plasma): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_generator_errors(ray_start_regular, use_actors, store_in_plasma): +@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +def test_generator_errors( + ray_start_regular, use_actors, store_in_plasma, num_returns_type +): remote_generator_fn = None if use_actors: @@ -158,7 +161,7 @@ def generator(num_returns, store_in_plasma): with pytest.raises(ray.exceptions.RayTaskError): ray.get(ref3) - dynamic_ref = remote_generator_fn.options(num_returns="dynamic").remote( + dynamic_ref = remote_generator_fn.options(num_returns=num_returns_type).remote( 3, store_in_plasma ) ref1, ref2 = ray.get(dynamic_ref) @@ -218,10 +221,13 @@ def generator(num_returns, store_in_plasma, counter): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_dynamic_generator(ray_start_regular, use_actors, store_in_plasma): +@pytest.mark.parametrize("num_returns_type", ["streaming"]) +def test_dynamic_generator( + ray_start_regular, use_actors, store_in_plasma, num_returns_type +): if use_actors: - @ray.remote(num_returns="dynamic") + @ray.remote(num_returns=num_returns_type) def dynamic_generator(num_returns, store_in_plasma): for i in range(num_returns): if store_in_plasma: @@ -255,21 +261,34 @@ def read(gen): return True gen = ray.get( - remote_generator_fn.options(num_returns="dynamic").remote(10, store_in_plasma) + remote_generator_fn.options(num_returns=num_returns_type).remote( + 10, store_in_plasma + ) ) for i, ref in enumerate(gen): assert ray.get(ref)[0] == i # Test empty generator. gen = ray.get( - remote_generator_fn.options(num_returns="dynamic").remote(0, store_in_plasma) + remote_generator_fn.options(num_returns=num_returns_type).remote( + 0, store_in_plasma + ) ) assert len(list(gen)) == 0 # Check that passing as task arg. - gen = remote_generator_fn.options(num_returns="dynamic").remote(10, store_in_plasma) - assert ray.get(read.remote(gen)) - assert ray.get(read.remote(ray.get(gen))) + if num_returns_type == "dynamic": + gen = remote_generator_fn.options(num_returns=num_returns_type).remote( + 10, store_in_plasma + ) + assert ray.get(read.remote(gen)) + assert ray.get(read.remote(ray.get(gen))) + else: + with pytest.raises(TypeError): + gen = remote_generator_fn.options(num_returns=num_returns_type).remote( + 10, store_in_plasma + ) + assert ray.get(read.remote(gen)) # Also works if we override num_returns with a static value. ray.get( @@ -279,7 +298,7 @@ def read(gen): ) # Normal remote functions don't work with num_returns="dynamic". - @ray.remote(num_returns="dynamic") + @ray.remote(num_returns=num_returns_type) def static(num_returns): return list(range(num_returns)) @@ -289,7 +308,8 @@ def static(num_returns): ray.get(ref) -def test_dynamic_generator_distributed(ray_start_cluster): +@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) +def test_dynamic_generator_distributed(ray_start_cluster, num_returns_type): cluster = ray_start_cluster # Head node with no resources. cluster.add_node(num_cpus=0) @@ -297,7 +317,7 @@ def test_dynamic_generator_distributed(ray_start_cluster): cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - @ray.remote(num_returns="dynamic") + @ray.remote(num_returns=num_returns_type) def dynamic_generator(num_returns): for i in range(num_returns): yield np.ones(1_000_000, dtype=np.int8) * i diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index c496d52b6179..277d8226cb50 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -1,10 +1,14 @@ import pytest +import numpy as np import sys import time +import gc from unittest.mock import patch, Mock import ray +from ray._private.test_utils import wait_for_condition +from ray.experimental.state.api import list_objects from ray._raylet import StreamingObjectRefGenerator from ray.cloudpickle import dumps from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError @@ -132,6 +136,270 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) +def test_generator_basic(shutdown_only): + ray.init(num_cpus=1) + + """Basic cases""" + + @ray.remote + def f(): + for i in range(5): + yield i + + gen = f.options(num_returns="streaming").remote() + i = 0 + for ref in gen: + print(ray.get(ref)) + assert i == ray.get(ref) + del ref + i += 1 + + """Exceptions""" + + @ray.remote + def f(): + for i in range(5): + if i == 2: + raise ValueError + yield i + + gen = f.options(num_returns="streaming").remote() + ray.get(next(gen)) + ray.get(next(gen)) + with pytest.raises(ray.exceptions.RayTaskError) as e: + ray.get(next(gen)) + print(str(e.value)) + with pytest.raises(StopIteration): + ray.get(next(gen)) + with pytest.raises(StopIteration): + ray.get(next(gen)) + + """Generator Task failure""" + + @ray.remote + class A: + def getpid(self): + import os + + return os.getpid() + + def f(self): + for i in range(5): + time.sleep(0.1) + yield i + + a = A.remote() + i = 0 + gen = a.f.options(num_returns="streaming").remote() + i = 0 + for ref in gen: + if i == 2: + ray.kill(a) + if i == 3: + with pytest.raises(ray.exceptions.RayActorError) as e: + ray.get(ref) + assert "The actor is dead because it was killed by `ray.kill`" in str( + e.value + ) + break + assert i == ray.get(ref) + del ref + i += 1 + for _ in range(10): + with pytest.raises(StopIteration): + next(gen) + + """Retry exceptions""" + # TODO(sang): Enable it once retry is supported. + # @ray.remote + # class Actor: + # def __init__(self): + # self.should_kill = True + + # def should_kill(self): + # return self.should_kill + + # async def set(self, wait_s): + # await asyncio.sleep(wait_s) + # self.should_kill = False + + # @ray.remote(retry_exceptions=[ValueError], max_retries=10) + # def f(a): + # for i in range(5): + # should_kill = ray.get(a.should_kill.remote()) + # if i == 3 and should_kill: + # raise ValueError + # yield i + + # a = Actor.remote() + # gen = f.options(num_returns="streaming").remote(a) + # assert ray.get(next(gen)) == 0 + # assert ray.get(next(gen)) == 1 + # assert ray.get(next(gen)) == 2 + # a.set.remote(3) + # assert ray.get(next(gen)) == 3 + # assert ray.get(next(gen)) == 4 + # with pytest.raises(StopIteration): + # ray.get(next(gen)) + + """Cancel""" + + @ray.remote + def f(): + for i in range(5): + time.sleep(5) + yield i + + gen = f.options(num_returns="streaming").remote() + assert ray.get(next(gen)) == 0 + ray.cancel(gen) + with pytest.raises(ray.exceptions.RayTaskError) as e: + assert ray.get(next(gen)) == 1 + assert "was cancelled" in str(e.value) + with pytest.raises(StopIteration): + next(gen) + + +@pytest.mark.parametrize("crash_type", ["exception", "worker_crash"]) +def test_generator_streaming_no_leak_upon_failures( + monkeypatch, shutdown_only, crash_type +): + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server.ReportIntermediateTaskReturn=100000:1000000", + ) + ray.init(num_cpus=1) + + @ray.remote + def g(): + try: + gen = f.options(num_returns="streaming").remote() + for ref in gen: + print(ref) + ray.get(ref) + except Exception: + print("exception!") + del ref + + del gen + gc.collect() + + # Only the ref g is alive. + def verify(): + print(list_objects()) + return len(list_objects()) == 1 + + wait_for_condition(verify) + return True + + @ray.remote + def f(): + for i in range(10): + time.sleep(0.2) + if i == 4: + if crash_type == "exception": + raise ValueError + else: + sys.exit(9) + yield 2 + + for _ in range(5): + ray.get(g.remote()) + + +@pytest.mark.parametrize("use_actors", [False, True]) +@pytest.mark.parametrize("store_in_plasma", [False, True]) +def test_generator_streaming(shutdown_only, use_actors, store_in_plasma): + """Verify the generator is working in a streaming fashion.""" + ray.init() + remote_generator_fn = None + if use_actors: + + @ray.remote + class Generator: + def __init__(self): + pass + + def generator(self, num_returns, store_in_plasma): + for i in range(num_returns): + if store_in_plasma: + yield np.ones(1_000_000, dtype=np.int8) * i + else: + yield [i] + + g = Generator.remote() + remote_generator_fn = g.generator + else: + + @ray.remote(max_retries=0) + def generator(num_returns, store_in_plasma): + for i in range(num_returns): + if store_in_plasma: + yield np.ones(1_000_000, dtype=np.int8) * i + else: + yield [i] + + remote_generator_fn = generator + + """Verify num_returns="streaming" is streaming""" + gen = remote_generator_fn.options(num_returns="streaming").remote( + 3, store_in_plasma + ) + i = 0 + for ref in gen: + id = ref.hex() + if store_in_plasma: + expected = np.ones(1_000_000, dtype=np.int8) * i + assert np.array_equal(ray.get(ref), expected) + else: + expected = [i] + assert ray.get(ref) == expected + + del ref + + wait_for_condition( + lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 + ) + i += 1 + + +def test_generator_dist_chain(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) + ray.init() + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + + @ray.remote + class ChainActor: + def __init__(self, child=None): + self.child = child + + def get_data(self): + if not self.child: + for _ in range(10): + time.sleep(0.1) + yield np.ones(5 * 1024 * 1024) + else: + for data in self.child.get_data.options( + num_returns="streaming" + ).remote(): + yield ray.get(data) + + chain_actor = ChainActor.remote() + chain_actor_2 = ChainActor.remote(chain_actor) + chain_actor_3 = ChainActor.remote(chain_actor_2) + chain_actor_4 = ChainActor.remote(chain_actor_3) + + for ref in chain_actor_4.get_data.options(num_returns="streaming").remote(): + assert np.array_equal(np.ones(5 * 1024 * 1024), ray.get(ref)) + del ref + + if __name__ == "__main__": import os diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 71000748cb44..11e4778b297e 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -218,6 +218,12 @@ ObjectID TaskSpecification::ReturnId(size_t return_index) const { bool TaskSpecification::ReturnsDynamic() const { return message_->returns_dynamic(); } +// TODO(sang): Merge this with ReturnsDynamic once migrating to the +// streaming generator. +bool TaskSpecification::IsStreamingGenerator() const { + return message_->streaming_generator(); +} + std::vector TaskSpecification::DynamicReturnIds() const { RAY_CHECK(message_->returns_dynamic()); std::vector dynamic_return_ids; diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 3b29d2aadb3b..eea53f3d0348 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -262,6 +262,8 @@ class TaskSpecification : public MessageWrapper { bool ReturnsDynamic() const; + bool IsStreamingGenerator() const; + std::vector DynamicReturnIds() const; void AddDynamicReturnId(const ObjectID &dynamic_return_id); diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c260745b7161..1110504ea0b5 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -126,6 +126,7 @@ class TaskSpecBuilder { const rpc::Address &caller_address, uint64_t num_returns, bool returns_dynamic, + bool is_streaming_generator, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const std::string &debugger_breakpoint, @@ -149,6 +150,7 @@ class TaskSpecBuilder { message_->mutable_caller_address()->CopyFrom(caller_address); message_->set_num_returns(num_returns); message_->set_returns_dynamic(returns_dynamic); + message_->set_streaming_generator(is_streaming_generator); message_->mutable_required_resources()->insert(required_resources.begin(), required_resources.end()); message_->mutable_required_placement_resources()->insert( diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e3088741be3e..9f2d950db681 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1850,6 +1850,16 @@ void CoreWorker::BuildCommonTaskSpec( // is a generator of ObjectRefs. num_returns = 1; } + // TODO(sang): Remove this and integrate it to + // nun_returns == -1 once migrating to streaming + // generator. + bool is_streaming_generator = num_returns == -2; + if (is_streaming_generator) { + num_returns = 1; + // We are using the dynamic return if + // the streaming generator is used. + returns_dynamic = true; + } RAY_CHECK(num_returns >= 0); builder.SetCommonTaskSpec( task_id, @@ -1866,6 +1876,7 @@ void CoreWorker::BuildCommonTaskSpec( address, num_returns, returns_dynamic, + is_streaming_generator, required_resources, required_placement_resources, debugger_breakpoint, @@ -2662,7 +2673,8 @@ Status CoreWorker::ExecuteTask( application_error, defined_concurrency_groups, name_of_concurrency_group_to_execute, - /*is_reattempt=*/task_spec.AttemptNumber() > 0); + /*is_reattempt=*/task_spec.AttemptNumber() > 0, + /*is_streaming_generator*/ task_spec.IsStreamingGenerator()); // Get the reference counts for any IDs that we borrowed during this task, // remove the local reference for these IDs, and return the ref count info to @@ -2830,13 +2842,12 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, } } -ObjectID CoreWorker::AllocateDynamicReturnId() { +ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) { const auto &task_spec = worker_context_.GetCurrentTask(); const auto return_id = ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); AddLocalReference(return_id, ""); - reference_counter_->AddBorrowedObject( - return_id, ObjectID::Nil(), worker_context_.GetCurrentTask()->CallerAddress()); + reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); return return_id; } @@ -3257,7 +3268,11 @@ void CoreWorker::ProcessSubscribeForObjectEviction( // counter so that we know that it exists. const auto generator_id = ObjectID::FromBinary(message.generator_id()); RAY_CHECK(!generator_id.IsNil()); - reference_counter_->AddDynamicReturn(object_id, generator_id); + if (task_manager_->ObjectRefStreamExists(generator_id)) { + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, generator_id); + } else { + reference_counter_->AddDynamicReturn(object_id, generator_id); + } } // Returns true if the object was present and the callback was added. It might have @@ -3391,7 +3406,11 @@ void CoreWorker::AddSpilledObjectLocationOwner( // object. Add the dynamically created object to our ref counter so that we // know that it exists. RAY_CHECK(!generator_id->IsNil()); - reference_counter_->AddDynamicReturn(object_id, *generator_id); + if (task_manager_->ObjectRefStreamExists(*generator_id)) { + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, *generator_id); + } else { + reference_counter_->AddDynamicReturn(object_id, *generator_id); + } } auto reference_exists = @@ -3419,9 +3438,14 @@ void CoreWorker::AddObjectLocationOwner(const ObjectID &object_id, // until the task finishes. const auto &maybe_generator_id = task_manager_->TaskGeneratorId(object_id.TaskId()); if (!maybe_generator_id.IsNil()) { - // The task is a generator and may not have finished yet. Add the internal - // ObjectID so that we can update its location. - reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); + if (task_manager_->ObjectRefStreamExists(maybe_generator_id)) { + // If the stream exists, it means it is a streaming generator. + reference_counter_->OwnDynamicStreamingTaskReturnRef(object_id, maybe_generator_id); + } else { + // The task is a generator and may not have finished yet. Add the internal + // ObjectID so that we can update its location. + reference_counter_->AddDynamicReturn(object_id, maybe_generator_id); + } RAY_UNUSED(reference_counter_->AddObjectLocation(object_id, node_id)); } } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 4b27ce3f0b2a..1f8c725f7080 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1022,9 +1022,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// object to the task caller and have the resulting ObjectRef be owned by /// the caller. This is in contrast to static allocation, where the caller /// decides at task invocation time how many returns the task should have. + /// \param[in] owner_address The address of the owner who will own this + /// dynamically generated object. /// /// \param[out] The ObjectID that the caller should use to store the object. - ObjectID AllocateDynamicReturnId(); + ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address); /// Get a handle to an actor. /// diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 157a3fbc53a3..3a8346776077 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -56,7 +56,10 @@ struct CoreWorkerOptions { // used for actor creation task. const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt)>; + bool is_reattempt, + // True if the task is for streaming generator. + // TODO(sang): Remove it and combine it with dynamic returns. + bool is_streaming_generator)>; CoreWorkerOptions() : store_socket(""), diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 5afb92f853be..109dd0dc9686 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -124,7 +124,8 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt) { + bool is_reattempt, + bool is_streaming_generator) { // These 2 parameters are used for Python only, and Java worker // will not use them. RAY_UNUSED(defined_concurrency_groups); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 31a97db7bd4f..62dd91f4474b 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -570,6 +570,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { address, num_returns, false, + false, resources, resources, "", diff --git a/src/ray/core_worker/test/dependency_resolver_test.cc b/src/ray/core_worker/test/dependency_resolver_test.cc index 4d2406e006ec..5ca82b773b7a 100644 --- a/src/ray/core_worker/test/dependency_resolver_test.cc +++ b/src/ray/core_worker/test/dependency_resolver_test.cc @@ -44,6 +44,7 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r empty_address, 1, false, + false, resources, resources, serialized_runtime_env, diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index 61eb4370c3f4..498551b61334 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -65,6 +65,7 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r empty_address, 1, false, + false, resources, resources, serialized_runtime_env, diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 1c782438ae28..7529a5255ee0 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -67,7 +67,8 @@ class MockWorker { std::string *application_error, const std::vector &defined_concurrency_groups, const std::string name_of_concurrency_group_to_execute, - bool is_reattempt) { + bool is_reattempt, + bool is_streaming_generator) { return ExecuteTask(caller_address, task_type, task_name, diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index fdef576c32e3..744b3ae2bb2a 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -58,6 +58,7 @@ struct Mocker { owner_address, 1, false, + false, required_resources, required_placement_resources, "", diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index b78e354768a4..bf10020a37b9 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -422,6 +422,12 @@ message TaskSpec { // This will be the actor creation task's task id for concurrent actors. Or // the main thread's task id for other cases. bytes submitter_task_id = 33; + // True if the task is a streaming generator. When it is true, + // returns_dynamic has to be true as well. This is a temporary flag + // until we migrate the generator implementatino to streaming. + // TODO(sang): Remove it once migrating to the streaming generator + // by default. + bool streaming_generator = 34; } message TaskInfoEntry { diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index de2bd227996c..d5e17ee0fe62 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -165,6 +165,7 @@ RayTask CreateTask( address, 0, /*returns_dynamic=*/false, + /*is_streaming_generator*/ false, required_resources, {}, "", From 3e0212e7ebbfcd6c58671cd686d78c4fad40fa15 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 21:06:39 -0700 Subject: [PATCH 13/77] Fix failing tests. Signed-off-by: SangBin Cho --- src/ray/core_worker/test/task_manager_test.cc | 6 +----- src/ray/protobuf/core_worker.proto | 3 ++- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 0fc669be9082..a565633ea672 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -29,8 +29,7 @@ namespace core { TaskSpecification CreateTaskHelper(uint64_t num_returns, std::vector dependencies, - bool dynamic_returns = false, - bool streaming_generator = false) { + bool dynamic_returns = false) { TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::FromRandom(JobID::FromInt(1)).Binary()); task.GetMutableMessage().set_num_returns(num_returns); @@ -42,9 +41,6 @@ TaskSpecification CreateTaskHelper(uint64_t num_returns, if (dynamic_returns) { task.GetMutableMessage().set_returns_dynamic(true); } - if (streaming_generator) { - task.GetMutableMessage().set_streaming_generator(true); - } return task; } diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 605c3a5460de..65b0b077866d 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -424,7 +424,8 @@ service CoreWorkerService { /// the caller (subscriber). rpc PubsubLongPolling(PubsubLongPollingRequest) returns (PubsubLongPollingReply); // The RPC to report the intermediate task return to the caller. - rpc ReportIntermediateTaskReturn(ReportIntermediateTaskReturnRequest) returns (ReportIntermediateTaskReturnReply); + rpc ReportIntermediateTaskReturn(ReportIntermediateTaskReturnRequest) + returns (ReportIntermediateTaskReturnReply); /// The pubsub command batch request used by the subscriber. rpc PubsubCommandBatch(PubsubCommandBatchRequest) returns (PubsubCommandBatchReply); // Update the batched object location information to the ownership-based object From 7610474388b8178c727412ce417800a4684f4f67 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 21:08:27 -0700 Subject: [PATCH 14/77] Fix Signed-off-by: SangBin Cho --- src/ray/core_worker/test/task_manager_test.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index a565633ea672..0fc669be9082 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -29,7 +29,8 @@ namespace core { TaskSpecification CreateTaskHelper(uint64_t num_returns, std::vector dependencies, - bool dynamic_returns = false) { + bool dynamic_returns = false, + bool streaming_generator = false) { TaskSpecification task; task.GetMutableMessage().set_task_id(TaskID::FromRandom(JobID::FromInt(1)).Binary()); task.GetMutableMessage().set_num_returns(num_returns); @@ -41,6 +42,9 @@ TaskSpecification CreateTaskHelper(uint64_t num_returns, if (dynamic_returns) { task.GetMutableMessage().set_returns_dynamic(true); } + if (streaming_generator) { + task.GetMutableMessage().set_streaming_generator(true); + } return task; } From aaa058255d95c0c18346bde5abfd793830025e3c Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 14 May 2023 22:23:27 -0700 Subject: [PATCH 15/77] Fix a broken test. Signed-off-by: SangBin Cho --- src/ray/core_worker/test/task_manager_test.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index a565633ea672..f8e1d92c212d 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -1425,8 +1425,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { */ // Submit a task. rpc::Address caller_address; - auto spec = - CreateTaskHelper(1, {}, /*dynamic_returns=*/true, /*streaming_generator*/ true); + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); auto generator_id = spec.ReturnId(0); manager_.AddPendingTask(caller_address, spec, "", /*num_retries=*/0); // CREATE @@ -1601,8 +1600,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { /** * Test the case where the task return RPC is received out of order */ - auto spec = - CreateTaskHelper(1, {}, /*dynamic_returns=*/true, /*streaming_generator*/ true); + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); auto generator_id = spec.ReturnId(0); // CREATE manager_.CreateObjectRefStream(generator_id); From 74a2e31252517fb99898f5560f0204f694aee80d Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 04:25:25 -0700 Subject: [PATCH 16/77] Finished async actor. Signed-off-by: SangBin Cho --- python/ray/_private/async_compat.py | 6 +- python/ray/_raylet.pxd | 10 +- python/ray/_raylet.pyx | 336 ++++++++++++++----- python/ray/actor.py | 8 +- python/ray/includes/libcoreworker.pxd | 6 +- python/ray/includes/unique_ids.pxd | 2 + python/ray/tests/test_streaming_generator.py | 250 +++++++++++++- python/ray/util/tracing/tracing_helper.py | 6 + src/ray/core_worker/core_worker.cc | 23 +- src/ray/core_worker/core_worker.h | 22 +- 10 files changed, 565 insertions(+), 104 deletions(-) diff --git a/python/ray/_private/async_compat.py b/python/ray/_private/async_compat.py index b1ecccf2590e..7821c6424d2f 100644 --- a/python/ray/_private/async_compat.py +++ b/python/ray/_private/async_compat.py @@ -19,10 +19,14 @@ def get_new_event_loop(): return asyncio.new_event_loop() +def is_async_func(func): + return inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) + + def sync_to_async(func): """Convert a blocking function to async function""" - if inspect.iscoroutinefunction(func): + if is_async_func(func): return func async def wrapper(*args, **kwargs): diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 28a7632ed8c1..38d35bde3ba6 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -39,7 +39,8 @@ from ray.includes.libcoreworker cimport ( from ray.includes.unique_ids cimport ( CObjectID, - CActorID + CActorID, + CTaskID, ) from ray.includes.function_descriptor cimport ( CFunctionDescriptor, @@ -154,6 +155,13 @@ cdef class CoreWorker: cdef python_scheduling_strategy_to_c( self, python_scheduling_strategy, CSchedulingStrategy *c_scheduling_strategy) + cdef CObjectID allocate_dynamic_return_id_for_generator( + self, + const CAddress &owner_address, + const CTaskID &task_id, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + generator_index, + is_async_actor) cdef class FunctionDescriptor: cdef: diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 6762ade4578e..a5f751b7cdb1 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -144,7 +144,11 @@ from ray.util.scheduling_strategies import ( import ray._private.ray_constants as ray_constants import ray.cloudpickle as ray_pickle from ray.core.generated.common_pb2 import ActorDiedErrorContext -from ray._private.async_compat import sync_to_async, get_new_event_loop +from ray._private.async_compat import ( + sync_to_async, + get_new_event_loop, + is_async_func +) from ray._private.client_mode_hook import disable_client_hook import ray._private.gcs_utils as gcs_utils import ray._private.memory_monitor as memory_monitor @@ -226,9 +230,15 @@ class StreamingObjectRefGenerator: up to N + 1 objects (if there's a system failure, the last object will contain a system level exception). """ - return self._next() + return self._next_sync() + + def __aiter__(self): + return self - def _next( + async def __anext__(self): + return await self._next_async() + + def _next_sync( self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, @@ -252,6 +262,8 @@ class StreamingObjectRefGenerator: last object will contain a system level exception). Args: + is_async: True if the generator is used inside + an async event loop. False otherwise. timeout_s: If the next object is not ready within this timeout, it returns the nil object ref. sleep_interval_s: busy waiting interval. @@ -260,71 +272,124 @@ class StreamingObjectRefGenerator: available within this time, it will hard fail the generator. """ - obj = self._handle_next() + obj = self._handle_next_sync() last_time = time.time() # The generator ref will be None if the task succeeds. # It will contain an exception if the task fails by # a system error. while obj.is_nil(): - if self._generator_task_exception: - # The generator task has failed already. - # We raise StopIteration - # to conform the next interface in Python. - raise StopIteration - else: - # Otherwise, we should ray.get on the generator - # ref to find if the task has a system failure. - # Return the generator ref that contains the system - # error as soon as possible. - r, _ = ray.wait([self._generator_ref], timeout=0) - if len(r) > 0: - try: - ray.get(r) - except Exception as e: - # If it has failed, return the generator task ref - # so that the ref will raise an exception. - self._generator_task_exception = e - return self._generator_ref - finally: - if self._generator_task_completed_time is None: - self._generator_task_completed_time = time.time() - - # Currently, since the ordering of intermediate result report - # is not guaranteed, it is possible that althoug the task - # has succeeded, all of the object references are not reported - # (e.g., when there are network failures). - # If all the object refs are not reported to the generator - # within 30 seconds, we consider is as an unreconverable error. - if self._generator_task_completed_time: - if (time.time() - self._generator_task_completed_time - > unexpected_network_failure_timeout_s): - # It means the next wasn't reported although the task - # has been terminated 30 seconds ago. - self._generator_task_exception = AssertionError - assert False, "Unexpected network failure occured." - - if timeout_s != -1 and time.time() - last_time > timeout_s: - return ObjectRef.nil() - - # 100us busy waiting + error_ref = self._handle_error( + last_time, + timeout_s, + unexpected_network_failure_timeout_s) + if error_ref is not None: + return error_ref + time.sleep(sleep_interval_s) - obj = self._handle_next() + obj = self._handle_next_sync() + return obj - def _handle_next(self): + async def _next_async( + self, + timeout_s: float = -1, + sleep_interval_s: float = 0.0001, + unexpected_network_failure_timeout_s: float = 30): + """Same API as _next_sync, but it is for async context.""" + obj = await self._handle_next_async() + last_time = time.time() + + # The generator ref will be None if the task succeeds. + # It will contain an exception if the task fails by + # a system error. + while obj.is_nil(): + error_ref = self._handle_error( + last_time, + timeout_s, + unexpected_network_failure_timeout_s) + if error_ref is not None: + return error_ref + + await asyncio.sleep(sleep_interval_s) + obj = await self._handle_next_async() + + return obj + + async def _handle_next_async(self): try: - if hasattr(self.worker, "core_worker"): - obj = self.worker.core_worker.async_read_object_ref_stream( - self._generator_ref) - return obj - else: - raise ValueError( - "Cannot access the core worker. " - "Did you already shutdown Ray via ray.shutdown()?") + return self._handle_next() + except ObjectRefStreamEoFError: + raise StopAsyncIteration + + def _handle_next_sync(self): + try: + return self._handle_next() except ObjectRefStreamEoFError: raise StopIteration + def _handle_next(self): + if hasattr(self.worker, "core_worker"): + obj = self.worker.core_worker.async_read_object_ref_stream( + self._generator_ref) + return obj + else: + raise ValueError( + "Cannot access the core worker. " + "Did you already shutdown Ray via ray.shutdown()?") + + def _handle_error( + self, + last_time: int, + timeout_s: float, + unexpected_network_failure_timeout_s: float): + """Handle the error case of next APIs. + + Return None if there's no error. Returns a ref if + the ref is supposed to be return. + """ + if self._generator_task_exception: + # The generator task has failed already. + # We raise StopIteration + # to conform the next interface in Python. + raise StopIteration + else: + # Otherwise, we should ray.get on the generator + # ref to find if the task has a system failure. + # Return the generator ref that contains the system + # error as soon as possible. + r, _ = ray.wait([self._generator_ref], timeout=0) + if len(r) > 0: + try: + ray.get(r) + except Exception as e: + # If it has failed, return the generator task ref + # so that the ref will raise an exception. + self._generator_task_exception = e + return self._generator_ref + finally: + if self._generator_task_completed_time is None: + self._generator_task_completed_time = time.time() + + # Currently, since the ordering of intermediate result report + # is not guaranteed, it is possible that althoug the task + # has succeeded, all of the object references are not reported + # (e.g., when there are network failures). + # If all the object refs are not reported to the generator + # within 30 seconds, we consider is as an unreconverable error. + if self._generator_task_completed_time: + if (time.time() - self._generator_task_completed_time + > unexpected_network_failure_timeout_s): + # It means the next wasn't reported although the task + # has been terminated 30 seconds ago. + self._generator_task_exception = AssertionError + assert False, "Unexpected network failure occured." + + if timeout_s != -1 and time.time() - last_time > timeout_s: + return ObjectRef.nil() + + return None + def __del__(self): if hasattr(self.worker, "core_worker"): # NOTE: This can be called multiple times @@ -801,6 +866,7 @@ cdef store_task_errors( cdef execute_streaming_generator( generator, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, const CObjectID &generator_id, CTaskType task_type, const CAddress &caller_address, @@ -811,6 +877,7 @@ cdef execute_streaming_generator( title, actor, actor_id, + name_of_concurrency_group_to_execute, c_bool *is_retryable_error, c_string *application_error): """Execute a given generator and streaming-report the @@ -849,18 +916,32 @@ cdef execute_streaming_generator( application error. """ worker = ray._private.worker.global_worker + # Generator task should only have 1 return object ref, + # which contains None or exceptions (if system error occurs). + assert returns != NULL + assert returns[0].size() == 1 + cdef: CoreWorker core_worker = worker.core_worker c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result generator_index = 0 - assert inspect.isgenerator(generator), ( - "execute_generator's first argument must be a generator." - ) + is_async = inspect.isasyncgen(generator) while True: try: - output = next(generator) + if is_async: + output = core_worker.run_async_func_or_coro_in_event_loop( + generator.__anext__(), + function_descriptor, + name_of_concurrency_group_to_execute) + else: + output = next(generator) + except AsyncioActorExit: + # Make the task handle this exception. + raise + except StopAsyncIteration: + break except StopIteration: break except Exception as e: @@ -885,8 +966,13 @@ cdef execute_streaming_generator( logger.debug("Task failed with unretryable exception:" " {}.".format(task_id), exc_info=True) - error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + error_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + returns, + generator_index, + is_async, + ) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -909,9 +995,13 @@ cdef execute_streaming_generator( break else: # Report the intermediate result if there was no error. - return_id = ( - CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( - caller_address)) + return_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + returns, + generator_index, + is_async, + ) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -1002,7 +1092,8 @@ cdef execute_dynamic_generator_and_store_task_outputs( # generate one additional ObjectRef. This last # ObjectRef will contain the error. error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + .AllocateDynamicReturnId( + caller_address, CTaskID.Nil(), -1)) dynamic_returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -1110,10 +1201,10 @@ cdef void execute_task( if core_worker.current_actor_is_asyncio(): if len(inspect.getmembers( actor.__class__, - predicate=inspect.iscoroutinefunction)) == 0: + predicate=is_async_func)) == 0: error_message = ( - "Failed to create actor. The failure reason " - "is that you set the async flag, but the actor does not " + "Failed to create actor. You set the async flag, " + "but the actor does not " "have any coroutine functions.") raise RayActorError( ActorDiedErrorContext( @@ -1131,7 +1222,7 @@ cdef void execute_task( # transport with max_concurrency flag. increase_recursion_limit() - if inspect.iscoroutinefunction(function.method): + if is_async_func(function.method): async_function = function else: # Just execute the method if it's ray internal method. @@ -1139,10 +1230,15 @@ cdef void execute_task( return function(actor, *arguments, **kwarguments) async_function = sync_to_async(function) - return core_worker.run_async_func_in_event_loop( - async_function, function_descriptor, - name_of_concurrency_group_to_execute, actor, - *arguments, **kwarguments) + if inspect.isasyncgenfunction(function.method): + # The coroutine will be handled separately by + # execute_dynamic_generator_and_store_task_outputs + return async_function(actor, *arguments, **kwarguments) + else: + return core_worker.run_async_func_or_coro_in_event_loop( + async_function, function_descriptor, + name_of_concurrency_group_to_execute, actor, + *arguments, **kwarguments) return function(actor, *arguments, **kwarguments) @@ -1164,7 +1260,7 @@ cdef void execute_task( return (ray._private.worker.global_worker .deserialize_objects( metadata_pairs, object_refs)) - args = core_worker.run_async_func_in_event_loop( + args = core_worker.run_async_func_or_coro_in_event_loop( deserialize_args, function_descriptor, name_of_concurrency_group_to_execute) else: @@ -1218,7 +1314,8 @@ cdef void execute_task( # which is the generator task return. assert returns[0].size() == 1 - if not inspect.isgenerator(outputs): + if (not inspect.isgenerator(outputs) + and not inspect.isasyncgen(outputs)): raise ValueError( "Functions with " "@ray.remote(num_returns=\"streaming\" " @@ -1226,6 +1323,7 @@ cdef void execute_task( execute_streaming_generator( outputs, + returns, returns[0][0].first, # generator object ID. task_type, caller_address, @@ -1236,6 +1334,7 @@ cdef void execute_task( title, actor, actor_id, + name_of_concurrency_group_to_execute, is_retryable_error, application_error) # Streaming generator output is not used, so set it to None. @@ -3150,6 +3249,7 @@ cdef class CoreWorker: int64_t task_output_inlined_bytes int64_t num_returns = -1 shared_ptr[CRayObject] *return_ptr + num_outputs_stored = 0 if not ref_generator_id.IsNil(): # The task specified a dynamic number of return values. Determine @@ -3184,7 +3284,8 @@ cdef class CoreWorker: # enabled by default. while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + .AllocateDynamicReturnId( + caller_address, CTaskID.Nil(), -1)) returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -3314,14 +3415,21 @@ cdef class CoreWorker: return self.eventloop_for_default_cg, self.thread_for_default_cg - def run_async_func_in_event_loop( - self, func, function_descriptor, specified_cgname, *args, **kwargs): - + def run_async_func_or_coro_in_event_loop( + self, func_or_coro, function_descriptor, specified_cgname, *args, **kwargs): + """Run the async function or coroutine to the event loop. + The event loop is running in a separate thread. + """ cdef: CFiberEvent event eventloop, async_thread = self.get_event_loop( function_descriptor, specified_cgname) - coroutine = func(*args, **kwargs) + + if inspect.isawaitable(func_or_coro): + coroutine = func_or_coro + else: + coroutine = func_or_coro(*args, **kwargs) + future = asyncio.run_coroutine_threadsafe(coroutine, eventloop) future.add_done_callback(lambda _: event.Notify()) with nogil: @@ -3456,6 +3564,74 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker() \ .RecordTaskLogEnd(out_end_offset, err_end_offset) + cdef CObjectID allocate_dynamic_return_id_for_generator( + self, + const CAddress &owner_address, + const CTaskID &task_id, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + generator_index, + is_async_actor): + """Allocate a dynamic return ID for a generator task. + + NOTE: When is_async_actor is True, + this API SHOULD NOT BE called + within an async actor's event IO thread. The caller MUST ensure + this for correctness. It is due to the limitation WorkerContext + API when async actor is used. + See https://github.com/ray-project/ray/issues/10324 for further details. + + Args: + owner_address: The address of the owner (caller) of the + generator task. + task_id: The task ID of the generator task. + returns: A list of return objects. This is used to + calculate the number of return values. + generator_index: The index of dynamically generated object + ref. + is_async_actor: True if the allocation is for async actor. + If async actor is used, we should calculate the + put_index ourselves. + """ + assert returns != NULL + cdef: + num_returns = returns[0].size() + + if is_async_actor: + # This part of code has a couple of assumptions. + # - This API is not called within an asyncio event loop + # thread. + # - Ray object ref is generated by incrementing put_index + # whenever a new return value is added or ray.put is called. + # + # When an async actor is used, it uses its own thread to execute + # async tasks. That means all the ray.put will use a put_index + # scoped to a asyncio event loop thread. + # This means the execution thread that this API will be called + # will only create "return" objects. That means if we use + # num_returns + genreator_index as a put_index, it is guaranteed + # to be unique. + # + # Why do we need it? + # + # We have to provide a put_index ourselves here because + # the current implementation only has 1 worker context at any + # given time, meaning WorkerContext::TaskID & WorkerContext::PutIndex + # both could be incorrect (duplicated) when this API is called. + return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + owner_address, + task_id, + # Should add 1 because put index is always incremented + # before it is used. So if you have 1 return object + # the next index will be 2. + 1 + num_returns + generator_index, # put_index + ) + else: + return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + owner_address, + CTaskID.Nil(), + -1 + ) + def create_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() diff --git a/python/ray/actor.py b/python/ray/actor.py index 91b88de7b947..25a39c97cc74 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -9,6 +9,7 @@ import ray._raylet from ray import ActorClassID, Language, cross_language from ray._private import ray_option_utils +from ray._private.async_compat import is_async_func from ray._private.auto_init_hook import auto_init_ray from ray._private.client_mode_hook import ( client_mode_convert_actor, @@ -752,12 +753,7 @@ def _remote(self, args=None, kwargs=None, **actor_options): kwargs = {} meta = self.__ray_metadata__ actor_has_async_methods = ( - len( - inspect.getmembers( - meta.modified_class, predicate=inspect.iscoroutinefunction - ) - ) - > 0 + len(inspect.getmembers(meta.modified_class, predicate=is_async_func)) > 0 ) is_asyncio = actor_has_async_methods diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index b8a5f14f9d6b..4b31551747b9 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -19,6 +19,7 @@ from ray.includes.unique_ids cimport ( CObjectID, CPlacementGroupID, CWorkerID, + ObjectIDIndexType, ) from ray.includes.common cimport ( @@ -151,7 +152,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CRayStatus AsyncReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) - CObjectID AllocateDynamicReturnId(const CAddress &owner_address) + CObjectID AllocateDynamicReturnId( + const CAddress &owner_address, + const CTaskID &task_id, + ObjectIDIndexType put_index) CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index cd7890119a40..2fb14e6322c0 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -173,3 +173,5 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod CPlacementGroupID Of(CJobID job_id) + + ctypedef uint32_t ObjectIDIndexType diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 277d8226cb50..d652206f0886 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -1,3 +1,4 @@ +import asyncio import pytest import numpy as np import sys @@ -49,21 +50,21 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): # Test when there's no new ref, it returns a nil. mocked_ray_wait.return_value = [], [generator_ref] - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) assert ref.is_nil() # When the new ref is available, next should return it. for _ in range(3): new_ref = ray.ObjectRef.from_random() c.async_read_object_ref_stream.return_value = new_ref - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) assert new_ref == ref # When async_read_object_ref_stream raises a # ObjectRefStreamEoFError, it should raise a stop iteration. c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa with pytest.raises(StopIteration): - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) # Make sure we cannot serialize the generator. with pytest.raises(TypeError): @@ -89,7 +90,7 @@ def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): mocked_ray_get.side_effect = WorkerCrashedError() c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) # If the generator task fails by a systsem error, # meaning the ref will raise an exception # it should be returned. @@ -101,7 +102,7 @@ def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): # raise stopIteration regardless of what # the ref contains now. with pytest.raises(StopIteration): - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): @@ -126,14 +127,20 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): # unexpected_network_failure_timeout_s second, # it should fail. c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + ref = generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) assert ref == ray.ObjectRef.nil() time.sleep(1) with pytest.raises(AssertionError): - generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) # After that StopIteration should be raised. with pytest.raises(StopIteration): - generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) def test_generator_basic(shutdown_only): @@ -366,6 +373,7 @@ def generator(num_returns, store_in_plasma): def test_generator_dist_chain(ray_start_cluster): + """E2E test to verify chain of generator works properly.""" cluster = ray_start_cluster cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) ray.init() @@ -400,6 +408,232 @@ def get_data(self): del ref +@pytest.mark.parametrize("store_in_plasma", [False, True]) +def test_actor_streaming_generator(shutdown_only, store_in_plasma): + """Test actor/async actor with sync/async generator interfaces.""" + ray.init() + + @ray.remote + class Actor: + def f(self, ref): + for i in range(3): + yield i + + async def async_f(self, ref): + for i in range(3): + await asyncio.sleep(0.1) + yield i + + def g(self): + return 3 + + a = Actor.remote() + if store_in_plasma: + arr = np.random.rand(5 * 1024 * 1024) + else: + arr = 3 + + def verify_sync_task_executor(): + generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + # Verify it works with next. + assert isinstance(generator, StreamingObjectRefGenerator) + assert ray.get(next(generator)) == 0 + assert ray.get(next(generator)) == 1 + assert ray.get(next(generator)) == 2 + with pytest.raises(StopIteration): + ray.get(next(generator)) + + # Verify it works with for. + generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + for index, ref in enumerate(generator): + assert index == ray.get(ref) + + def verify_async_task_executor(): + # Verify it works with next. + generator = a.async_f.options(num_returns="streaming").remote(ray.put(arr)) + assert isinstance(generator, StreamingObjectRefGenerator) + assert ray.get(next(generator)) == 0 + assert ray.get(next(generator)) == 1 + assert ray.get(next(generator)) == 2 + + # Verify it works with for. + generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + for index, ref in enumerate(generator): + assert index == ray.get(ref) + + async def verify_sync_task_async_generator(): + # Verify anext + async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + assert isinstance(async_generator, StreamingObjectRefGenerator) + for expected in range(3): + ref = await async_generator.__anext__() + assert await ref == expected + with pytest.raises(StopAsyncIteration): + await async_generator.__anext__() + + # Verify async for. + async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + expected = 0 + async for ref in async_generator: + value = await ref + assert value == value + expected += 1 + + async def verify_async_task_async_generator(): + async_generator = a.async_f.options(num_returns="streaming").remote( + ray.put(arr) + ) + assert isinstance(async_generator, StreamingObjectRefGenerator) + for expected in range(3): + ref = await async_generator.__anext__() + assert await ref == expected + with pytest.raises(StopAsyncIteration): + await async_generator.__anext__() + + # Verify async for. + async_generator = a.async_f.options(num_returns="streaming").remote( + ray.put(arr) + ) + expected = 0 + async for value in async_generator: + value = await ref + assert value == value + expected += 1 + + verify_sync_task_executor() + verify_async_task_executor() + asyncio.run(verify_sync_task_async_generator()) + asyncio.run(verify_async_task_async_generator()) + + +def test_actor_generator_call_stats(shutdown_only): + """Verify that the private API _get_actor_call_stats + works correctly when the generator is used. + """ + pass + + +def test_streaming_generator_exception(shutdown_only): + # Verify the exceptions are correctly raised. + # Also verify the followup next will raise StopIteration. + ray.init() + + @ray.remote + class Actor: + def f(self): + raise ValueError + yield 1 # noqa + + async def async_f(self): + raise ValueError + yield 1 # noqa + + a = Actor.remote() + g = a.f.options(num_returns="streaming").remote() + with pytest.raises(ValueError): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + g = a.async_f.options(num_returns="streaming").remote() + with pytest.raises(ValueError): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + +def test_threaded_actor_generator(shutdown_only): + ray.init() + + @ray.remote(max_concurrency=10) + class Actor: + def f(self): + for i in range(30): + time.sleep(0.1) + yield np.ones(1024 * 1024) * i + + @ray.remote(max_concurrency=20) + class AsyncActor: + async def f(self): + for i in range(30): + await asyncio.sleep(0.1) + yield np.ones(1024 * 1024) * i + + async def main(): + a = Actor.remote() + asy = AsyncActor.remote() + + async def run(): + i = 0 + async for ref in a.f.options(num_returns="streaming").remote(): + val = ray.get(ref) + print(val) + print(ref) + assert np.array_equal(val, np.ones(1024 * 1024) * i) + i += 1 + del ref + + async def run2(): + i = 0 + async for ref in asy.f.options(num_returns="streaming").remote(): + val = await ref + print(ref) + print(val) + assert np.array_equal(val, np.ones(1024 * 1024) * i), ref + i += 1 + del ref + + coroutines = [run() for _ in range(10)] + coroutines = [run2() for _ in range(20)] + + await asyncio.gather(*coroutines) + + asyncio.run(main()) + + +def test_generator_dist_all_gather(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) + ray.init() + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + + @ray.remote(num_cpus=1) + class Actor: + def __init__(self, child=None): + self.child = child + + def get_data(self): + for _ in range(10): + time.sleep(0.1) + yield np.ones(5 * 1024 * 1024) + + async def all_gather(): + actor = Actor.remote() + async for ref in actor.get_data.options(num_returns="streaming").remote(): + val = await ref + assert np.array_equal(np.ones(5 * 1024 * 1024), val) + del ref + + async def main(): + await asyncio.gather(all_gather(), all_gather(), all_gather(), all_gather()) + + asyncio.run(main()) + summary = ray._private.internal_api.memory_summary(stats_only=True) + print(summary) + # assert "Spilled" not in summary, summary + + if __name__ == "__main__": import os diff --git a/python/ray/util/tracing/tracing_helper.py b/python/ray/util/tracing/tracing_helper.py index 0c027c33a8e7..985edb0d612c 100644 --- a/python/ray/util/tracing/tracing_helper.py +++ b/python/ray/util/tracing/tracing_helper.py @@ -520,6 +520,12 @@ async def _resume_span( if is_static_method(_cls, name) or is_class_method(method): continue + if inspect.isgeneratorfunction(method) or inspect.isasyncgenfunction(method): + # Right now, this method somehow changes the signature of the method + # when they are generator. + # TODO(sang): Fix it. + continue + # Don't decorate the __del__ magic method. # It's because the __del__ can be called after Python # modules are garbage colleted, which means the modules diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9f2d950db681..2a9c0c0a0bf4 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2842,10 +2842,25 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, } } -ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address) { - const auto &task_spec = worker_context_.GetCurrentTask(); - const auto return_id = - ObjectID::FromIndex(task_spec->TaskId(), worker_context_.GetNextPutIndex()); +ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, + const TaskID &task_id, + ObjectIDIndexType put_index) { + TaskID current_task_id; + if (task_id.IsNil()) { + const auto &task_spec = worker_context_.GetCurrentTask(); + current_task_id = task_spec->TaskId(); + } else { + current_task_id = task_id; + } + + ObjectIDIndexType current_put_index; + if (put_index == -1) { + current_put_index = worker_context_.GetNextPutIndex(); + } else { + current_put_index = put_index; + } + + const auto return_id = ObjectID::FromIndex(current_task_id, current_put_index); AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); return return_id; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 1f8c725f7080..d456442625d1 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1022,11 +1022,27 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// object to the task caller and have the resulting ObjectRef be owned by /// the caller. This is in contrast to static allocation, where the caller /// decides at task invocation time how many returns the task should have. + /// + /// NOTE: Normally task_id and put_index it not necessary to be specified + /// because we can obtain them from the global worker context. However, + /// when the async actor uses this API, it cannot find the correct + /// worker context due to the implementation limitation. + /// In this case, the caller is responsible for providing the correct + /// task ID and index. + /// See https://github.com/ray-project/ray/issues/10324 for the further details. + /// /// \param[in] owner_address The address of the owner who will own this /// dynamically generated object. - /// - /// \param[out] The ObjectID that the caller should use to store the object. - ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address); + /// \param[in] task_id The task id of the dynamically generated return ID. + /// If Nil() is specified, it will deduce the Task ID from the current + /// worker context. + /// \param[in] put_index The equivalent of the return value of + /// WorkerContext::GetNextPutIndex. + /// If -1 is specified, it will deduce the Task ID from the current + /// worker context. + ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address, + const TaskID &task_id = TaskID::Nil(), + ObjectIDIndexType put_index = -1); /// Get a handle to an actor. /// From 8b9ba39b08dc5549435f2ec785a2795689b5c26a Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 05:36:18 -0700 Subject: [PATCH 17/77] Add a unit test. Signed-off-by: SangBin Cho --- python/ray/tests/test_runtime_context.py | 21 +++++++++++ python/ray/tests/test_streaming_generator.py | 39 ++++++++++++++++---- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/python/ray/tests/test_runtime_context.py b/python/ray/tests/test_runtime_context.py index 42b7b5fed42e..503ab6a10320 100644 --- a/python/ray/tests/test_runtime_context.py +++ b/python/ray/tests/test_runtime_context.py @@ -240,6 +240,27 @@ async def func(self): assert max(result["AysncActor.func"]["pending"] for result in results) == 3 +def test_actor_stats_async_actor_generator(ray_start_regular): + signal = SignalActor.remote() + + @ray.remote + class AysncActor: + async def func(self): + await signal.wait.remote() + yield ray.get_runtime_context()._get_actor_call_stats() + + actor = AysncActor.options(max_concurrency=3).remote() + gens = [actor.func.options(num_returns="streaming").remote() for _ in range(6)] + time.sleep(1) + signal.send.remote() + results = [] + for gen in gens: + for ref in gen: + results.append(ray.get(ref)) + assert max(result["AysncActor.func"]["running"] for result in results) == 3 + assert max(result["AysncActor.func"]["pending"] for result in results) == 3 + + # Use default filterwarnings behavior for this test @pytest.mark.filterwarnings("default") def test_ids(ray_start_regular): diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index d652206f0886..73a3b04ceebf 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -143,6 +143,38 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): ) +@pytest.mark.asyncio +async def test_streaming_object_ref_generator_unit_async(mocked_worker): + """ + Verify the basic case: + create a generator -> read values -> nothing more to read -> delete. + """ + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.create_object_ref_stream.assert_called() + + # Test when there's no new ref, it returns a nil. + mocked_ray_wait.return_value = [], [generator_ref] + ref = await generator._next_async(timeout_s=0) + assert ref.is_nil() + + # When the new ref is available, next should return it. + for _ in range(3): + new_ref = ray.ObjectRef.from_random() + c.async_read_object_ref_stream.return_value = new_ref + ref = await generator._next_async(timeout_s=0) + assert new_ref == ref + + # When async_read_object_ref_stream raises a + # ObjectRefStreamEoFError, it should raise a stop iteration. + c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + with pytest.raises(StopAsyncIteration): + ref = await generator._next_async(timeout_s=0) + + def test_generator_basic(shutdown_only): ray.init(num_cpus=1) @@ -506,13 +538,6 @@ async def verify_async_task_async_generator(): asyncio.run(verify_async_task_async_generator()) -def test_actor_generator_call_stats(shutdown_only): - """Verify that the private API _get_actor_call_stats - works correctly when the generator is used. - """ - pass - - def test_streaming_generator_exception(shutdown_only): # Verify the exceptions are correctly raised. # Also verify the followup next will raise StopIteration. From a4b62ac21d9c078b868a6d30027c8c4d1af2aaf8 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 05:36:38 -0700 Subject: [PATCH 18/77] done Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a5f751b7cdb1..399a28a5225d 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -344,7 +344,7 @@ class StreamingObjectRefGenerator: timeout_s: float, unexpected_network_failure_timeout_s: float): """Handle the error case of next APIs. - + Return None if there's no error. Returns a ref if the ref is supposed to be return. """ From 9ed05d98d9ef19d9fcc13d505334bb71ebc1a387 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 06:07:18 -0700 Subject: [PATCH 19/77] Addressed code review. Signed-off-by: SangBin Cho --- python/ray/includes/libcoreworker.pxd | 4 +- python/ray/tests/test_streaming_generator.py | 668 ++++++++++++++++++ src/ray/core_worker/core_worker.cc | 22 +- src/ray/core_worker/core_worker.h | 18 +- src/ray/core_worker/reference_count.cc | 4 + src/ray/core_worker/task_manager.cc | 51 +- src/ray/core_worker/task_manager.h | 34 +- .../core_worker/test/reference_count_test.cc | 51 ++ src/ray/core_worker/test/task_manager_test.cc | 82 +-- src/ray/protobuf/core_worker.proto | 10 +- src/ray/rpc/worker/core_worker_client.h | 8 +- src/ray/rpc/worker/core_worker_server.h | 4 +- 12 files changed, 841 insertions(+), 115 deletions(-) create mode 100644 python/ray/tests/test_streaming_generator.py diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 1b8d978cfcc6..2c71e1a5d809 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -236,11 +236,11 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: int64_t timeout_ms, c_vector[shared_ptr[CObjectLocation]] *results) CRayStatus TriggerGlobalGC() - CRayStatus ReportIntermediateTaskReturn( + CRayStatus ReportGeneratorItemReturns( const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object, const CObjectID &generator_id, const CAddress &caller_address, - int64_t idx, + int64_t item_index, c_bool finished) c_string MemoryUsageString() diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py new file mode 100644 index 000000000000..3344782c7be0 --- /dev/null +++ b/python/ray/tests/test_streaming_generator.py @@ -0,0 +1,668 @@ +import asyncio +import pytest +import numpy as np +import sys +import time +import gc + +from unittest.mock import patch, Mock + +import ray +from ray._private.test_utils import wait_for_condition +from ray.experimental.state.api import list_objects +from ray._raylet import StreamingObjectRefGenerator +from ray.cloudpickle import dumps +from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError + + +class MockedWorker: + def __init__(self, mocked_core_worker): + self.core_worker = mocked_core_worker + + def reset_core_worker(self): + """Emulate the case ray.shutdown is called + and the core_worker instance is GC'ed. + """ + self.core_worker = None + + +@pytest.fixture +def mocked_worker(): + mocked_core_worker = Mock() + mocked_core_worker.async_read_object_ref_stream.return_value = None + mocked_core_worker.delete_object_ref_stream.return_value = None + mocked_core_worker.create_object_ref_stream.return_value = None + worker = MockedWorker(mocked_core_worker) + yield worker + + +def test_streaming_object_ref_generator_basic_unit(mocked_worker): + """ + Verify the basic case: + create a generator -> read values -> nothing more to read -> delete. + """ + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.create_object_ref_stream.assert_called() + + # Test when there's no new ref, it returns a nil. + mocked_ray_wait.return_value = [], [generator_ref] + ref = generator._next_sync(timeout_s=0) + assert ref.is_nil() + + # When the new ref is available, next should return it. + for _ in range(3): + new_ref = ray.ObjectRef.from_random() + c.async_read_object_ref_stream.return_value = new_ref + ref = generator._next_sync(timeout_s=0) + assert new_ref == ref + + # When async_read_object_ref_stream raises a + # ObjectRefStreamEoFError, it should raise a stop iteration. + c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + with pytest.raises(StopIteration): + ref = generator._next_sync(timeout_s=0) + + # Make sure we cannot serialize the generator. + with pytest.raises(TypeError): + dumps(generator) + + del generator + c.delete_object_ref_stream.assert_called() + + +def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): + """ + Verify when a task is failed by a system error, + the generator ref is returned. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the worker failure happens. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.side_effect = WorkerCrashedError() + + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = generator._next_sync(timeout_s=0) + # If the generator task fails by a systsem error, + # meaning the ref will raise an exception + # it should be returned. + print(ref) + print(generator_ref) + assert ref == generator_ref + + # Once exception is raised, it should always + # raise stopIteration regardless of what + # the ref contains now. + with pytest.raises(StopIteration): + ref = generator._next_sync(timeout_s=0) + + +def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): + """ + Verify when a task is finished, but if the next ref is not available + on time, it raises an assertion error. + + TODO(sang): Once we move the task subimssion path to use pubsub + to guarantee the ordering, we don't need this test anymore. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the task has finished. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.return_value = None + + # If StopIteration is not raised within + # unexpected_network_failure_timeout_s second, + # it should fail. + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) + assert ref == ray.ObjectRef.nil() + time.sleep(1) + with pytest.raises(AssertionError): + generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) + # After that StopIteration should be raised. + with pytest.raises(StopIteration): + generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) + + +@pytest.mark.asyncio +async def test_streaming_object_ref_generator_unit_async(mocked_worker): + """ + Verify the basic case: + create a generator -> read values -> nothing more to read -> delete. + """ + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.create_object_ref_stream.assert_called() + + # Test when there's no new ref, it returns a nil. + mocked_ray_wait.return_value = [], [generator_ref] + ref = await generator._next_async(timeout_s=0) + assert ref.is_nil() + + # When the new ref is available, next should return it. + for _ in range(3): + new_ref = ray.ObjectRef.from_random() + c.async_read_object_ref_stream.return_value = new_ref + ref = await generator._next_async(timeout_s=0) + assert new_ref == ref + + # When async_read_object_ref_stream raises a + # ObjectRefStreamEoFError, it should raise a stop iteration. + c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + with pytest.raises(StopAsyncIteration): + ref = await generator._next_async(timeout_s=0) + + +def test_generator_basic(shutdown_only): + ray.init(num_cpus=1) + + """Basic cases""" + + @ray.remote + def f(): + for i in range(5): + yield i + + gen = f.options(num_returns="streaming").remote() + i = 0 + for ref in gen: + print(ray.get(ref)) + assert i == ray.get(ref) + del ref + i += 1 + + """Exceptions""" + + @ray.remote + def f(): + for i in range(5): + if i == 2: + raise ValueError + yield i + + gen = f.options(num_returns="streaming").remote() + ray.get(next(gen)) + ray.get(next(gen)) + with pytest.raises(ray.exceptions.RayTaskError) as e: + ray.get(next(gen)) + print(str(e.value)) + with pytest.raises(StopIteration): + ray.get(next(gen)) + with pytest.raises(StopIteration): + ray.get(next(gen)) + + """Generator Task failure""" + + @ray.remote + class A: + def getpid(self): + import os + + return os.getpid() + + def f(self): + for i in range(5): + time.sleep(0.1) + yield i + + a = A.remote() + i = 0 + gen = a.f.options(num_returns="streaming").remote() + i = 0 + for ref in gen: + if i == 2: + ray.kill(a) + if i == 3: + with pytest.raises(ray.exceptions.RayActorError) as e: + ray.get(ref) + assert "The actor is dead because it was killed by `ray.kill`" in str( + e.value + ) + break + assert i == ray.get(ref) + del ref + i += 1 + for _ in range(10): + with pytest.raises(StopIteration): + next(gen) + + """Retry exceptions""" + # TODO(sang): Enable it once retry is supported. + # @ray.remote + # class Actor: + # def __init__(self): + # self.should_kill = True + + # def should_kill(self): + # return self.should_kill + + # async def set(self, wait_s): + # await asyncio.sleep(wait_s) + # self.should_kill = False + + # @ray.remote(retry_exceptions=[ValueError], max_retries=10) + # def f(a): + # for i in range(5): + # should_kill = ray.get(a.should_kill.remote()) + # if i == 3 and should_kill: + # raise ValueError + # yield i + + # a = Actor.remote() + # gen = f.options(num_returns="streaming").remote(a) + # assert ray.get(next(gen)) == 0 + # assert ray.get(next(gen)) == 1 + # assert ray.get(next(gen)) == 2 + # a.set.remote(3) + # assert ray.get(next(gen)) == 3 + # assert ray.get(next(gen)) == 4 + # with pytest.raises(StopIteration): + # ray.get(next(gen)) + + """Cancel""" + + @ray.remote + def f(): + for i in range(5): + time.sleep(5) + yield i + + gen = f.options(num_returns="streaming").remote() + assert ray.get(next(gen)) == 0 + ray.cancel(gen) + with pytest.raises(ray.exceptions.RayTaskError) as e: + assert ray.get(next(gen)) == 1 + assert "was cancelled" in str(e.value) + with pytest.raises(StopIteration): + next(gen) + + +@pytest.mark.parametrize("crash_type", ["exception", "worker_crash"]) +def test_generator_streaming_no_leak_upon_failures( + monkeypatch, shutdown_only, crash_type +): + with monkeypatch.context() as m: + # defer for 10s for the second node. + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server.ReportGeneratorItemReturns=100000:1000000", + ) + ray.init(num_cpus=1) + + @ray.remote + def g(): + try: + gen = f.options(num_returns="streaming").remote() + for ref in gen: + print(ref) + ray.get(ref) + except Exception: + print("exception!") + del ref + + del gen + gc.collect() + + # Only the ref g is alive. + def verify(): + print(list_objects()) + return len(list_objects()) == 1 + + wait_for_condition(verify) + return True + + @ray.remote + def f(): + for i in range(10): + time.sleep(0.2) + if i == 4: + if crash_type == "exception": + raise ValueError + else: + sys.exit(9) + yield 2 + + for _ in range(5): + ray.get(g.remote()) + + +@pytest.mark.parametrize("use_actors", [False, True]) +@pytest.mark.parametrize("store_in_plasma", [False, True]) +def test_generator_streaming(shutdown_only, use_actors, store_in_plasma): + """Verify the generator is working in a streaming fashion.""" + ray.init() + remote_generator_fn = None + if use_actors: + + @ray.remote + class Generator: + def __init__(self): + pass + + def generator(self, num_returns, store_in_plasma): + for i in range(num_returns): + if store_in_plasma: + yield np.ones(1_000_000, dtype=np.int8) * i + else: + yield [i] + + g = Generator.remote() + remote_generator_fn = g.generator + else: + + @ray.remote(max_retries=0) + def generator(num_returns, store_in_plasma): + for i in range(num_returns): + if store_in_plasma: + yield np.ones(1_000_000, dtype=np.int8) * i + else: + yield [i] + + remote_generator_fn = generator + + """Verify num_returns="streaming" is streaming""" + gen = remote_generator_fn.options(num_returns="streaming").remote( + 3, store_in_plasma + ) + i = 0 + for ref in gen: + id = ref.hex() + if store_in_plasma: + expected = np.ones(1_000_000, dtype=np.int8) * i + assert np.array_equal(ray.get(ref), expected) + else: + expected = [i] + assert ray.get(ref) == expected + + del ref + + wait_for_condition( + lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 + ) + i += 1 + + +def test_generator_dist_chain(ray_start_cluster): + """E2E test to verify chain of generator works properly.""" + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) + ray.init() + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + + @ray.remote + class ChainActor: + def __init__(self, child=None): + self.child = child + + def get_data(self): + if not self.child: + for _ in range(10): + time.sleep(0.1) + yield np.ones(5 * 1024 * 1024) + else: + for data in self.child.get_data.options( + num_returns="streaming" + ).remote(): + yield ray.get(data) + + chain_actor = ChainActor.remote() + chain_actor_2 = ChainActor.remote(chain_actor) + chain_actor_3 = ChainActor.remote(chain_actor_2) + chain_actor_4 = ChainActor.remote(chain_actor_3) + + for ref in chain_actor_4.get_data.options(num_returns="streaming").remote(): + assert np.array_equal(np.ones(5 * 1024 * 1024), ray.get(ref)) + del ref + + +@pytest.mark.parametrize("store_in_plasma", [False, True]) +def test_actor_streaming_generator(shutdown_only, store_in_plasma): + """Test actor/async actor with sync/async generator interfaces.""" + ray.init() + + @ray.remote + class Actor: + def f(self, ref): + for i in range(3): + yield i + + async def async_f(self, ref): + for i in range(3): + await asyncio.sleep(0.1) + yield i + + def g(self): + return 3 + + a = Actor.remote() + if store_in_plasma: + arr = np.random.rand(5 * 1024 * 1024) + else: + arr = 3 + + def verify_sync_task_executor(): + generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + # Verify it works with next. + assert isinstance(generator, StreamingObjectRefGenerator) + assert ray.get(next(generator)) == 0 + assert ray.get(next(generator)) == 1 + assert ray.get(next(generator)) == 2 + with pytest.raises(StopIteration): + ray.get(next(generator)) + + # Verify it works with for. + generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + for index, ref in enumerate(generator): + assert index == ray.get(ref) + + def verify_async_task_executor(): + # Verify it works with next. + generator = a.async_f.options(num_returns="streaming").remote(ray.put(arr)) + assert isinstance(generator, StreamingObjectRefGenerator) + assert ray.get(next(generator)) == 0 + assert ray.get(next(generator)) == 1 + assert ray.get(next(generator)) == 2 + + # Verify it works with for. + generator = a.f.options(num_returns="streaming").remote(ray.put(3)) + for index, ref in enumerate(generator): + assert index == ray.get(ref) + + async def verify_sync_task_async_generator(): + # Verify anext + async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + assert isinstance(async_generator, StreamingObjectRefGenerator) + for expected in range(3): + ref = await async_generator.__anext__() + assert await ref == expected + with pytest.raises(StopAsyncIteration): + await async_generator.__anext__() + + # Verify async for. + async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) + expected = 0 + async for ref in async_generator: + value = await ref + assert value == value + expected += 1 + + async def verify_async_task_async_generator(): + async_generator = a.async_f.options(num_returns="streaming").remote( + ray.put(arr) + ) + assert isinstance(async_generator, StreamingObjectRefGenerator) + for expected in range(3): + ref = await async_generator.__anext__() + assert await ref == expected + with pytest.raises(StopAsyncIteration): + await async_generator.__anext__() + + # Verify async for. + async_generator = a.async_f.options(num_returns="streaming").remote( + ray.put(arr) + ) + expected = 0 + async for value in async_generator: + value = await ref + assert value == value + expected += 1 + + verify_sync_task_executor() + verify_async_task_executor() + asyncio.run(verify_sync_task_async_generator()) + asyncio.run(verify_async_task_async_generator()) + + +def test_streaming_generator_exception(shutdown_only): + # Verify the exceptions are correctly raised. + # Also verify the followup next will raise StopIteration. + ray.init() + + @ray.remote + class Actor: + def f(self): + raise ValueError + yield 1 # noqa + + async def async_f(self): + raise ValueError + yield 1 # noqa + + a = Actor.remote() + g = a.f.options(num_returns="streaming").remote() + with pytest.raises(ValueError): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + g = a.async_f.options(num_returns="streaming").remote() + with pytest.raises(ValueError): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + with pytest.raises(StopIteration): + ray.get(next(g)) + + +def test_threaded_actor_generator(shutdown_only): + ray.init() + + @ray.remote(max_concurrency=10) + class Actor: + def f(self): + for i in range(30): + time.sleep(0.1) + yield np.ones(1024 * 1024) * i + + @ray.remote(max_concurrency=20) + class AsyncActor: + async def f(self): + for i in range(30): + await asyncio.sleep(0.1) + yield np.ones(1024 * 1024) * i + + async def main(): + a = Actor.remote() + asy = AsyncActor.remote() + + async def run(): + i = 0 + async for ref in a.f.options(num_returns="streaming").remote(): + val = ray.get(ref) + print(val) + print(ref) + assert np.array_equal(val, np.ones(1024 * 1024) * i) + i += 1 + del ref + + async def run2(): + i = 0 + async for ref in asy.f.options(num_returns="streaming").remote(): + val = await ref + print(ref) + print(val) + assert np.array_equal(val, np.ones(1024 * 1024) * i), ref + i += 1 + del ref + + coroutines = [run() for _ in range(10)] + coroutines = [run2() for _ in range(20)] + + await asyncio.gather(*coroutines) + + asyncio.run(main()) + + +def test_generator_dist_all_gather(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) + ray.init() + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=1) + + @ray.remote(num_cpus=1) + class Actor: + def __init__(self, child=None): + self.child = child + + def get_data(self): + for _ in range(10): + time.sleep(0.1) + yield np.ones(5 * 1024 * 1024) + + async def all_gather(): + actor = Actor.remote() + async for ref in actor.get_data.options(num_returns="streaming").remote(): + val = await ref + assert np.array_equal(np.ones(5 * 1024 * 1024), val) + del ref + + async def main(): + await asyncio.gather(all_gather(), all_gather(), all_gather(), all_gather()) + + asyncio.run(main()) + summary = ray._private.internal_api.memory_summary(stats_only=True) + print(summary) + # assert "Spilled" not in summary, summary + + +if __name__ == "__main__": + import os + + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 95b355e683cb..4081613fde65 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2832,17 +2832,17 @@ ObjectID CoreWorker::AllocateDynamicReturnId() { return return_id; } -Status CoreWorker::ReportIntermediateTaskReturn( +Status CoreWorker::ReportGeneratorItemReturns( const std::pair> &dynamic_return_object, const ObjectID &generator_id, const rpc::Address &caller_address, - int64_t idx, + int64_t item_index, bool finished) { - RAY_LOG(DEBUG) << "Write the object ref stream, index: " << idx + RAY_LOG(DEBUG) << "Write the object ref stream, index: " << item_index << " finished: " << finished << ", id: " << dynamic_return_object.first; - rpc::ReportIntermediateTaskReturnRequest request; + rpc::ReportGeneratorItemReturnsRequest request; request.mutable_worker_addr()->CopyFrom(rpc_address_); - request.set_idx(idx); + request.set_item_index(item_index); request.set_finished(finished); request.set_generator_id(generator_id.Binary()); auto client = core_worker_client_pool_->GetOrConnect(caller_address); @@ -2863,9 +2863,9 @@ Status CoreWorker::ReportIntermediateTaskReturn( memory_store_->Delete(deleted); } - client->ReportIntermediateTaskReturn( + client->ReportGeneratorItemReturns( request, - [](const Status &status, const rpc::ReportIntermediateTaskReturnReply &reply) { + [](const Status &status, const rpc::ReportGeneratorItemReturnsReply &reply) { if (!status.ok()) { // TODO(sang): Handle network error more gracefully. RAY_LOG(ERROR) << "Failed to send the object ref."; @@ -2874,11 +2874,11 @@ Status CoreWorker::ReportIntermediateTaskReturn( return Status::OK(); } -void CoreWorker::HandleReportIntermediateTaskReturn( - rpc::ReportIntermediateTaskReturnRequest request, - rpc::ReportIntermediateTaskReturnReply *reply, +void CoreWorker::HandleReportGeneratorItemReturns( + rpc::ReportGeneratorItemReturnsRequest request, + rpc::ReportGeneratorItemReturnsReply *reply, rpc::SendReplyCallback send_reply_callback) { - task_manager_->HandleReportIntermediateTaskReturn(request); + task_manager_->HandleReportGeneratorItemReturns(request); send_reply_callback(Status::OK(), nullptr, nullptr); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index a3862421085a..e9430c213e57 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -709,14 +709,14 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Report the task caller at caller_address that the intermediate /// task return. It means if this API is used, the caller will be notified /// the task return before the current task is terminated. The caller must - /// implement HandleReportIntermediateTaskReturn API endpoint + /// implement HandleReportGeneratorItemReturns API endpoint /// to handle the intermediate result report. /// This API makes sense only for a generator task /// (task that can return multiple intermediate /// result before the task terminates). /// /// NOTE: The API doesn't guarantee the ordering of the report. The - /// caller is supposed to reorder the report based on the idx. + /// caller is supposed to reorder the report based on the item_index. /// /// \param[in] dynamic_return_object A intermediate ray object to report /// to the caller before the task terminates. This object must have been @@ -725,24 +725,24 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// task. /// \param[in] caller_address The address of the caller of the current task /// that created a generator_id. - /// \param[in] idx The index of the task return. It is used to reorder the + /// \param[in] item_index The index of the task return. It is used to reorder the /// report from the caller side. /// \param[in] finished True indicates there's going to be no more intermediate /// task return. When finished is provided dynamic_return_object input will be /// ignored. - Status ReportIntermediateTaskReturn( + Status ReportGeneratorItemReturns( const std::pair> &dynamic_return_object, const ObjectID &generator_id, const rpc::Address &caller_address, - int64_t idx, + int64_t item_index, bool finished); /// Implements gRPC server handler. /// If an executor can generator task return before the task is finished, - /// it invokes this endpoint via ReportIntermediateTaskReturn RPC. - void HandleReportIntermediateTaskReturn( - rpc::ReportIntermediateTaskReturnRequest request, - rpc::ReportIntermediateTaskReturnReply *reply, + /// it invokes this endpoint via ReportGeneratorItemReturns RPC. + void HandleReportGeneratorItemReturns( + rpc::ReportGeneratorItemReturnsRequest request, + rpc::ReportGeneratorItemReturnsReply *reply, rpc::SendReplyCallback send_reply_callback) override; /// Get a string describing object store memory usage for debugging purposes. diff --git a/src/ray/core_worker/reference_count.cc b/src/ray/core_worker/reference_count.cc index ede9cd844705..584cca3d85e2 100644 --- a/src/ray/core_worker/reference_count.cc +++ b/src/ray/core_worker/reference_count.cc @@ -250,6 +250,10 @@ void ReferenceCounter::OwnDynamicStreamingTaskReturnRef(const ObjectID &object_i // Generator object already went out of scope. // It means the generator is already GC'ed. No need to // update the reference. + RAY_LOG(DEBUG) + << "Ignore OwnDynamicStreamingTaskReturnRef. The dynamic return reference " + << object_id << " is registered after the generator id " << generator_id + << " went out of scope."; return; } RAY_LOG(DEBUG) << "Adding dynamic return " << object_id diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index f2afe87c9bb5..b2fe7373d213 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -30,7 +30,7 @@ const int64_t kTaskFailureThrottlingThreshold = 50; // Throttle task failure logs to once this interval. const int64_t kTaskFailureLoggingFrequencyMillis = 5000; -Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { +Status ObjectRefStream::TryReadNextItem(ObjectID *object_id_out) { bool is_eof_set = last_ != -1; if (is_eof_set && curr_ >= last_) { RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ @@ -39,8 +39,8 @@ Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { return Status::ObjectRefStreamEoF(""); } - auto it = idx_to_refs_.find(curr_); - if (it != idx_to_refs_.end()) { + auto it = item_index_to_refs_.find(curr_); + if (it != item_index_to_refs_.end()) { // If the current index has been written, // return the object ref. // The returned object ref will always have a ref count of 1. @@ -61,18 +61,18 @@ Status ObjectRefStream::AsyncReadNext(ObjectID *object_id_out) { return Status::OK(); } -bool ObjectRefStream::Write(const ObjectID &object_id, int64_t idx) { +bool ObjectRefStream::InsertToStream(const ObjectID &object_id, int64_t item_index) { if (last_ != -1) { RAY_CHECK(curr_ <= last_); } - if (idx < curr_) { + if (item_index < curr_) { // Index is already used. Don't write it to the stream. return false; } - auto it = idx_to_refs_.find(idx); - if (it != idx_to_refs_.end()) { + auto it = item_index_to_refs_.find(item_index); + if (it != item_index_to_refs_.end()) { // It means the when a task is retried it returns a different object id // for the same index, which means the task was not deterministic. // Fail the owner if it happens. @@ -83,11 +83,11 @@ bool ObjectRefStream::Write(const ObjectID &object_id, int64_t idx) { << ". It means a undeterministic task has been retried. Disable the retry " "feature using `max_retries=0` (task) or `max_task_retries=0` (actor)."; } - idx_to_refs_.emplace(idx, object_id); + item_index_to_refs_.emplace(item_index, object_id); return true; } -void ObjectRefStream::WriteEoF(int64_t idx) { last_ = idx; } +void ObjectRefStream::MarkEndOfStream(int64_t item_index) { last_ = item_index; } std::vector TaskManager::AddPendingTask( const rpc::Address &caller_address, @@ -382,7 +382,7 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { while (true) { ObjectID object_id; - const auto &status = AsyncReadObjectRefStreamInternal(generator_id, &object_id); + const auto &status = TryReadObjectRefStreamInternal(generator_id, &object_id); // keyError means the stream reaches to EoF. if (status.IsObjectRefStreamEoF()) { @@ -410,21 +410,22 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { } } -Status TaskManager::AsyncReadObjectRefStreamInternal(const ObjectID &generator_id, - ObjectID *object_id_out) { +Status TaskManager::TryReadObjectRefStreamInternal(const ObjectID &generator_id, + ObjectID *object_id_out) { RAY_CHECK(object_id_out != nullptr); auto stream_it = object_ref_streams_.find(generator_id); RAY_CHECK(stream_it != object_ref_streams_.end()) - << "AsyncReadObjectRefStream API can be used only when the stream has been created " + << "TryReadObjectRefStreamInternal API can be used only when the stream has been " + "created " "and not removed."; - const auto &status = stream_it->second.AsyncReadNext(object_id_out); + const auto &status = stream_it->second.TryReadNextItem(object_id_out); return status; } -Status TaskManager::AsyncReadObjectRefStream(const ObjectID &generator_id, - ObjectID *object_id_out) { +Status TaskManager::TryReadObjectRefStream(const ObjectID &generator_id, + ObjectID *object_id_out) { absl::MutexLock lock(&mu_); - return AsyncReadObjectRefStreamInternal(generator_id, object_id_out); + return TryReadObjectRefStreamInternal(generator_id, object_id_out); } bool TaskManager::ObjectRefStreamExists(const ObjectID &generator_id) { @@ -433,21 +434,21 @@ bool TaskManager::ObjectRefStreamExists(const ObjectID &generator_id) { return it != object_ref_streams_.end(); } -void TaskManager::HandleReportIntermediateTaskReturn( - const rpc::ReportIntermediateTaskReturnRequest &request) { +void TaskManager::HandleReportGeneratorItemReturns( + const rpc::ReportGeneratorItemReturnsRequest &request) { const auto &generator_id = ObjectID::FromBinary(request.generator_id()); const auto &task_id = generator_id.TaskId(); - int64_t idx = request.idx(); + int64_t item_index = request.item_index(); // Every generated object has the same task id. - RAY_LOG(DEBUG) << "Received an intermediate result of index " << idx + RAY_LOG(DEBUG) << "Received an intermediate result of index " << item_index << " generator_id: " << generator_id; if (request.finished()) { absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "Write EoF to the object ref stream. Index: " << idx; + RAY_LOG(DEBUG) << "Write EoF to the object ref stream. Index: " << item_index; auto stream_it = object_ref_streams_.find(generator_id); if (stream_it != object_ref_streams_.end()) { - stream_it->second.WriteEoF(idx); + stream_it->second.MarkEndOfStream(item_index); } // The last report should not have any return objects. RAY_CHECK(request.dynamic_return_objects_size() == 0); @@ -468,7 +469,7 @@ void TaskManager::HandleReportIntermediateTaskReturn( absl::MutexLock lock(&mu_); auto stream_it = object_ref_streams_.find(generator_id); if (stream_it != object_ref_streams_.end()) { - index_not_used_yet = stream_it->second.Write(object_id, idx); + index_not_used_yet = stream_it->second.InsertToStream(object_id, item_index); } // TODO(sang): Update the reconstruct ids and task spec // when we support retry. @@ -485,7 +486,7 @@ void TaskManager::HandleReportIntermediateTaskReturn( // It is okay now because we don't support retry yet. But when // we support retry, we should guarantee it is not called // after the task resubmission. We can do it by guaranteeing - // HandleReportIntermediateTaskReturn is not called after the task + // HandleReportGeneratorItemReturns is not called after the task // CompletePendingTask. reference_counter_->UpdateObjectReady(object_id); HandleTaskReturn(object_id, diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 4a55a090c201..b94a2263accd 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -102,29 +102,31 @@ class ObjectRefStream { /// \param[out] object_id_out The next object ID from the stream. /// Nil ID is returned if the next index hasn't been written. /// \return KeyError if it reaches to EoF. Ok otherwise. - Status AsyncReadNext(ObjectID *object_id_out); + Status TryReadNextItem(ObjectID *object_id_out); - /// Write the object id to the stream of an index idx. + /// Insert the object id to the stream of an index item_index. /// - /// If the idx has been already read (by AsyncReadNext), - /// the write request will be ignored. If the idx has been + /// If the item_index has been already read (by TryReadNextItem), + /// the write request will be ignored. If the item_index has been /// already written, it will be no-op. It doesn't override. /// - /// \param[in] object_id The object id that will be read at index idx. - /// \param[in] idx The index where the object id will be written. + /// \param[in] object_id The object id that will be read at index item_index. + /// \param[in] item_index The index where the object id will be written. /// \return True if the idx hasn't been used. False otherwise. - bool Write(const ObjectID &object_id, int64_t idx); + bool InsertToStream(const ObjectID &object_id, int64_t item_index); /// Mark the stream canont be used anymore. - void WriteEoF(int64_t idx); + /// + /// \param[in] The last item index that means the end of stream. + void MarkEndOfStream(int64_t item_index); private: const ObjectID generator_id_; - /// The index -> object reference ids. - absl::flat_hash_map idx_to_refs_; + /// The item_index -> object reference ids. + absl::flat_hash_map item_index_to_refs_; /// The last index of the stream. - /// idx < last will contain object references. + /// item_index < last will contain object references. /// If -1, that means the stream hasn't reached to EoF. int64_t last_ = -1; /// The current index of the stream. @@ -213,8 +215,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa bool is_application_error) override; /// Handle the task return reported before the task terminates. - void HandleReportIntermediateTaskReturn( - const rpc::ReportIntermediateTaskReturnRequest &request); + void HandleReportGeneratorItemReturns( + const rpc::ReportGeneratorItemReturnsRequest &request); /// Delete the object ref stream. /// @@ -259,7 +261,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param[out] object_id_out The next object ID from the stream. /// Nil ID is returned if the next index hasn't been written. /// \return KeyError if it reaches to EoF. Ok otherwise. - Status AsyncReadObjectRefStream(const ObjectID &generator_id, ObjectID *object_id_out); + Status TryReadObjectRefStream(const ObjectID &generator_id, ObjectID *object_id_out); /// Returns true if task can be retried. /// @@ -601,8 +603,8 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// \param task_entry Task entry for the corresponding task attempt void MarkTaskRetryOnFailed(TaskEntry &task_entry, const rpc::RayErrorInfo &error_info); - Status AsyncReadObjectRefStreamInternal(const ObjectID &generator_id, - ObjectID *object_id_out) + Status TryReadObjectRefStreamInternal(const ObjectID &generator_id, + ObjectID *object_id_out) EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Used to store task results. diff --git a/src/ray/core_worker/test/reference_count_test.cc b/src/ray/core_worker/test/reference_count_test.cc index 51b5d51523ac..dd715888d646 100644 --- a/src/ray/core_worker/test/reference_count_test.cc +++ b/src/ray/core_worker/test/reference_count_test.cc @@ -2946,6 +2946,57 @@ TEST_F(ReferenceCountTest, TestForwardNestedRefs) { borrower2->rc_.RemoveLocalReference(inner_id, nullptr); } +TEST_F(ReferenceCountTest, TestOwnDynamicStreamingTaskReturnRef) { + auto object_id = ObjectID::FromRandom(); + auto generator_id = ObjectID::FromRandom(); + auto generator_id_2 = ObjectID::FromRandom(); + rpc::Address added_address; + + // Verify OwnDynamicStreamingTaskReturnRef is ignored + // when there's no generator id. + rc->OwnDynamicStreamingTaskReturnRef(object_id, generator_id); + ASSERT_FALSE(rc->GetOwner(generator_id, &added_address)); + ASSERT_FALSE(rc->GetOwner(object_id, &added_address)); + ASSERT_FALSE(rc->HasReference(object_id)); + ASSERT_FALSE(rc->HasReference(generator_id)); + + // Add a generator id. + rpc::Address address; + address.set_ip_address("1234"); + rc->AddOwnedObject(generator_id, {}, address, "", 0, false, /*add_local_ref=*/true); + ASSERT_TRUE(rc->HasReference(generator_id)); + + // Verify object id is not registered if the incorrect generator id is given. + rc->OwnDynamicStreamingTaskReturnRef(object_id, generator_id_2); + ASSERT_FALSE(rc->HasReference(object_id)); + + // Verify object is owned. + rc->OwnDynamicStreamingTaskReturnRef(object_id, generator_id); + ASSERT_TRUE(rc->HasReference(object_id)); + // Verify the number of objects: Generator + object. + ASSERT_EQ(rc->NumObjectIDsInScope(), 2); + // Verify it is owned by us. + ASSERT_TRUE(rc->GetOwner(object_id, &added_address)); + ASSERT_EQ(address.ip_address(), added_address.ip_address()); + // Verify it had 1 local reference. + std::vector deleted; + rc_.RemoveLocalReference(object_id, &deleted); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_EQ(deleted.size(), 1); + ASSERT_FALSE(rc->GetOwner(object_id, &added_address)); + + // Remove the generator. + rc_.RemoveLocalReference(generator_id, nullptr); + ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + ASSERT_FALSE(rc->GetOwner(generator_id, &added_address)); + + // Verify we cannot register a new object after the generator id is removed. + auto object_id_2 = ObjectID::FromRandom(); + rc->OwnDynamicStreamingTaskReturnRef(object_id_2, generator_id); + ASSERT_FALSE(rc->GetOwner(object_id_2, &added_address)); + ASSERT_FALSE(rc->HasReference(object_id_2)); +} + } // namespace core } // namespace ray diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index f8e1d92c212d..d6198abccb7a 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -51,17 +51,17 @@ rpc::Address GetRandomWorkerAddr() { return addr; } -rpc::ReportIntermediateTaskReturnRequest GetIntermediateTaskReturn( +rpc::ReportGeneratorItemReturnsRequest GetIntermediateTaskReturn( int64_t idx, bool finished, const ObjectID &generator_id, const ObjectID &dynamic_return_id, std::shared_ptr data, bool set_in_plasma) { - rpc::ReportIntermediateTaskReturnRequest request; + rpc::ReportGeneratorItemReturnsRequest request; rpc::Address addr; request.mutable_worker_addr()->CopyFrom(addr); - request.set_idx(idx); + request.set_item_index(idx); request.set_finished(finished); request.set_generator_id(generator_id.Binary()); auto dynamic_return_object = request.add_dynamic_return_objects(); @@ -71,12 +71,12 @@ rpc::ReportIntermediateTaskReturnRequest GetIntermediateTaskReturn( return request; } -rpc::ReportIntermediateTaskReturnRequest GetEoFTaskReturn(int64_t idx, - const ObjectID &generator_id) { - rpc::ReportIntermediateTaskReturnRequest request; +rpc::ReportGeneratorItemReturnsRequest GetEoFTaskReturn(int64_t idx, + const ObjectID &generator_id) { + rpc::ReportGeneratorItemReturnsRequest request; rpc::Address addr; request.mutable_worker_addr()->CopyFrom(addr); - request.set_idx(idx); + request.set_item_index(idx); request.set_finished(true); request.set_generator_id(generator_id.Binary()); return request; @@ -1224,20 +1224,20 @@ TEST_F(TaskManagerTest, TestObjectRefStreamBasic) { /*data*/ data, /*set_in_plasma*/ false); // WRITE * 2 - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); } // WRITEEoF - manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(last_idx, generator_id)); + manager_.HandleReportGeneratorItemReturns(GetEoFTaskReturn(last_idx, generator_id)); ObjectID obj_id; for (auto i = 0; i < last_idx; i++) { // READ * 2 - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_ids[i]); } // READ (EoF) - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.IsObjectRefStreamEoF()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE @@ -1271,19 +1271,19 @@ TEST_F(TaskManagerTest, TestObjectRefStreamMixture) { /*data*/ data, /*set_in_plasma*/ false); // WRITE - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // READ ObjectID obj_id; - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_ids[i]); } // WRITEEoF - manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(last_idx, generator_id)); + manager_.HandleReportGeneratorItemReturns(GetEoFTaskReturn(last_idx, generator_id)); ObjectID obj_id; // READ (EoF) - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.IsObjectRefStreamEoF()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE @@ -1310,12 +1310,12 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // WRITEEoF - manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(1, generator_id)); + manager_.HandleReportGeneratorItemReturns(GetEoFTaskReturn(1, generator_id)); // READ (works) ObjectID obj_id; - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_id); @@ -1329,9 +1329,9 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // READ (doesn't works because EoF is already written) - status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.IsObjectRefStreamEoF()); } @@ -1355,10 +1355,10 @@ TEST_F(TaskManagerTest, TestObjectRefStreamIndexDiscarded) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // READ ObjectID obj_id; - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_id); @@ -1372,9 +1372,9 @@ TEST_F(TaskManagerTest, TestObjectRefStreamIndexDiscarded) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // READ (New write will be ignored). - status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, ObjectID::Nil()); } @@ -1391,7 +1391,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamReadIgnoredWhenNothingWritten) { // READ (no-op) ObjectID obj_id; - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, ObjectID::Nil()); @@ -1405,14 +1405,14 @@ TEST_F(TaskManagerTest, TestObjectRefStreamReadIgnoredWhenNothingWritten) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // READ (works this time) - status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_id); // READ (nothing should return) - status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, ObjectID::Nil()); } @@ -1444,7 +1444,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // NumObjectIDsInScope == Generator + intermediate result. ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 2); @@ -1455,7 +1455,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { // Make sure you can read. ObjectID obj_id; - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_id); @@ -1479,9 +1479,9 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { /*dynamic_return_id*/ dynamic_return_id2, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // EoF - manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(2, generator_id)); + manager_.HandleReportGeneratorItemReturns(GetEoFTaskReturn(2, generator_id)); // NumObjectIDsInScope == Generator + 2 intermediate result. ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); @@ -1490,12 +1490,12 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { ASSERT_EQ(results.size(), 1); // Make sure you can read. - status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_id2); // Nothing more to read. - status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.IsObjectRefStreamEoF()); manager_.DelObjectRefStream(generator_id); @@ -1529,7 +1529,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { /*dynamic_return_id*/ dynamic_return_id, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // WRITE 2 auto dynamic_return_id2 = ObjectID::FromIndex(spec.TaskId(), 3); data = GenerateRandomBuffer(); @@ -1540,7 +1540,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { /*dynamic_return_id*/ dynamic_return_id2, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); RAY_LOG(ERROR) << "SANG-TODO 1"; // NumObjectIDsInScope == Generator + 2 WRITE ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); @@ -1580,7 +1580,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) { /*dynamic_return_id*/ dynamic_return_id3, /*data*/ data, /*set_in_plasma*/ false); - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); // The write should have been no op. No refs and no obj values except the generator id. ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut()); @@ -1608,7 +1608,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { auto last_idx = 2; std::vector dynamic_return_ids; // EoF reported first. - manager_.HandleReportIntermediateTaskReturn(GetEoFTaskReturn(last_idx, generator_id)); + manager_.HandleReportGeneratorItemReturns(GetEoFTaskReturn(last_idx, generator_id)); // Write index 1 -> 0 for (auto i = last_idx - 1; i > -1; i--) { @@ -1624,19 +1624,19 @@ TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { /*data*/ data, /*set_in_plasma*/ false); // WRITE * 2 - manager_.HandleReportIntermediateTaskReturn(req); + manager_.HandleReportGeneratorItemReturns(req); } // Verify read works. ObjectID obj_id; for (auto i = 0; i < last_idx; i++) { - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.ok()); ASSERT_EQ(obj_id, dynamic_return_ids[i]); } // READ (EoF) - auto status = manager_.AsyncReadObjectRefStream(generator_id, &obj_id); + auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); ASSERT_TRUE(status.IsObjectRefStreamEoF()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 65b0b077866d..23350fcb0f0f 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -382,7 +382,7 @@ message RayletNotifyGCSRestartRequest {} message RayletNotifyGCSRestartReply {} -message ReportIntermediateTaskReturnRequest { +message ReportGeneratorItemReturnsRequest { // The intermediate return object that's dynamically // generated from the executor side. repeated ReturnObject dynamic_return_objects = 1; @@ -392,7 +392,7 @@ message ReportIntermediateTaskReturnRequest { // reorder the intermediate return object // because the ordering of this request // is not guaranteed. - int64 idx = 3; + int64 item_index = 3; // If true, it means there's going to be no more // task return after this request. bool finished = 4; @@ -401,7 +401,7 @@ message ReportIntermediateTaskReturnRequest { bytes generator_id = 5; } -message ReportIntermediateTaskReturnReply {} +message ReportGeneratorItemReturnsReply {} service CoreWorkerService { // Notify core worker GCS has restarted. @@ -424,8 +424,8 @@ service CoreWorkerService { /// the caller (subscriber). rpc PubsubLongPolling(PubsubLongPollingRequest) returns (PubsubLongPollingReply); // The RPC to report the intermediate task return to the caller. - rpc ReportIntermediateTaskReturn(ReportIntermediateTaskReturnRequest) - returns (ReportIntermediateTaskReturnReply); + rpc ReportGeneratorItemReturns(ReportGeneratorItemReturnsRequest) + returns (ReportGeneratorItemReturnsReply); /// The pubsub command batch request used by the subscriber. rpc PubsubCommandBatch(PubsubCommandBatchRequest) returns (PubsubCommandBatchReply); // Update the batched object location information to the ownership-based object diff --git a/src/ray/rpc/worker/core_worker_client.h b/src/ray/rpc/worker/core_worker_client.h index 3b7caa1592f2..b8341f7eb6b8 100644 --- a/src/ray/rpc/worker/core_worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -154,9 +154,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface { const GetObjectLocationsOwnerRequest &request, const ClientCallback &callback) {} - virtual void ReportIntermediateTaskReturn( - const ReportIntermediateTaskReturnRequest &request, - const ClientCallback &callback) {} + virtual void ReportGeneratorItemReturns( + const ReportGeneratorItemReturnsRequest &request, + const ClientCallback &callback) {} /// Tell this actor to exit immediately. virtual void KillActor(const KillActorRequest &request, @@ -288,7 +288,7 @@ class CoreWorkerClient : public std::enable_shared_from_this, override) VOID_RPC_CLIENT_METHOD(CoreWorkerService, - ReportIntermediateTaskReturn, + ReportGeneratorItemReturns, grpc_client_, /*method_timeout_ms*/ -1, override) diff --git a/src/ray/rpc/worker/core_worker_server.h b/src/ray/rpc/worker/core_worker_server.h index c41486fb4af8..9c548463a786 100644 --- a/src/ray/rpc/worker/core_worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -44,7 +44,7 @@ namespace rpc { RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ CoreWorkerService, GetObjectLocationsOwner, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED( \ - CoreWorkerService, ReportIntermediateTaskReturn, -1) \ + CoreWorkerService, ReportGeneratorItemReturns, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, KillActor, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, CancelTask, -1) \ RPC_SERVICE_HANDLER_SERVER_METRICS_DISABLED(CoreWorkerService, RemoteCancelTask, -1) \ @@ -70,7 +70,7 @@ namespace rpc { DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(PubsubCommandBatch) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(UpdateObjectLocationBatch) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(GetObjectLocationsOwner) \ - DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(ReportIntermediateTaskReturn) \ + DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(ReportGeneratorItemReturns) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(KillActor) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(CancelTask) \ DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(RemoteCancelTask) \ From e2f19801a10b3ec4dc02b17cf2102d03a19bb4d1 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 06:09:02 -0700 Subject: [PATCH 20/77] removed a test file Signed-off-by: SangBin Cho --- python/ray/tests/test_streaming_generator.py | 668 ------------------- 1 file changed, 668 deletions(-) delete mode 100644 python/ray/tests/test_streaming_generator.py diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py deleted file mode 100644 index 3344782c7be0..000000000000 --- a/python/ray/tests/test_streaming_generator.py +++ /dev/null @@ -1,668 +0,0 @@ -import asyncio -import pytest -import numpy as np -import sys -import time -import gc - -from unittest.mock import patch, Mock - -import ray -from ray._private.test_utils import wait_for_condition -from ray.experimental.state.api import list_objects -from ray._raylet import StreamingObjectRefGenerator -from ray.cloudpickle import dumps -from ray.exceptions import ObjectRefStreamEoFError, WorkerCrashedError - - -class MockedWorker: - def __init__(self, mocked_core_worker): - self.core_worker = mocked_core_worker - - def reset_core_worker(self): - """Emulate the case ray.shutdown is called - and the core_worker instance is GC'ed. - """ - self.core_worker = None - - -@pytest.fixture -def mocked_worker(): - mocked_core_worker = Mock() - mocked_core_worker.async_read_object_ref_stream.return_value = None - mocked_core_worker.delete_object_ref_stream.return_value = None - mocked_core_worker.create_object_ref_stream.return_value = None - worker = MockedWorker(mocked_core_worker) - yield worker - - -def test_streaming_object_ref_generator_basic_unit(mocked_worker): - """ - Verify the basic case: - create a generator -> read values -> nothing more to read -> delete. - """ - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - c.create_object_ref_stream.assert_called() - - # Test when there's no new ref, it returns a nil. - mocked_ray_wait.return_value = [], [generator_ref] - ref = generator._next_sync(timeout_s=0) - assert ref.is_nil() - - # When the new ref is available, next should return it. - for _ in range(3): - new_ref = ray.ObjectRef.from_random() - c.async_read_object_ref_stream.return_value = new_ref - ref = generator._next_sync(timeout_s=0) - assert new_ref == ref - - # When async_read_object_ref_stream raises a - # ObjectRefStreamEoFError, it should raise a stop iteration. - c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa - with pytest.raises(StopIteration): - ref = generator._next_sync(timeout_s=0) - - # Make sure we cannot serialize the generator. - with pytest.raises(TypeError): - dumps(generator) - - del generator - c.delete_object_ref_stream.assert_called() - - -def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): - """ - Verify when a task is failed by a system error, - the generator ref is returned. - """ - with patch("ray.get") as mocked_ray_get: - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - - # Simulate the worker failure happens. - mocked_ray_wait.return_value = [generator_ref], [] - mocked_ray_get.side_effect = WorkerCrashedError() - - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next_sync(timeout_s=0) - # If the generator task fails by a systsem error, - # meaning the ref will raise an exception - # it should be returned. - print(ref) - print(generator_ref) - assert ref == generator_ref - - # Once exception is raised, it should always - # raise stopIteration regardless of what - # the ref contains now. - with pytest.raises(StopIteration): - ref = generator._next_sync(timeout_s=0) - - -def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): - """ - Verify when a task is finished, but if the next ref is not available - on time, it raises an assertion error. - - TODO(sang): Once we move the task subimssion path to use pubsub - to guarantee the ordering, we don't need this test anymore. - """ - with patch("ray.get") as mocked_ray_get: - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - - # Simulate the task has finished. - mocked_ray_wait.return_value = [generator_ref], [] - mocked_ray_get.return_value = None - - # If StopIteration is not raised within - # unexpected_network_failure_timeout_s second, - # it should fail. - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next_sync( - timeout_s=0, unexpected_network_failure_timeout_s=1 - ) - assert ref == ray.ObjectRef.nil() - time.sleep(1) - with pytest.raises(AssertionError): - generator._next_sync( - timeout_s=0, unexpected_network_failure_timeout_s=1 - ) - # After that StopIteration should be raised. - with pytest.raises(StopIteration): - generator._next_sync( - timeout_s=0, unexpected_network_failure_timeout_s=1 - ) - - -@pytest.mark.asyncio -async def test_streaming_object_ref_generator_unit_async(mocked_worker): - """ - Verify the basic case: - create a generator -> read values -> nothing more to read -> delete. - """ - with patch("ray.wait") as mocked_ray_wait: - c = mocked_worker.core_worker - generator_ref = ray.ObjectRef.from_random() - generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() - c.create_object_ref_stream.assert_called() - - # Test when there's no new ref, it returns a nil. - mocked_ray_wait.return_value = [], [generator_ref] - ref = await generator._next_async(timeout_s=0) - assert ref.is_nil() - - # When the new ref is available, next should return it. - for _ in range(3): - new_ref = ray.ObjectRef.from_random() - c.async_read_object_ref_stream.return_value = new_ref - ref = await generator._next_async(timeout_s=0) - assert new_ref == ref - - # When async_read_object_ref_stream raises a - # ObjectRefStreamEoFError, it should raise a stop iteration. - c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa - with pytest.raises(StopAsyncIteration): - ref = await generator._next_async(timeout_s=0) - - -def test_generator_basic(shutdown_only): - ray.init(num_cpus=1) - - """Basic cases""" - - @ray.remote - def f(): - for i in range(5): - yield i - - gen = f.options(num_returns="streaming").remote() - i = 0 - for ref in gen: - print(ray.get(ref)) - assert i == ray.get(ref) - del ref - i += 1 - - """Exceptions""" - - @ray.remote - def f(): - for i in range(5): - if i == 2: - raise ValueError - yield i - - gen = f.options(num_returns="streaming").remote() - ray.get(next(gen)) - ray.get(next(gen)) - with pytest.raises(ray.exceptions.RayTaskError) as e: - ray.get(next(gen)) - print(str(e.value)) - with pytest.raises(StopIteration): - ray.get(next(gen)) - with pytest.raises(StopIteration): - ray.get(next(gen)) - - """Generator Task failure""" - - @ray.remote - class A: - def getpid(self): - import os - - return os.getpid() - - def f(self): - for i in range(5): - time.sleep(0.1) - yield i - - a = A.remote() - i = 0 - gen = a.f.options(num_returns="streaming").remote() - i = 0 - for ref in gen: - if i == 2: - ray.kill(a) - if i == 3: - with pytest.raises(ray.exceptions.RayActorError) as e: - ray.get(ref) - assert "The actor is dead because it was killed by `ray.kill`" in str( - e.value - ) - break - assert i == ray.get(ref) - del ref - i += 1 - for _ in range(10): - with pytest.raises(StopIteration): - next(gen) - - """Retry exceptions""" - # TODO(sang): Enable it once retry is supported. - # @ray.remote - # class Actor: - # def __init__(self): - # self.should_kill = True - - # def should_kill(self): - # return self.should_kill - - # async def set(self, wait_s): - # await asyncio.sleep(wait_s) - # self.should_kill = False - - # @ray.remote(retry_exceptions=[ValueError], max_retries=10) - # def f(a): - # for i in range(5): - # should_kill = ray.get(a.should_kill.remote()) - # if i == 3 and should_kill: - # raise ValueError - # yield i - - # a = Actor.remote() - # gen = f.options(num_returns="streaming").remote(a) - # assert ray.get(next(gen)) == 0 - # assert ray.get(next(gen)) == 1 - # assert ray.get(next(gen)) == 2 - # a.set.remote(3) - # assert ray.get(next(gen)) == 3 - # assert ray.get(next(gen)) == 4 - # with pytest.raises(StopIteration): - # ray.get(next(gen)) - - """Cancel""" - - @ray.remote - def f(): - for i in range(5): - time.sleep(5) - yield i - - gen = f.options(num_returns="streaming").remote() - assert ray.get(next(gen)) == 0 - ray.cancel(gen) - with pytest.raises(ray.exceptions.RayTaskError) as e: - assert ray.get(next(gen)) == 1 - assert "was cancelled" in str(e.value) - with pytest.raises(StopIteration): - next(gen) - - -@pytest.mark.parametrize("crash_type", ["exception", "worker_crash"]) -def test_generator_streaming_no_leak_upon_failures( - monkeypatch, shutdown_only, crash_type -): - with monkeypatch.context() as m: - # defer for 10s for the second node. - m.setenv( - "RAY_testing_asio_delay_us", - "CoreWorkerService.grpc_server.ReportGeneratorItemReturns=100000:1000000", - ) - ray.init(num_cpus=1) - - @ray.remote - def g(): - try: - gen = f.options(num_returns="streaming").remote() - for ref in gen: - print(ref) - ray.get(ref) - except Exception: - print("exception!") - del ref - - del gen - gc.collect() - - # Only the ref g is alive. - def verify(): - print(list_objects()) - return len(list_objects()) == 1 - - wait_for_condition(verify) - return True - - @ray.remote - def f(): - for i in range(10): - time.sleep(0.2) - if i == 4: - if crash_type == "exception": - raise ValueError - else: - sys.exit(9) - yield 2 - - for _ in range(5): - ray.get(g.remote()) - - -@pytest.mark.parametrize("use_actors", [False, True]) -@pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_generator_streaming(shutdown_only, use_actors, store_in_plasma): - """Verify the generator is working in a streaming fashion.""" - ray.init() - remote_generator_fn = None - if use_actors: - - @ray.remote - class Generator: - def __init__(self): - pass - - def generator(self, num_returns, store_in_plasma): - for i in range(num_returns): - if store_in_plasma: - yield np.ones(1_000_000, dtype=np.int8) * i - else: - yield [i] - - g = Generator.remote() - remote_generator_fn = g.generator - else: - - @ray.remote(max_retries=0) - def generator(num_returns, store_in_plasma): - for i in range(num_returns): - if store_in_plasma: - yield np.ones(1_000_000, dtype=np.int8) * i - else: - yield [i] - - remote_generator_fn = generator - - """Verify num_returns="streaming" is streaming""" - gen = remote_generator_fn.options(num_returns="streaming").remote( - 3, store_in_plasma - ) - i = 0 - for ref in gen: - id = ref.hex() - if store_in_plasma: - expected = np.ones(1_000_000, dtype=np.int8) * i - assert np.array_equal(ray.get(ref), expected) - else: - expected = [i] - assert ray.get(ref) == expected - - del ref - - wait_for_condition( - lambda: len(list_objects(filters=[("object_id", "=", id)])) == 0 - ) - i += 1 - - -def test_generator_dist_chain(ray_start_cluster): - """E2E test to verify chain of generator works properly.""" - cluster = ray_start_cluster - cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) - ray.init() - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - - @ray.remote - class ChainActor: - def __init__(self, child=None): - self.child = child - - def get_data(self): - if not self.child: - for _ in range(10): - time.sleep(0.1) - yield np.ones(5 * 1024 * 1024) - else: - for data in self.child.get_data.options( - num_returns="streaming" - ).remote(): - yield ray.get(data) - - chain_actor = ChainActor.remote() - chain_actor_2 = ChainActor.remote(chain_actor) - chain_actor_3 = ChainActor.remote(chain_actor_2) - chain_actor_4 = ChainActor.remote(chain_actor_3) - - for ref in chain_actor_4.get_data.options(num_returns="streaming").remote(): - assert np.array_equal(np.ones(5 * 1024 * 1024), ray.get(ref)) - del ref - - -@pytest.mark.parametrize("store_in_plasma", [False, True]) -def test_actor_streaming_generator(shutdown_only, store_in_plasma): - """Test actor/async actor with sync/async generator interfaces.""" - ray.init() - - @ray.remote - class Actor: - def f(self, ref): - for i in range(3): - yield i - - async def async_f(self, ref): - for i in range(3): - await asyncio.sleep(0.1) - yield i - - def g(self): - return 3 - - a = Actor.remote() - if store_in_plasma: - arr = np.random.rand(5 * 1024 * 1024) - else: - arr = 3 - - def verify_sync_task_executor(): - generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) - # Verify it works with next. - assert isinstance(generator, StreamingObjectRefGenerator) - assert ray.get(next(generator)) == 0 - assert ray.get(next(generator)) == 1 - assert ray.get(next(generator)) == 2 - with pytest.raises(StopIteration): - ray.get(next(generator)) - - # Verify it works with for. - generator = a.f.options(num_returns="streaming").remote(ray.put(3)) - for index, ref in enumerate(generator): - assert index == ray.get(ref) - - def verify_async_task_executor(): - # Verify it works with next. - generator = a.async_f.options(num_returns="streaming").remote(ray.put(arr)) - assert isinstance(generator, StreamingObjectRefGenerator) - assert ray.get(next(generator)) == 0 - assert ray.get(next(generator)) == 1 - assert ray.get(next(generator)) == 2 - - # Verify it works with for. - generator = a.f.options(num_returns="streaming").remote(ray.put(3)) - for index, ref in enumerate(generator): - assert index == ray.get(ref) - - async def verify_sync_task_async_generator(): - # Verify anext - async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) - assert isinstance(async_generator, StreamingObjectRefGenerator) - for expected in range(3): - ref = await async_generator.__anext__() - assert await ref == expected - with pytest.raises(StopAsyncIteration): - await async_generator.__anext__() - - # Verify async for. - async_generator = a.f.options(num_returns="streaming").remote(ray.put(arr)) - expected = 0 - async for ref in async_generator: - value = await ref - assert value == value - expected += 1 - - async def verify_async_task_async_generator(): - async_generator = a.async_f.options(num_returns="streaming").remote( - ray.put(arr) - ) - assert isinstance(async_generator, StreamingObjectRefGenerator) - for expected in range(3): - ref = await async_generator.__anext__() - assert await ref == expected - with pytest.raises(StopAsyncIteration): - await async_generator.__anext__() - - # Verify async for. - async_generator = a.async_f.options(num_returns="streaming").remote( - ray.put(arr) - ) - expected = 0 - async for value in async_generator: - value = await ref - assert value == value - expected += 1 - - verify_sync_task_executor() - verify_async_task_executor() - asyncio.run(verify_sync_task_async_generator()) - asyncio.run(verify_async_task_async_generator()) - - -def test_streaming_generator_exception(shutdown_only): - # Verify the exceptions are correctly raised. - # Also verify the followup next will raise StopIteration. - ray.init() - - @ray.remote - class Actor: - def f(self): - raise ValueError - yield 1 # noqa - - async def async_f(self): - raise ValueError - yield 1 # noqa - - a = Actor.remote() - g = a.f.options(num_returns="streaming").remote() - with pytest.raises(ValueError): - ray.get(next(g)) - - with pytest.raises(StopIteration): - ray.get(next(g)) - - with pytest.raises(StopIteration): - ray.get(next(g)) - - g = a.async_f.options(num_returns="streaming").remote() - with pytest.raises(ValueError): - ray.get(next(g)) - - with pytest.raises(StopIteration): - ray.get(next(g)) - - with pytest.raises(StopIteration): - ray.get(next(g)) - - -def test_threaded_actor_generator(shutdown_only): - ray.init() - - @ray.remote(max_concurrency=10) - class Actor: - def f(self): - for i in range(30): - time.sleep(0.1) - yield np.ones(1024 * 1024) * i - - @ray.remote(max_concurrency=20) - class AsyncActor: - async def f(self): - for i in range(30): - await asyncio.sleep(0.1) - yield np.ones(1024 * 1024) * i - - async def main(): - a = Actor.remote() - asy = AsyncActor.remote() - - async def run(): - i = 0 - async for ref in a.f.options(num_returns="streaming").remote(): - val = ray.get(ref) - print(val) - print(ref) - assert np.array_equal(val, np.ones(1024 * 1024) * i) - i += 1 - del ref - - async def run2(): - i = 0 - async for ref in asy.f.options(num_returns="streaming").remote(): - val = await ref - print(ref) - print(val) - assert np.array_equal(val, np.ones(1024 * 1024) * i), ref - i += 1 - del ref - - coroutines = [run() for _ in range(10)] - coroutines = [run2() for _ in range(20)] - - await asyncio.gather(*coroutines) - - asyncio.run(main()) - - -def test_generator_dist_all_gather(ray_start_cluster): - cluster = ray_start_cluster - cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) - ray.init() - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - cluster.add_node(num_cpus=1) - - @ray.remote(num_cpus=1) - class Actor: - def __init__(self, child=None): - self.child = child - - def get_data(self): - for _ in range(10): - time.sleep(0.1) - yield np.ones(5 * 1024 * 1024) - - async def all_gather(): - actor = Actor.remote() - async for ref in actor.get_data.options(num_returns="streaming").remote(): - val = await ref - assert np.array_equal(np.ones(5 * 1024 * 1024), val) - del ref - - async def main(): - await asyncio.gather(all_gather(), all_gather(), all_gather(), all_gather()) - - asyncio.run(main()) - summary = ray._private.internal_api.memory_summary(stats_only=True) - print(summary) - # assert "Spilled" not in summary, summary - - -if __name__ == "__main__": - import os - - if os.environ.get("PARALLEL_CI"): - sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) - else: - sys.exit(pytest.main(["-sv", __file__])) From 805c7bb26861693f516dc9ce499892ff7964c665 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 06:26:53 -0700 Subject: [PATCH 21/77] Updated Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 12 ++++++------ python/ray/includes/libcoreworker.pxd | 2 +- python/ray/tests/test_streaming_generator.py | 18 ++++++++++-------- src/ray/core_worker/core_worker.cc | 6 +++--- src/ray/core_worker/core_worker.h | 4 ++-- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index d91321342786..994535158e79 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -315,7 +315,7 @@ class StreamingObjectRefGenerator: def _handle_next(self): try: if hasattr(self.worker, "core_worker"): - obj = self.worker.core_worker.async_read_object_ref_stream( + obj = self.worker.core_worker.try_read_next_object_ref_stream( self._generator_ref) return obj else: @@ -899,7 +899,7 @@ cdef execute_streaming_generator( function_name, task_type, title, &intermediate_result, application_error, caller_address) - CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( + CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( intermediate_result.back(), generator_id, caller_address, generator_index, False) @@ -927,7 +927,7 @@ cdef execute_streaming_generator( assert intermediate_result.size() == 1 del output - CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( + CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( intermediate_result.back(), generator_id, caller_address, @@ -944,7 +944,7 @@ cdef execute_streaming_generator( logger.debug( "Writes EoF to a ObjectRefStream " "of an index {}".format(generator_index)) - CCoreWorkerProcess.GetCoreWorker().ReportIntermediateTaskReturn( + CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( c_pair[CObjectID, shared_ptr[CRayObject]]( CObjectID.Nil(), shared_ptr[CRayObject]()), generator_id, @@ -3491,13 +3491,13 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker().DelObjectRefStream(c_generator_id) - def async_read_object_ref_stream(self, ObjectRef generator_id): + def try_read_next_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() CObjectReference c_object_ref check_status( - CCoreWorkerProcess.GetCoreWorker().AsyncReadObjectRefStream( + CCoreWorkerProcess.GetCoreWorker().TryReadObjectRefStream( c_generator_id, &c_object_ref)) return ObjectRef( c_object_ref.object_id(), diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index c9e00b847b57..3998de724433 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -149,7 +149,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const CObjectID& generator_id) void DelObjectRefStream(const CObjectID &generator_id) void CreateObjectRefStream(const CObjectID &generator_id) - CRayStatus AsyncReadObjectRefStream( + CRayStatus TryReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) CObjectID AllocateDynamicReturnId(const CAddress &owner_address) diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 277d8226cb50..ef071536c346 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -28,7 +28,7 @@ def reset_core_worker(self): @pytest.fixture def mocked_worker(): mocked_core_worker = Mock() - mocked_core_worker.async_read_object_ref_stream.return_value = None + mocked_core_worker.try_read_next_object_ref_stream.return_value = None mocked_core_worker.delete_object_ref_stream.return_value = None mocked_core_worker.create_object_ref_stream.return_value = None worker = MockedWorker(mocked_core_worker) @@ -44,7 +44,7 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): c = mocked_worker.core_worker generator_ref = ray.ObjectRef.from_random() generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() c.create_object_ref_stream.assert_called() # Test when there's no new ref, it returns a nil. @@ -55,13 +55,15 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): # When the new ref is available, next should return it. for _ in range(3): new_ref = ray.ObjectRef.from_random() - c.async_read_object_ref_stream.return_value = new_ref + c.try_read_next_object_ref_stream.return_value = new_ref ref = generator._next(timeout_s=0) assert new_ref == ref - # When async_read_object_ref_stream raises a + # When try_read_next_object_ref_stream raises a # ObjectRefStreamEoFError, it should raise a stop iteration. - c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError( + "" + ) # noqa with pytest.raises(StopIteration): ref = generator._next(timeout_s=0) @@ -88,7 +90,7 @@ def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): mocked_ray_wait.return_value = [generator_ref], [] mocked_ray_get.side_effect = WorkerCrashedError() - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() ref = generator._next(timeout_s=0) # If the generator task fails by a systsem error, # meaning the ref will raise an exception @@ -125,7 +127,7 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): # If StopIteration is not raised within # unexpected_network_failure_timeout_s second, # it should fail. - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) assert ref == ray.ObjectRef.nil() time.sleep(1) @@ -268,7 +270,7 @@ def test_generator_streaming_no_leak_upon_failures( # defer for 10s for the second node. m.setenv( "RAY_testing_asio_delay_us", - "CoreWorkerService.grpc_server.ReportIntermediateTaskReturn=100000:1000000", + "CoreWorkerService.grpc_server.ReportGeneratorItemReturns=100000:1000000", ) ray.init(num_cpus=1) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7497dbed8128..4becae7a1958 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2789,10 +2789,10 @@ void CoreWorker::DelObjectRefStream(const ObjectID &generator_id) { task_manager_->DelObjectRefStream(generator_id); } -Status CoreWorker::AsyncReadObjectRefStream(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out) { +Status CoreWorker::TryReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out) { ObjectID object_id; - const auto &status = task_manager_->AsyncReadObjectRefStream(generator_id, &object_id); + const auto &status = task_manager_->TryReadObjectRefStream(generator_id, &object_id); if (!status.ok()) { return status; } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 0b3c181bfa75..6c9da7050f9c 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -379,8 +379,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// generated ObjectReference. /// \return Status RayKeyError if the stream reaches to EoF. /// OK otherwise. - Status AsyncReadObjectRefStream(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out); + Status TryReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out); /// Delete the ObjectRefStream of generator_id /// created by CreateObjectRefStream. From fa7fe24e520a1004e627fbb5452c910ef75d00e9 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 06:42:45 -0700 Subject: [PATCH 22/77] done Signed-off-by: SangBin Cho --- python/ray/tests/test_streaming_generator.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 815d3c11bcff..3a71e55871af 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -57,7 +57,7 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): for _ in range(3): new_ref = ray.ObjectRef.from_random() c.try_read_next_object_ref_stream.return_value = new_ref - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) assert new_ref == ref # When try_read_next_object_ref_stream raises a @@ -92,12 +92,10 @@ def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): mocked_ray_get.side_effect = WorkerCrashedError() c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0) + ref = generator._next_sync(timeout_s=0) # If the generator task fails by a systsem error, # meaning the ref will raise an exception # it should be returned. - print(ref) - print(generator_ref) assert ref == generator_ref # Once exception is raised, it should always @@ -129,7 +127,7 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): # unexpected_network_failure_timeout_s second, # it should fail. c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) + ref = generator._next_sync(timeout_s=0, unexpected_network_failure_timeout_s=1) assert ref == ray.ObjectRef.nil() time.sleep(1) with pytest.raises(AssertionError): @@ -624,7 +622,7 @@ async def run2(): asyncio.run(main()) -def test_generator_dist_all_gather(ray_start_cluster): +def test_generator_dist_gather(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0, object_store_memory=1 * 1024 * 1024 * 1024) ray.init() From 5dc6b98d24d6af6e2a9e6c3e36f7ec1c4bca7488 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 06:48:59 -0700 Subject: [PATCH 23/77] Fixed a unit test. Signed-off-by: SangBin Cho --- src/ray/core_worker/test/reference_count_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ray/core_worker/test/reference_count_test.cc b/src/ray/core_worker/test/reference_count_test.cc index dd715888d646..ee26467516bc 100644 --- a/src/ray/core_worker/test/reference_count_test.cc +++ b/src/ray/core_worker/test/reference_count_test.cc @@ -2980,14 +2980,14 @@ TEST_F(ReferenceCountTest, TestOwnDynamicStreamingTaskReturnRef) { ASSERT_EQ(address.ip_address(), added_address.ip_address()); // Verify it had 1 local reference. std::vector deleted; - rc_.RemoveLocalReference(object_id, &deleted); + rc->RemoveLocalReference(object_id, &deleted); ASSERT_EQ(rc->NumObjectIDsInScope(), 1); ASSERT_EQ(deleted.size(), 1); ASSERT_FALSE(rc->GetOwner(object_id, &added_address)); // Remove the generator. - rc_.RemoveLocalReference(generator_id, nullptr); - ASSERT_EQ(rc->NumObjectIDsInScope(), 1); + rc->RemoveLocalReference(generator_id, nullptr); + ASSERT_EQ(rc->NumObjectIDsInScope(), 0); ASSERT_FALSE(rc->GetOwner(generator_id, &added_address)); // Verify we cannot register a new object after the generator id is removed. From 7c449be48b943e3abac90f4b67cf00090c8f8600 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 17:06:12 -0700 Subject: [PATCH 24/77] fix apis Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 6 +++--- python/ray/includes/libcoreworker.pxd | 2 +- python/ray/tests/test_streaming_generator.py | 16 +++++++++------- src/ray/core_worker/core_worker.cc | 6 +++--- src/ray/core_worker/core_worker.h | 4 ++-- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index ee46c8705cb9..24c1fbcab3a9 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -315,7 +315,7 @@ class StreamingObjectRefGenerator: def _handle_next(self): try: if hasattr(self.worker, "core_worker"): - obj = self.worker.core_worker.async_read_object_ref_stream( + obj = self.worker.core_worker.try_read_next_object_ref_stream( self._generator_ref) return obj else: @@ -3291,13 +3291,13 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker().DelObjectRefStream(c_generator_id) - def async_read_object_ref_stream(self, ObjectRef generator_id): + def try_read_next_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() CObjectReference c_object_ref check_status( - CCoreWorkerProcess.GetCoreWorker().AsyncReadObjectRefStream( + CCoreWorkerProcess.GetCoreWorker().TryReadObjectRefStream( c_generator_id, &c_object_ref)) return ObjectRef( c_object_ref.object_id(), diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 8a6ff6f21a05..306d5f940fd2 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -149,7 +149,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const CObjectID& generator_id) void DelObjectRefStream(const CObjectID &generator_id) void CreateObjectRefStream(const CObjectID &generator_id) - CRayStatus AsyncReadObjectRefStream( + CRayStatus TryReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) CObjectID AllocateDynamicReturnId() diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index c496d52b6179..248b7dc67749 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -24,7 +24,7 @@ def reset_core_worker(self): @pytest.fixture def mocked_worker(): mocked_core_worker = Mock() - mocked_core_worker.async_read_object_ref_stream.return_value = None + mocked_core_worker.try_read_next_object_ref_stream.return_value = None mocked_core_worker.delete_object_ref_stream.return_value = None mocked_core_worker.create_object_ref_stream.return_value = None worker = MockedWorker(mocked_core_worker) @@ -40,7 +40,7 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): c = mocked_worker.core_worker generator_ref = ray.ObjectRef.from_random() generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() c.create_object_ref_stream.assert_called() # Test when there's no new ref, it returns a nil. @@ -51,13 +51,15 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): # When the new ref is available, next should return it. for _ in range(3): new_ref = ray.ObjectRef.from_random() - c.async_read_object_ref_stream.return_value = new_ref + c.try_read_next_object_ref_stream.return_value = new_ref ref = generator._next(timeout_s=0) assert new_ref == ref - # When async_read_object_ref_stream raises a + # When try_read_next_object_ref_stream raises a # ObjectRefStreamEoFError, it should raise a stop iteration. - c.async_read_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError( + "" + ) # noqa with pytest.raises(StopIteration): ref = generator._next(timeout_s=0) @@ -84,7 +86,7 @@ def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): mocked_ray_wait.return_value = [generator_ref], [] mocked_ray_get.side_effect = WorkerCrashedError() - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() ref = generator._next(timeout_s=0) # If the generator task fails by a systsem error, # meaning the ref will raise an exception @@ -121,7 +123,7 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): # If StopIteration is not raised within # unexpected_network_failure_timeout_s second, # it should fail. - c.async_read_object_ref_stream.return_value = ray.ObjectRef.nil() + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() ref = generator._next(timeout_s=0, unexpected_network_failure_timeout_s=1) assert ref == ray.ObjectRef.nil() time.sleep(1) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index bd7d948b0b53..1029b1e51fd0 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2777,10 +2777,10 @@ void CoreWorker::DelObjectRefStream(const ObjectID &generator_id) { task_manager_->DelObjectRefStream(generator_id); } -Status CoreWorker::AsyncReadObjectRefStream(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out) { +Status CoreWorker::TryReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out) { ObjectID object_id; - const auto &status = task_manager_->AsyncReadObjectRefStream(generator_id, &object_id); + const auto &status = task_manager_->TryReadObjectRefStream(generator_id, &object_id); if (!status.ok()) { return status; } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 2e0aeaddc2b5..0673ee280f3c 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -379,8 +379,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// generated ObjectReference. /// \return Status RayKeyError if the stream reaches to EoF. /// OK otherwise. - Status AsyncReadObjectRefStream(const ObjectID &generator_id, - rpc::ObjectReference *object_ref_out); + Status TryReadObjectRefStream(const ObjectID &generator_id, + rpc::ObjectReference *object_ref_out); /// Delete the ObjectRefStream of generator_id /// created by CreateObjectRefStream. From d7ebad1ca1703cb75c5c96b51b05b7d659b0376f Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 18:56:31 -0700 Subject: [PATCH 25/77] lint. Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 14 +++++++++----- python/ray/includes/libcoreworker.pxd | 4 ++-- python/ray/includes/unique_ids.pxi | 3 ++- python/ray/tests/test_streaming_generator.py | 8 ++++++-- src/ray/core_worker/core_worker.cc | 6 +++--- src/ray/core_worker/core_worker.h | 6 +++--- 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 47faddbf72ff..f6541e582d19 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -106,6 +106,7 @@ from ray.includes.unique_ids cimport ( CObjectID, CNodeID, CPlacementGroupID, + ObjectIDIndexType, ) from ray.includes.libcoreworker cimport ( ActorHandleSharedPtr, @@ -123,7 +124,7 @@ from ray.includes.ray_config cimport RayConfig from ray.includes.global_state_accessor cimport CGlobalStateAccessor from ray.includes.global_state_accessor cimport RedisDelKeySync from ray.includes.optional cimport ( - optional + optional, nullopt ) import ray @@ -187,6 +188,8 @@ current_task_id_lock = threading.Lock() job_config_initialized = False job_config_initialization_lock = threading.Lock() +cdef optional[ObjectIDIndexType] NULL_PUT_INDEX = nullopt + class ObjectRefGenerator: def __init__(self, refs): @@ -1093,7 +1096,7 @@ cdef execute_dynamic_generator_and_store_task_outputs( # ObjectRef will contain the error. error_id = (CCoreWorkerProcess.GetCoreWorker() .AllocateDynamicReturnId( - caller_address, CTaskID.Nil(), -1)) + caller_address, CTaskID.Nil(), NULL_PUT_INDEX)) dynamic_returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -3308,7 +3311,7 @@ cdef class CoreWorker: while i >= returns[0].size(): return_id = (CCoreWorkerProcess.GetCoreWorker() .AllocateDynamicReturnId( - caller_address, CTaskID.Nil(), -1)) + caller_address, CTaskID.Nil(), NULL_PUT_INDEX)) returns[0].push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -3646,13 +3649,14 @@ cdef class CoreWorker: # Should add 1 because put index is always incremented # before it is used. So if you have 1 return object # the next index will be 2. - 1 + num_returns + generator_index, # put_index + make_optional[ObjectIDIndexType]( + 1 + num_returns + generator_index) # put_index ) else: return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( owner_address, CTaskID.Nil(), - -1 + NULL_PUT_INDEX ) def create_object_ref_stream(self, ObjectRef generator_id): diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index ec963a9dd259..d2359317e30d 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -50,7 +50,7 @@ from ray.includes.function_descriptor cimport ( ) from ray.includes.optional cimport ( - optional + optional, ) ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ @@ -156,7 +156,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CObjectID AllocateDynamicReturnId( const CAddress &owner_address, const CTaskID &task_id, - ObjectIDIndexType put_index) + optional[ObjectIDIndexType] put_index) CJobID GetCurrentJobId() CTaskID GetCurrentTaskId() diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 2b4f5c78f5ba..8221111a2955 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -21,9 +21,10 @@ from ray.includes.unique_ids cimport ( CTaskID, CUniqueID, CWorkerID, - CPlacementGroupID + CPlacementGroupID, ) + import ray from ray._private.utils import decode diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 3a71e55871af..1037ef3c0f6d 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -127,7 +127,9 @@ def test_streaming_object_ref_generator_network_failed_unit(mocked_worker): # unexpected_network_failure_timeout_s second, # it should fail. c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - ref = generator._next_sync(timeout_s=0, unexpected_network_failure_timeout_s=1) + ref = generator._next_sync( + timeout_s=0, unexpected_network_failure_timeout_s=1 + ) assert ref == ray.ObjectRef.nil() time.sleep(1) with pytest.raises(AssertionError): @@ -168,7 +170,9 @@ async def test_streaming_object_ref_generator_unit_async(mocked_worker): # When try_read_next_object_ref_stream raises a # ObjectRefStreamEoFError, it should raise a stop iteration. - c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError("") # noqa + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError( + "" + ) # noqa with pytest.raises(StopAsyncIteration): ref = await generator._next_async(timeout_s=0) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index f98e515c85ae..326455543a56 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2858,7 +2858,7 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, const TaskID &task_id, - ObjectIDIndexType put_index) { + std::optional put_index) { TaskID current_task_id; if (task_id.IsNil()) { const auto &task_spec = worker_context_.GetCurrentTask(); @@ -2868,10 +2868,10 @@ ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, } ObjectIDIndexType current_put_index; - if (put_index == -1) { + if (!put_index.has_value()) { current_put_index = worker_context_.GetNextPutIndex(); } else { - current_put_index = put_index; + current_put_index = put_index.value(); } const auto return_id = ObjectID::FromIndex(current_task_id, current_put_index); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index ee745f736c92..fbdcffa91995 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1038,11 +1038,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// worker context. /// \param[in] put_index The equivalent of the return value of /// WorkerContext::GetNextPutIndex. - /// If -1 is specified, it will deduce the Task ID from the current - /// worker context. + /// If std::nullopt is specified, it will deduce the put index from the + /// current worker context. ObjectID AllocateDynamicReturnId(const rpc::Address &owner_address, const TaskID &task_id = TaskID::Nil(), - ObjectIDIndexType put_index = -1); + std::optional put_index = -1); /// Get a handle to an actor. /// From 2b046b6b6be2028f84f3d1c2218d6e35f3e96bb0 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 May 2023 23:45:47 -0700 Subject: [PATCH 26/77] Ready for a benchmark. Signed-off-by: SangBin Cho --- python/ray/_private/ray_perf.py | 130 ++++++++++++++++++++++++++++++++ python/ray/_raylet.pyx | 28 +++++-- 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/python/ray/_private/ray_perf.py b/python/ray/_private/ray_perf.py index a902ae0f7dc4..f811ef5640f3 100644 --- a/python/ray/_private/ray_perf.py +++ b/python/ray/_private/ray_perf.py @@ -35,6 +35,55 @@ async def small_value_batch(self, n): await asyncio.wait([small_value.remote() for _ in range(n)]) +@ray.remote +class AsyncGeneratorActor: + async def small_value(self): + yield b"ok" + + async def small_value_with_arg(self, x): + yield b"ok" + + async def small_value_batch(self, n): + yield await asyncio.wait([small_value.remote() for _ in range(n)]) + + +@ray.remote(num_cpus=0) +class GeneratorClient: + def __init__(self, servers): + if not isinstance(servers, list): + servers = [servers] + self.servers = servers + + def small_value_batch(self, n): + results = [] + for s in self.servers: + results.extend( + [ + s.small_value.options(num_returns="streaming").remote() + for _ in range(n) + ] + ) + refs = [] + for gen in results: + refs.extend(list(gen)) + ray.get(refs) + + def small_value_batch_arg(self, n): + x = ray.put(0) + results = [] + for s in self.servers: + results.extend( + [ + s.small_value_arg.options(num_returns="streaming").remote(x) + for _ in range(n) + ] + ) + refs = [] + for gen in results: + refs.extend(list(gen)) + ray.get(refs) + + @ray.remote(num_cpus=0) class Client: def __init__(self, servers): @@ -280,6 +329,87 @@ def async_actor_multi(): ray.get([async_actor_work.remote(a) for _ in range(m)]) results += timeit("n:n async-actor calls async", async_actor_multi, m * n) + + """ + Async generator actor + """ + a = AsyncGeneratorActor.remote() + + def actor_sync_generator(): + ray.get(list(a.small_value.options(num_returns="streaming").remote())) + + results += timeit("1:1 async-actor-generator calls sync", actor_sync_generator) + + a = AsyncGeneratorActor.options().remote() + + def async_actor_generator(): + gens = [ + a.small_value.options(num_returns="streaming").remote() for _ in range(1000) + ] + refs = [] + for gen in gens: + refs.extend(list(gen)) + ray.get(refs) + + results += timeit( + "1:1 async-actor-generator calls async", async_actor_generator, 1000 + ) + + a = AsyncGeneratorActor.remote() + + def async_actor_generator(): + gens = [ + a.small_value_with_arg.options(num_returns="streaming").remote(i) + for i in range(1000) + ] + refs = [] + for gen in gens: + refs.extend(list(gen)) + print(refs) + print(len(refs)) + ray.get(refs) + + results += timeit( + "1:1 async-actor-generator calls with args async", async_actor_generator, 1000 + ) + + n = 5000 + n_cpu = multiprocessing.cpu_count() // 2 + actors = [AsyncGeneratorActor.remote() for _ in range(n_cpu)] + client = GeneratorClient.remote(actors) + + def async_actor_async_generator(): + gen = client.small_value_batch.remote(n) + ray.get(list(gen)) + + results += timeit( + "1:n async-actor-generator calls async", + async_actor_async_generator, + n * len(actors), + ) + + n = 5000 + m = 4 + n_cpu = multiprocessing.cpu_count() // 2 + a = [AsyncGeneratorActor.remote() for _ in range(n_cpu)] + + @ray.remote + def async_actor_work_generator(actors): + gens = [ + actors[i % n_cpu].small_value.options(num_returns="streaming").remote() + for i in range(n) + ] + refs = [] + for gen in gens: + refs.extend(list(gen)) + ray.get(refs) + + def async_actor_multi_generator(): + ray.get([async_actor_work_generator.remote(a) for _ in range(m)]) + + results += timeit( + "n:n async-actor-generator calls async", async_actor_multi_generator, m * n + ) ray.shutdown() NUM_PGS = 100 diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index f6541e582d19..d2346a8243b2 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -245,7 +245,7 @@ class StreamingObjectRefGenerator: self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, - unexpected_network_failure_timeout_s: float = 30): + unexpected_network_failure_timeout_s: float = 60): """Waits for timeout_s and returns the object ref if available. If an object is not available within the given timeout, it @@ -298,7 +298,7 @@ class StreamingObjectRefGenerator: self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, - unexpected_network_failure_timeout_s: float = 30): + unexpected_network_failure_timeout_s: float = 60): """Same API as _next_sync, but it is for async context.""" obj = await self._handle_next_async() last_time = time.time() @@ -386,7 +386,10 @@ class StreamingObjectRefGenerator: # It means the next wasn't reported although the task # has been terminated 30 seconds ago. self._generator_task_exception = AssertionError - assert False, "Unexpected network failure occured." + assert False, ( + "Unexpected network failure occured. " + f"Task ID: {self._generator_ref.task_id()}" + ) if timeout_s != -1 and time.time() - last_time > timeout_s: return ObjectRef.nil() @@ -1393,7 +1396,9 @@ cdef void execute_task( print(task_attempt_magic_token, end="") print(task_attempt_magic_token, file=sys.stderr, end="") - if returns[0].size() == 1 and not inspect.isgenerator(outputs): + if (returns[0].size() == 1 + and not inspect.isgenerator(outputs) + and not inspect.isasyncgen(outputs)): # If there is only one return specified, we should return # all return values as a single object. outputs = (outputs,) @@ -1418,13 +1423,22 @@ cdef void execute_task( # like GCS has such info. core_worker.set_actor_repr_name(actor_repr) - if (returns[0].size() > 0 and - not inspect.isgenerator(outputs) and - len(outputs) != int(returns[0].size())): + if (returns[0].size() > 0 + and not inspect.isgenerator(outputs) + and not inspect.isasyncgen(outputs) + and len(outputs) != int(returns[0].size())): raise ValueError( "Task returned {} objects, but num_returns={}.".format( len(outputs), returns[0].size())) + if inspect.isgenerator(outputs) or inspect.isasyncgen(outputs): + if dynamic_returns == NULL and not is_streaming_generator: + raise ValueError( + f"{name} is a generator function, " + "but it doesn't specify " + "@ray.remote(num_returns=\"dynamic\") or " + "@ray.remote (num_returns=\"streaming\"). ") + # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): # TODO(sang): Remove it once we use streaming generator From c726484164c00593d8223b116a3f4dc9bc0b9094 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 18 May 2023 03:03:16 -0700 Subject: [PATCH 27/77] Made it work. Signed-off-by: SangBin Cho --- python/ray/_private/ray_perf.py | 6 +- python/ray/_raylet.pyx | 28 +++---- python/ray/includes/libcoreworker.pxd | 1 - python/ray/tests/test_streaming_generator.py | 83 +++++++++++++++++++- src/ray/core_worker/core_worker.cc | 15 ++++ src/ray/core_worker/task_manager.cc | 3 +- src/ray/core_worker/task_manager.h | 4 + 7 files changed, 117 insertions(+), 23 deletions(-) diff --git a/python/ray/_private/ray_perf.py b/python/ray/_private/ray_perf.py index f811ef5640f3..219dd00833a7 100644 --- a/python/ray/_private/ray_perf.py +++ b/python/ray/_private/ray_perf.py @@ -365,8 +365,6 @@ def async_actor_generator(): refs = [] for gen in gens: refs.extend(list(gen)) - print(refs) - print(len(refs)) ray.get(refs) results += timeit( @@ -379,8 +377,8 @@ def async_actor_generator(): client = GeneratorClient.remote(actors) def async_actor_async_generator(): - gen = client.small_value_batch.remote(n) - ray.get(list(gen)) + ref = client.small_value_batch.remote(n) + ray.get(ref) results += timeit( "1:n async-actor-generator calls async", diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index d2346a8243b2..3c0178efc51f 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -217,7 +217,6 @@ class StreamingObjectRefGenerator: # Ray's worker class. ray._private.worker.global_worker self.worker = worker assert hasattr(worker, "core_worker") - self.worker.core_worker.create_object_ref_stream(self._generator_ref) def __iter__(self): return self @@ -398,6 +397,8 @@ class StreamingObjectRefGenerator: def __del__(self): if hasattr(self.worker, "core_worker"): + # The stream is created when a task is first submitted via + # CreateObjectRefStream. # NOTE: This can be called multiple times # because python doesn't guarantee __del__ is called # only once. @@ -491,6 +492,7 @@ def compute_task_id(ObjectRef object_ref): cdef increase_recursion_limit(): """Double the recusion limit if current depth is close to the limit""" + t = time.time() cdef: CPyThreadState * s = PyThreadState_Get() int current_limit = Py_GetRecursionLimit() @@ -506,7 +508,6 @@ cdef increase_recursion_limit(): int CURRENT_DEPTH(CPyThreadState *x) int current_depth = CURRENT_DEPTH(s) - if current_limit - current_depth < 500: Py_SetRecursionLimit(new_limit) logger.debug("Increasing Python recursion limit to {} " @@ -1219,14 +1220,6 @@ cdef void execute_task( class_name=class_name ) ) - # Increase recursion limit if necessary. In asyncio mode, - # we have many parallel callstacks (represented in fibers) - # that's suspended for execution. Python interpreter will - # mistakenly count each callstack towards recusion limit. - # We don't need to worry about stackoverflow here because - # the max number of callstacks is limited in direct actor - # transport with max_concurrency flag. - increase_recursion_limit() if is_async_func(function.method): async_function = function @@ -3465,6 +3458,15 @@ cdef class CoreWorker: eventloop, async_thread = self.get_event_loop( function_descriptor, specified_cgname) + # Increase recursion limit if necessary. In asyncio mode, + # we have many parallel callstacks (represented in fibers) + # that's suspended for execution. Python interpreter will + # mistakenly count each callstack towards recusion limit. + # We don't need to worry about stackoverflow here because + # the max number of callstacks is limited in direct actor + # transport with max_concurrency flag. + increase_recursion_limit() + if inspect.isawaitable(func_or_coro): coroutine = func_or_coro else: @@ -3673,12 +3675,6 @@ cdef class CoreWorker: NULL_PUT_INDEX ) - def create_object_ref_stream(self, ObjectRef generator_id): - cdef: - CObjectID c_generator_id = generator_id.native() - - CCoreWorkerProcess.GetCoreWorker().CreateObjectRefStream(c_generator_id) - def delete_object_ref_stream(self, ObjectRef generator_id): cdef: CObjectID c_generator_id = generator_id.native() diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index d2359317e30d..cc0b3092ffb2 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -149,7 +149,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: shared_ptr[CRayObject] *return_object, const CObjectID& generator_id) void DelObjectRefStream(const CObjectID &generator_id) - void CreateObjectRefStream(const CObjectID &generator_id) CRayStatus TryReadObjectRefStream( const CObjectID &generator_id, CObjectReference *object_ref_out) diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 1037ef3c0f6d..81df9dc71b7e 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -658,7 +658,88 @@ async def main(): asyncio.run(main()) summary = ray._private.internal_api.memory_summary(stats_only=True) print(summary) - # assert "Spilled" not in summary, summary + + +def test_ray_serve_like_generator_stress_test(ray_start_cluster, monkeypatch): + """Mock the stressful Ray Serve workloads. + + Ray Serve has a single actor that invokes many generator tasks. + All the actors are async actor for Ray Serve. + """ + with monkeypatch.context() as m: + # Add a 10ms ~ 1 second delay to the RPC. + m.setenv( + "RAY_testing_asio_delay_us", + "CoreWorkerService.grpc_server.ReportGeneratorItemReturns=10000:1000000", + ) + + cluster = ray_start_cluster + total_cpus = 20 + # 5 nodes cluster, 4 CPUs each. + cluster.add_node(num_cpus=total_cpus // 5) + ray.init() + for _ in range(4): + cluster.add_node(num_cpus=total_cpus // 5) + + @ray.remote(num_cpus=1) + class ProxyActor: + async def get_data(self, child): + await asyncio.sleep(0.1) + gen = child.get_data.options(num_returns="streaming").remote() + async for ref in gen: + yield ref + del ref + + @ray.remote + class ChainActor: + def __init__(self, child=None): + self.child = child + + async def get_data(self): + if not self.child: + for i in range(10): + await asyncio.sleep(0.1) + yield np.ones(5 * 1024) * i + else: + async for ref in self.child.get_data.options( + num_returns="streaming" + ).remote(): + yield ref + + chain_actors = [] + num_chain_actors = 16 + for _ in range(num_chain_actors): + chain_actor = ChainActor.remote() + chain_actor_2 = ChainActor.remote(chain_actor) + chain_actor_3 = ChainActor.remote(chain_actor_2) + chain_actor_4 = ChainActor.remote(chain_actor_3) + chain_actors.append(chain_actor_4) + + proxy_actor = ProxyActor.remote() + + async def get_stream(proxy_actor, chain_actor): + i = 0 + async for ref in proxy_actor.get_data.options( + num_returns="streaming" + ).remote(chain_actor): + for _ in range(5): + ref = await ref + assert np.array_equal(np.ones(5 * 1024) * i, ref) + del ref + i += 1 + + async def main(): + await asyncio.gather( + *[get_stream(proxy_actor, chain_actor) for chain_actor in chain_actors] + ) + result = list_objects(raise_on_missing_output=False) + ref_types = set() + for r in result: + ref_types.add(r.reference_type) + # Verify no leaks + assert ref_types == {"ACTOR_HANDLE"} + + asyncio.run(main()) if __name__ == "__main__": diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 326455543a56..fe7e6f35aee7 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1947,6 +1947,13 @@ std::vector CoreWorker::SubmitTask( } else { returned_refs = task_manager_->AddPendingTask( task_spec.CallerAddress(), task_spec, CurrentCallSite(), max_retries); + + // If it is a generator task, create a object ref stream. + // The language frontend is responsible for calling DeleteObjectRefStream. + if (task_spec.IsStreamingGenerator()) { + CreateObjectRefStream(task_spec.ReturnId(0)); + } + io_service_.post( [this, task_spec]() { RAY_UNUSED(direct_task_submitter_->SubmitTask(task_spec)); @@ -2272,6 +2279,14 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, } else { returned_refs = task_manager_->AddPendingTask( rpc_address_, task_spec, CurrentCallSite(), actor_handle->MaxTaskRetries()); + + // If it is a generator task, create a object ref stream. + // The language frontend is responsible for calling DeleteObjectRefStream. + if (task_spec.IsStreamingGenerator()) { + // Generator task only has 1 return. + CreateObjectRefStream(task_spec.ReturnId(0)); + } + RAY_CHECK_OK(direct_actor_submitter_->SubmitTask(task_spec)); } task_returns = std::move(returned_refs); diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index b2fe7373d213..e702cc81c3fb 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -445,9 +445,10 @@ void TaskManager::HandleReportGeneratorItemReturns( if (request.finished()) { absl::MutexLock lock(&mu_); - RAY_LOG(DEBUG) << "Write EoF to the object ref stream. Index: " << item_index; + RAY_LOG(DEBUG) << "Writing EoF to the object ref stream. Index: " << item_index; auto stream_it = object_ref_streams_.find(generator_id); if (stream_it != object_ref_streams_.end()) { + RAY_LOG(DEBUG) << "Wrote EoF to the object ref stream. Index: " << item_index; stream_it->second.MarkEndOfStream(item_index); } // The last report should not have any return objects. diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index b94a2263accd..3cadecc6e7b0 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -236,8 +236,12 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// Create the object ref stream. /// If the object ref stream is not created by this API, /// all object ref stream operation will be no-op. + /// /// Once the stream is created, it has to be deleted /// by DelObjectRefStream when it is not used anymore. + /// Once you generate a stream, it is the caller's responsibility + /// to call DelObjectRefStream. + /// /// The API is not idempotent. /// /// \param[in] generator_id The object ref id of the streaming From 1151a289393501871fdff218ae02ce2a15810530 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 18 May 2023 05:14:53 -0700 Subject: [PATCH 28/77] done. Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 15 +++------------ src/ray/core_worker/core_worker.cc | 1 - 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 3c0178efc51f..fb807d9237eb 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -244,7 +244,7 @@ class StreamingObjectRefGenerator: self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, - unexpected_network_failure_timeout_s: float = 60): + unexpected_network_failure_timeout_s: float = 30): """Waits for timeout_s and returns the object ref if available. If an object is not available within the given timeout, it @@ -297,7 +297,7 @@ class StreamingObjectRefGenerator: self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, - unexpected_network_failure_timeout_s: float = 60): + unexpected_network_failure_timeout_s: float = 30): """Same API as _next_sync, but it is for async context.""" obj = await self._handle_next_async() last_time = time.time() @@ -387,7 +387,7 @@ class StreamingObjectRefGenerator: self._generator_task_exception = AssertionError assert False, ( "Unexpected network failure occured. " - f"Task ID: {self._generator_ref.task_id()}" + f"Task ID: {self._generator_ref.task_id().hex()}" ) if timeout_s != -1 and time.time() - last_time > timeout_s: @@ -492,7 +492,6 @@ def compute_task_id(ObjectRef object_ref): cdef increase_recursion_limit(): """Double the recusion limit if current depth is close to the limit""" - t = time.time() cdef: CPyThreadState * s = PyThreadState_Get() int current_limit = Py_GetRecursionLimit() @@ -1424,14 +1423,6 @@ cdef void execute_task( "Task returned {} objects, but num_returns={}.".format( len(outputs), returns[0].size())) - if inspect.isgenerator(outputs) or inspect.isasyncgen(outputs): - if dynamic_returns == NULL and not is_streaming_generator: - raise ValueError( - f"{name} is a generator function, " - "but it doesn't specify " - "@ray.remote(num_returns=\"dynamic\") or " - "@ray.remote (num_returns=\"streaming\"). ") - # Store the outputs in the object store. with core_worker.profile_event(b"task:store_outputs"): # TODO(sang): Remove it once we use streaming generator diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index fe7e6f35aee7..f0b60f105b98 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2283,7 +2283,6 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, // If it is a generator task, create a object ref stream. // The language frontend is responsible for calling DeleteObjectRefStream. if (task_spec.IsStreamingGenerator()) { - // Generator task only has 1 return. CreateObjectRefStream(task_spec.ReturnId(0)); } From b7be576913c115d5d9f844bebcea288f7bd70ea4 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 19 May 2023 08:28:34 -0700 Subject: [PATCH 29/77] Addressed code review. Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 907d19618149..33517ecda2e7 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -200,7 +200,7 @@ class ObjectRefGenerator: class StreamingObjectRefGenerator: - def __init__(self, generator_ref, worker): + def __init__(self, generator_ref: ObjectRef, worker: "Worker"): # The reference to a generator task. self._generator_ref = generator_ref # The last time generator task has completed. @@ -212,10 +212,10 @@ class StreamingObjectRefGenerator: assert hasattr(worker, "core_worker") self.worker.core_worker.create_object_ref_stream(self._generator_ref) - def __iter__(self): + def __iter__(self) -> "StreamingObjectRefGenerator": return self - def __next__(self): + def __next__(self) -> ObjectRef: """Waits until a next ref is available and returns the object ref. Raises StopIteration if there's no more objects @@ -232,7 +232,7 @@ class StreamingObjectRefGenerator: self, timeout_s: float = -1, sleep_interval_s: float = 0.0001, - unexpected_network_failure_timeout_s: float = 30): + unexpected_network_failure_timeout_s: float = 30) -> ObjectRef: """Waits for timeout_s and returns the object ref if available. If an object is not available within the given timeout, it @@ -312,7 +312,7 @@ class StreamingObjectRefGenerator: obj = self._handle_next() return obj - def _handle_next(self): + def _handle_next(self) -> ObjectRef: try: if hasattr(self.worker, "core_worker"): obj = self.worker.core_worker.try_read_next_object_ref_stream( @@ -334,8 +334,8 @@ class StreamingObjectRefGenerator: def __getstate__(self): raise TypeError( - "Serialization of the StreamingObjectRefGenerator " - "is now allowed") + "You cannot return or pass a generator to other task. " + "Serializing a StreamingObjectRefGenerator is not allowed.") cdef int check_status(const CRayStatus& status) nogil except -1: From 11686d43f82cf0ad924e0407435f3b270971e4d2 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 19 May 2023 09:09:20 -0700 Subject: [PATCH 30/77] Addressed code review. Signed-off-by: SangBin Cho --- python/ray/_private/ray_option_utils.py | 4 +- python/ray/_raylet.pyx | 213 +++++++++++++++++------- python/ray/actor.py | 2 +- python/ray/includes/common.pxd | 1 + python/ray/includes/common.pxi | 2 + python/ray/remote_function.py | 2 +- python/ray/tests/test_generators.py | 2 +- src/ray/common/constants.h | 2 + src/ray/core_worker/core_worker.cc | 2 +- 9 files changed, 164 insertions(+), 66 deletions(-) diff --git a/python/ray/_private/ray_option_utils.py b/python/ray/_private/ray_option_utils.py index 97c35f9449ca..9d3104151a53 100644 --- a/python/ray/_private/ray_option_utils.py +++ b/python/ray/_private/ray_option_utils.py @@ -155,8 +155,8 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: (int, str, type(None)), lambda x: None if (x is None or x == "dynamic" or x == "streaming" or x >= 0) - else "The keyword 'num_returns' only accepts None, a non-negative integer, or " - '"dynamic" (for generators)', + else "The keyword 'num_returns' only accepts None, a non-negative integer, " + '"dynamic" (for generators), or "streaming" (for streaming generators)', default_value=1, ), "object_store_memory": Option( # override "_common_options" diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 2635fbd700b1..f4aa5789f9a2 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -851,7 +851,6 @@ cdef execute_streaming_generator( worker = ray._private.worker.global_worker cdef: CoreWorker core_worker = worker.core_worker - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result generator_index = 0 assert inspect.isgenerator(generator), ( @@ -864,85 +863,54 @@ cdef execute_streaming_generator( except StopIteration: break except Exception as e: - # Report the error if the generator failed to execute. - is_retryable_error[0] = determine_if_retryable( + error_obj = create_generator_error_object( e, + worker, + task_type, + caller_address, + task_id, serialized_retry_exception_allowlist, + function_name, function_descriptor, + title, + actor, + actor_id, + is_retryable_error, + application_error ) - - if ( - is_retryable_error[0] - and core_worker.get_current_task_retry_exceptions() - ): - logger.debug("Task failed with retryable exception:" - " {}.".format(task_id), exc_info=True) - # Raise an exception directly and halt the execution - # because there's no need to set the exception - # for the return value when the task is retryable. - raise e - - logger.debug("Task failed with unretryable exception:" - " {}.".format(task_id), exc_info=True) - - error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) - intermediate_result.push_back( - c_pair[CObjectID, shared_ptr[CRayObject]]( - error_id, shared_ptr[CRayObject]())) - - store_task_errors( - worker, e, - True, # task_exception - actor, # actor - actor_id, # actor id - function_name, task_type, title, - &intermediate_result, application_error, caller_address) - CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( - intermediate_result.back(), - generator_id, caller_address, generator_index, False) - - if intermediate_result.size() > 0: - intermediate_result.pop_back() + error_obj, + generator_id, + caller_address, + generator_index, + False) # finished generator_index += 1 break else: # Report the intermediate result if there was no error. - return_id = ( - CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( - caller_address)) - intermediate_result.push_back( - c_pair[CObjectID, shared_ptr[CRayObject]]( - return_id, shared_ptr[CRayObject]())) - - core_worker.store_task_outputs( - worker, [output], - &intermediate_result, - caller_address, - generator_id) + generator_return_obj = create_generator_return_obj( + output, + generator_id, + worker, + caller_address) + # Del output here so that we can GC the memory + # usage asap. + del output + logger.debug( "Writes to a ObjectRefStream of an " "index {}".format(generator_index)) - assert intermediate_result.size() == 1 - del output - CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( - intermediate_result.back(), + generator_return_obj, generator_id, caller_address, generator_index, - False) - - if intermediate_result.size() > 0: - intermediate_result.pop_back() + False) # finished generator_index += 1 - # All the intermediate result has to be popped and reported. - assert intermediate_result.size() == 0 # Report the owner that there's no more objects. logger.debug( - "Writes EoF to a ObjectRefStream " + "Writes End of stream to a ObjectRefStream " "of an index {}".format(generator_index)) CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( c_pair[CObjectID, shared_ptr[CRayObject]]( @@ -953,6 +921,131 @@ cdef execute_streaming_generator( True) # finished. +cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( + output, + const CObjectID &generator_id, + worker: "Worker", + const CAddress &caller_address): + """Create a generator return object based on a given output. + + Args: + output: The output from a next(generator). + generator_id: The object ref id of the generator task. + worker: The Python worker class inside worker.py + caller_address: The address of the caller. By our protocol, + the caller of the streaming generator task is always + the owner, so we can also call it "owner address". + + Returns: + A Ray Object that contains the given output. + """ + cdef: + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result + CoreWorker core_worker = worker.core_worker + + return_id = ( + CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( + caller_address)) + intermediate_result.push_back( + c_pair[CObjectID, shared_ptr[CRayObject]]( + return_id, shared_ptr[CRayObject]())) + core_worker.store_task_outputs( + worker, [output], + &intermediate_result, + caller_address, + generator_id) + + return intermediate_result.back() + + +cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( + e: Exception, + worker: "Worker", + CTaskType task_type, + const CAddress &caller_address, + TaskID task_id, + const c_string &serialized_retry_exception_allowlist, + function_name, + function_descriptor, + title, + actor, + actor_id, + c_bool *is_retryable_error, + c_string *application_error): + """Create a generator error object. + + This API sets is_retryable_error and application_error, + It also creates and returns a new RayObject that + contains the exception `e`. + + Args: + e: The exception raised from a generator. + worker: The Python worker class inside worker.py + task_type: The type of the task. E.g., actor task, normal task. + caller_address: The address of the caller. By our protocol, + the caller of the streaming generator task is always + the owner, so we can also call it "owner address". + task_id: The task ID of the generator task. + serialized_retry_exception_allowlist: A list of + exceptions that are allowed to retry this generator task. + function_name: The name of the generator function. Used for + writing an error message. + function_descriptor: The function descriptor of + the generator function. Used for writing an error message. + title: The process title of the generator task. Used for + writing an error message. + actor: The instance of the actor created in this worker. + It is used to write an error message. + actor_id: The ID of the actor. It is used to write an error message. + is_retryable_error(out): It is set to True if the generator + raises an exception, and the error is retryable. + application_error(out): It is set if the generator raises an + application error. + + Returns: + A Ray Object that contains the given error exception. + """ + cdef: + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result + CoreWorker core_worker = worker.core_worker + + is_retryable_error[0] = determine_if_retryable( + e, + serialized_retry_exception_allowlist, + function_descriptor, + ) + + if ( + is_retryable_error[0] + and core_worker.get_current_task_retry_exceptions() + ): + logger.debug("Task failed with retryable exception:" + " {}.".format(task_id), exc_info=True) + # Raise an exception directly and halt the execution + # because there's no need to set the exception + # for the return value when the task is retryable. + raise e + + logger.debug("Task failed with unretryable exception:" + " {}.".format(task_id), exc_info=True) + + error_id = (CCoreWorkerProcess.GetCoreWorker() + .AllocateDynamicReturnId(caller_address)) + intermediate_result.push_back( + c_pair[CObjectID, shared_ptr[CRayObject]]( + error_id, shared_ptr[CRayObject]())) + + store_task_errors( + worker, e, + True, # task_exception + actor, # actor + actor_id, # actor id + function_name, task_type, title, + &intermediate_result, application_error, caller_address) + + return intermediate_result.back() + + cdef execute_dynamic_generator_and_store_task_outputs( generator, const CObjectID &generator_id, diff --git a/python/ray/actor.py b/python/ray/actor.py index 91b88de7b947..f3b647916f6e 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -1170,7 +1170,7 @@ def _actor_method_call( elif num_returns == "streaming": # TODO(sang): This is a temporary private API. # Remove it when we migrate to the streaming generator. - num_returns = -2 + num_returns = ray._raylet.STREAMING_GENERATOR_RETURN object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 09d1de01c251..8c191d5cc24a 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -433,3 +433,4 @@ cdef extern from "ray/common/task/task_spec.h" nogil: cdef extern from "ray/common/constants.h" nogil: cdef const char[] kWorkerSetupHookKeyName cdef int kResourceUnitScaling + cdef int kStreamingGeneratorReturn diff --git a/python/ray/includes/common.pxi b/python/ray/includes/common.pxi index ea402ded009e..268817021120 100644 --- a/python/ray/includes/common.pxi +++ b/python/ray/includes/common.pxi @@ -9,6 +9,7 @@ from ray.includes.common cimport ( CPythonGcsPublisher, kWorkerSetupHookKeyName, kResourceUnitScaling, + kStreamingGeneratorReturn, ) @@ -30,3 +31,4 @@ cdef class GcsClientOptions: WORKER_SETUP_HOOK_KEY_NAME_GCS = str(kWorkerSetupHookKeyName) RESOURCE_UNIT_SCALING = kResourceUnitScaling +STREAMING_GENERATOR_RETURN = kStreamingGeneratorReturn diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index bb627f09af92..b8e416f16014 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -309,7 +309,7 @@ def _remote(self, args=None, kwargs=None, **task_options): elif num_returns == "streaming": # TODO(sang): This is a temporary private API. # Remove it when we migrate to the streaming generator. - num_returns = -2 + num_returns = ray._raylet.STREAMING_GENERATOR_RETURN max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] diff --git a/python/ray/tests/test_generators.py b/python/ray/tests/test_generators.py index 3430da39cda2..7d89ff863d44 100644 --- a/python/ray/tests/test_generators.py +++ b/python/ray/tests/test_generators.py @@ -221,7 +221,7 @@ def generator(num_returns, store_in_plasma, counter): @pytest.mark.parametrize("use_actors", [False, True]) @pytest.mark.parametrize("store_in_plasma", [False, True]) -@pytest.mark.parametrize("num_returns_type", ["streaming"]) +@pytest.mark.parametrize("num_returns_type", ["dynamic", "streaming"]) def test_dynamic_generator( ray_start_regular, use_actors, store_in_plasma, num_returns_type ): diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index bf83ecc5189c..00e0e02070fa 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -22,6 +22,8 @@ constexpr int kResourceUnitScaling = 10000; constexpr char kWorkerSetupHookKeyName[] = "FunctionsToRun"; +constexpr int kStreamingGeneratorReturn = -2; + /// Length of Ray full-length IDs in bytes. constexpr size_t kUniqueIDSize = 28; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 0fce1289bb2a..fd47f4fd57ca 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1853,7 +1853,7 @@ void CoreWorker::BuildCommonTaskSpec( // TODO(sang): Remove this and integrate it to // nun_returns == -1 once migrating to streaming // generator. - bool is_streaming_generator = num_returns == -2; + bool is_streaming_generator = num_returns == kStreamingGeneratorReturn; if (is_streaming_generator) { num_returns = 1; // We are using the dynamic return if From 98d7292e29f5ae8329decc714d1d712a87c7fc97 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 19 May 2023 09:10:32 -0700 Subject: [PATCH 31/77] lint Signed-off-by: SangBin Cho --- python/ray/_raylet.pyx | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index f4aa5789f9a2..deb0769c09d9 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -935,7 +935,7 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( caller_address: The address of the caller. By our protocol, the caller of the streaming generator task is always the owner, so we can also call it "owner address". - + Returns: A Ray Object that contains the given output. """ @@ -954,7 +954,7 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( &intermediate_result, caller_address, generator_id) - + return intermediate_result.back() @@ -975,7 +975,7 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( """Create a generator error object. This API sets is_retryable_error and application_error, - It also creates and returns a new RayObject that + It also creates and returns a new RayObject that contains the exception `e`. Args: @@ -1019,15 +1019,17 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( is_retryable_error[0] and core_worker.get_current_task_retry_exceptions() ): - logger.debug("Task failed with retryable exception:" - " {}.".format(task_id), exc_info=True) + logger.debug( + "Task failed with retryable exception:" + " {}.".format(task_id), exc_info=True) # Raise an exception directly and halt the execution # because there's no need to set the exception # for the return value when the task is retryable. raise e - logger.debug("Task failed with unretryable exception:" - " {}.".format(task_id), exc_info=True) + logger.debug( + "Task failed with unretryable exception:" + " {}.".format(task_id), exc_info=True) error_id = (CCoreWorkerProcess.GetCoreWorker() .AllocateDynamicReturnId(caller_address)) From 391eb0f7e3a8a7f5baf8419b3ed1b52efaf18790 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 19 May 2023 19:00:14 -0700 Subject: [PATCH 32/77] addressed Signed-off-by: SangBin Cho --- python/ray/_raylet.pxd | 2 +- python/ray/_raylet.pyx | 66 ++++++++++++++------ python/ray/tests/test_streaming_generator.py | 3 - 3 files changed, 49 insertions(+), 22 deletions(-) diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 38d35bde3ba6..920709f45721 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -159,7 +159,7 @@ cdef class CoreWorker: self, const CAddress &owner_address, const CTaskID &task_id, - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + return_size, generator_index, is_async_actor) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 3ef0abe5b3db..ccfbfb6a6eb4 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -872,7 +872,6 @@ cdef store_task_errors( cdef execute_streaming_generator( generator, - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, const CObjectID &generator_id, CTaskType task_type, const CAddress &caller_address, @@ -884,6 +883,7 @@ cdef execute_streaming_generator( actor, actor_id, name_of_concurrency_group_to_execute, + return_size, c_bool *is_retryable_error, c_string *application_error): """Execute a given generator and streaming-report the @@ -916,6 +916,7 @@ cdef execute_streaming_generator( actor: The instance of the actor created in this worker. It is used to write an error message. actor_id: The ID of the actor. It is used to write an error message. + return_size: The number of static returns. is_retryable_error(out): It is set to True if the generator raises an exception, and the error is retryable. application_error(out): It is set if the generator raises an @@ -924,8 +925,7 @@ cdef execute_streaming_generator( worker = ray._private.worker.global_worker # Generator task should only have 1 return object ref, # which contains None or exceptions (if system error occurs). - assert returns != NULL - assert returns[0].size() == 1 + assert return_size == 1 cdef: CoreWorker core_worker = worker.core_worker @@ -962,6 +962,9 @@ cdef execute_streaming_generator( title, actor, actor_id, + return_size, + generator_index, + is_async, is_retryable_error, application_error ) @@ -979,7 +982,11 @@ cdef execute_streaming_generator( output, generator_id, worker, - caller_address) + caller_address, + task_id, + return_size, + generator_index, + is_async) # Del output here so that we can GC the memory # usage asap. del output @@ -1012,7 +1019,11 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( output, const CObjectID &generator_id, worker: "Worker", - const CAddress &caller_address): + const CAddress &caller_address, + TaskID task_id, + return_size, + generator_index, + is_async): """Create a generator return object based on a given output. Args: @@ -1022,6 +1033,11 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( caller_address: The address of the caller. By our protocol, the caller of the streaming generator task is always the owner, so we can also call it "owner address". + task_id: The task ID of the generator task. + return_size: The number of static returns. + generator_index: The index of a current error object. + is_async: Whether or not the given object is created within + an async actor. Returns: A Ray Object that contains the given output. @@ -1030,9 +1046,13 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_return_obj( c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker - return_id = ( - CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( - caller_address)) + return_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + return_size, + generator_index, + is_async, + ) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) @@ -1057,6 +1077,9 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( title, actor, actor_id, + return_size, + generator_index, + is_async, c_bool *is_retryable_error, c_string *application_error): """Create a generator error object. @@ -1084,6 +1107,10 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( actor: The instance of the actor created in this worker. It is used to write an error message. actor_id: The ID of the actor. It is used to write an error message. + return_size: The number of static returns. + generator_index: The index of a current error object. + is_async: Whether or not the given object is created within + an async actor. is_retryable_error(out): It is set to True if the generator raises an exception, and the error is retryable. application_error(out): It is set if the generator raises an @@ -1096,6 +1123,8 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker + # Generator only has 1 static return. + assert return_size == 1 is_retryable_error[0] = determine_if_retryable( e, serialized_retry_exception_allowlist, @@ -1118,8 +1147,13 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( "Task failed with unretryable exception:" " {}.".format(task_id), exc_info=True) - error_id = (CCoreWorkerProcess.GetCoreWorker() - .AllocateDynamicReturnId(caller_address)) + error_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + return_size, + generator_index, + is_async, + ) intermediate_result.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) @@ -1407,7 +1441,6 @@ cdef void execute_task( execute_streaming_generator( outputs, - returns, returns[0][0].first, # generator object ID. task_type, caller_address, @@ -1419,6 +1452,7 @@ cdef void execute_task( actor, actor_id, name_of_concurrency_group_to_execute, + returns[0].size(), is_retryable_error, application_error) # Streaming generator output is not used, so set it to None. @@ -3697,7 +3731,7 @@ cdef class CoreWorker: self, const CAddress &owner_address, const CTaskID &task_id, - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, + return_size, generator_index, is_async_actor): """Allocate a dynamic return ID for a generator task. @@ -3721,10 +3755,6 @@ cdef class CoreWorker: If async actor is used, we should calculate the put_index ourselves. """ - assert returns != NULL - cdef: - num_returns = returns[0].size() - if is_async_actor: # This part of code has a couple of assumptions. # - This API is not called within an asyncio event loop @@ -3737,7 +3767,7 @@ cdef class CoreWorker: # scoped to a asyncio event loop thread. # This means the execution thread that this API will be called # will only create "return" objects. That means if we use - # num_returns + genreator_index as a put_index, it is guaranteed + # return_size + genreator_index as a put_index, it is guaranteed # to be unique. # # Why do we need it? @@ -3753,7 +3783,7 @@ cdef class CoreWorker: # before it is used. So if you have 1 return object # the next index will be 2. make_optional[ObjectIDIndexType]( - 1 + num_returns + generator_index) # put_index + 1 + return_size + generator_index) # put_index ) else: return CCoreWorkerProcess.GetCoreWorker().AllocateDynamicReturnId( diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 81df9dc71b7e..72f1c5440696 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -46,7 +46,6 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): generator_ref = ray.ObjectRef.from_random() generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - c.create_object_ref_stream.assert_called() # Test when there's no new ref, it returns a nil. mocked_ray_wait.return_value = [], [generator_ref] @@ -73,7 +72,6 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): dumps(generator) del generator - c.delete_object_ref_stream.assert_called() def test_streaming_object_ref_generator_task_failed_unit(mocked_worker): @@ -154,7 +152,6 @@ async def test_streaming_object_ref_generator_unit_async(mocked_worker): generator_ref = ray.ObjectRef.from_random() generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() - c.create_object_ref_stream.assert_called() # Test when there's no new ref, it returns a nil. mocked_ray_wait.return_value = [], [generator_ref] From e49d1a56833909783af4a2a980b32cd34ae87fa8 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 21 May 2023 18:22:27 -0700 Subject: [PATCH 33/77] working Signed-off-by: SangBin Cho --- python/ray/_raylet.pxd | 1 - python/ray/_raylet.pyx | 18 +++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 0a2f1afb17fd..7297547b32df 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -145,7 +145,6 @@ cdef class CoreWorker: worker, outputs, const CAddress &caller_address, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, - const CAddress &caller_address, CObjectID ref_generator_id=*) cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef make_actor_handle(self, ActorHandleSharedPtr c_actor_handle) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 2be797b41cf3..a2c5b923e926 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -140,7 +140,6 @@ from ray.exceptions import ( AsyncioActorExit, PendingCallsLimitExceeded, RpcError, - ObjectRefStreamEoFError, ) from ray._private import external_storage from ray.util.scheduling_strategies import ( @@ -1194,7 +1193,6 @@ cdef execute_dynamic_generator_and_store_task_outputs( worker, generator, caller_address, dynamic_returns, - caller_address, generator_id) except Exception as error: is_retryable_error[0] = determine_if_retryable( @@ -1732,11 +1730,11 @@ cdef execute_task_with_cancellation_handler( actor, actor_id, execution_info.function_name, - task_type, title, caller_address, returns, + task_type, title, caller_address, + returns, # application_error: we are passing NULL since we don't want the # cancel tasks to fail. - NULL, - caller_address) + NULL) finally: with current_task_id_lock: current_task_id = None @@ -3388,7 +3386,6 @@ cdef class CoreWorker: const CAddress &caller_address, c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *returns, - const CAddress &caller_address, CObjectID ref_generator_id=CObjectID.Nil()): cdef: CObjectID return_id @@ -3584,15 +3581,6 @@ cdef class CoreWorker: eventloop, async_thread = self.get_event_loop( function_descriptor, specified_cgname) - # Increase recursion limit if necessary. In asyncio mode, - # we have many parallel callstacks (represented in fibers) - # that's suspended for execution. Python interpreter will - # mistakenly count each callstack towards recusion limit. - # We don't need to worry about stackoverflow here because - # the max number of callstacks is limited in direct actor - # transport with max_concurrency flag. - increase_recursion_limit() - if inspect.isawaitable(func_or_coro): coroutine = func_or_coro else: From ce1e79d874f77cec87fecaeaf24591e6d8453278 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sun, 21 May 2023 18:27:15 -0700 Subject: [PATCH 34/77] [Revert] Add more complicated tests Signed-off-by: SangBin Cho --- python/ray/_private/ray_perf.py | 129 ------------------- python/ray/tests/test_streaming_generator.py | 82 ------------ 2 files changed, 211 deletions(-) diff --git a/python/ray/_private/ray_perf.py b/python/ray/_private/ray_perf.py index 219dd00833a7..cbff4bb69aca 100644 --- a/python/ray/_private/ray_perf.py +++ b/python/ray/_private/ray_perf.py @@ -35,55 +35,6 @@ async def small_value_batch(self, n): await asyncio.wait([small_value.remote() for _ in range(n)]) -@ray.remote -class AsyncGeneratorActor: - async def small_value(self): - yield b"ok" - - async def small_value_with_arg(self, x): - yield b"ok" - - async def small_value_batch(self, n): - yield await asyncio.wait([small_value.remote() for _ in range(n)]) - - -@ray.remote(num_cpus=0) -class GeneratorClient: - def __init__(self, servers): - if not isinstance(servers, list): - servers = [servers] - self.servers = servers - - def small_value_batch(self, n): - results = [] - for s in self.servers: - results.extend( - [ - s.small_value.options(num_returns="streaming").remote() - for _ in range(n) - ] - ) - refs = [] - for gen in results: - refs.extend(list(gen)) - ray.get(refs) - - def small_value_batch_arg(self, n): - x = ray.put(0) - results = [] - for s in self.servers: - results.extend( - [ - s.small_value_arg.options(num_returns="streaming").remote(x) - for _ in range(n) - ] - ) - refs = [] - for gen in results: - refs.extend(list(gen)) - ray.get(refs) - - @ray.remote(num_cpus=0) class Client: def __init__(self, servers): @@ -330,86 +281,6 @@ def async_actor_multi(): results += timeit("n:n async-actor calls async", async_actor_multi, m * n) - """ - Async generator actor - """ - a = AsyncGeneratorActor.remote() - - def actor_sync_generator(): - ray.get(list(a.small_value.options(num_returns="streaming").remote())) - - results += timeit("1:1 async-actor-generator calls sync", actor_sync_generator) - - a = AsyncGeneratorActor.options().remote() - - def async_actor_generator(): - gens = [ - a.small_value.options(num_returns="streaming").remote() for _ in range(1000) - ] - refs = [] - for gen in gens: - refs.extend(list(gen)) - ray.get(refs) - - results += timeit( - "1:1 async-actor-generator calls async", async_actor_generator, 1000 - ) - - a = AsyncGeneratorActor.remote() - - def async_actor_generator(): - gens = [ - a.small_value_with_arg.options(num_returns="streaming").remote(i) - for i in range(1000) - ] - refs = [] - for gen in gens: - refs.extend(list(gen)) - ray.get(refs) - - results += timeit( - "1:1 async-actor-generator calls with args async", async_actor_generator, 1000 - ) - - n = 5000 - n_cpu = multiprocessing.cpu_count() // 2 - actors = [AsyncGeneratorActor.remote() for _ in range(n_cpu)] - client = GeneratorClient.remote(actors) - - def async_actor_async_generator(): - ref = client.small_value_batch.remote(n) - ray.get(ref) - - results += timeit( - "1:n async-actor-generator calls async", - async_actor_async_generator, - n * len(actors), - ) - - n = 5000 - m = 4 - n_cpu = multiprocessing.cpu_count() // 2 - a = [AsyncGeneratorActor.remote() for _ in range(n_cpu)] - - @ray.remote - def async_actor_work_generator(actors): - gens = [ - actors[i % n_cpu].small_value.options(num_returns="streaming").remote() - for i in range(n) - ] - refs = [] - for gen in gens: - refs.extend(list(gen)) - ray.get(refs) - - def async_actor_multi_generator(): - ray.get([async_actor_work_generator.remote(a) for _ in range(m)]) - - results += timeit( - "n:n async-actor-generator calls async", async_actor_multi_generator, m * n - ) - ray.shutdown() - NUM_PGS = 100 NUM_BUNDLES = 1 ray.init(resources={"custom": 100}) diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 865b23d9e338..6795c18d0567 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -657,88 +657,6 @@ async def main(): print(summary) -def test_ray_serve_like_generator_stress_test(ray_start_cluster, monkeypatch): - """Mock the stressful Ray Serve workloads. - - Ray Serve has a single actor that invokes many generator tasks. - All the actors are async actor for Ray Serve. - """ - with monkeypatch.context() as m: - # Add a 10ms ~ 1 second delay to the RPC. - m.setenv( - "RAY_testing_asio_delay_us", - "CoreWorkerService.grpc_server.ReportGeneratorItemReturns=10000:1000000", - ) - - cluster = ray_start_cluster - total_cpus = 20 - # 5 nodes cluster, 4 CPUs each. - cluster.add_node(num_cpus=total_cpus // 5) - ray.init() - for _ in range(4): - cluster.add_node(num_cpus=total_cpus // 5) - - @ray.remote(num_cpus=1) - class ProxyActor: - async def get_data(self, child): - await asyncio.sleep(0.1) - gen = child.get_data.options(num_returns="streaming").remote() - async for ref in gen: - yield ref - del ref - - @ray.remote - class ChainActor: - def __init__(self, child=None): - self.child = child - - async def get_data(self): - if not self.child: - for i in range(10): - await asyncio.sleep(0.1) - yield np.ones(5 * 1024) * i - else: - async for ref in self.child.get_data.options( - num_returns="streaming" - ).remote(): - yield ref - - chain_actors = [] - num_chain_actors = 16 - for _ in range(num_chain_actors): - chain_actor = ChainActor.remote() - chain_actor_2 = ChainActor.remote(chain_actor) - chain_actor_3 = ChainActor.remote(chain_actor_2) - chain_actor_4 = ChainActor.remote(chain_actor_3) - chain_actors.append(chain_actor_4) - - proxy_actor = ProxyActor.remote() - - async def get_stream(proxy_actor, chain_actor): - i = 0 - async for ref in proxy_actor.get_data.options( - num_returns="streaming" - ).remote(chain_actor): - for _ in range(5): - ref = await ref - assert np.array_equal(np.ones(5 * 1024) * i, ref) - del ref - i += 1 - - async def main(): - await asyncio.gather( - *[get_stream(proxy_actor, chain_actor) for chain_actor in chain_actors] - ) - result = list_objects(raise_on_missing_output=False) - ref_types = set() - for r in result: - ref_types.add(r.reference_type) - # Verify no leaks - assert ref_types == {"ACTOR_HANDLE"} - - asyncio.run(main()) - - if __name__ == "__main__": import os From 6c0448b2720ab50ebc1ebaed18b380b4de94bc90 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 23 May 2023 03:24:50 -0700 Subject: [PATCH 35/77] Addressed code review. Signed-off-by: SangBin Cho --- python/ray/_private/async_compat.py | 1 + python/ray/_raylet.pyx | 49 ++++++++++++++----- python/ray/includes/common.pxd | 4 +- python/ray/tests/test_async.py | 20 ++++++++ python/ray/tests/test_streaming_generator.py | 40 +++++++++++++-- src/ray/common/status.h | 11 +++-- src/ray/core_worker/context.cc | 23 +++++++++ src/ray/core_worker/context.h | 20 ++++++++ src/ray/core_worker/core_worker.cc | 17 +------ src/ray/core_worker/core_worker.h | 8 +-- src/ray/core_worker/task_manager.cc | 4 +- src/ray/core_worker/task_manager.h | 5 +- src/ray/core_worker/test/core_worker_test.cc | 33 +++++++++++++ src/ray/core_worker/test/task_manager_test.cc | 12 ++--- 14 files changed, 195 insertions(+), 52 deletions(-) diff --git a/python/ray/_private/async_compat.py b/python/ray/_private/async_compat.py index 7821c6424d2f..2e3b03aca623 100644 --- a/python/ray/_private/async_compat.py +++ b/python/ray/_private/async_compat.py @@ -20,6 +20,7 @@ def get_new_event_loop(): def is_async_func(func): + """Return True if the function is an async or async generator method.""" return inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a2c5b923e926..6c4e101e7c76 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -23,6 +23,7 @@ import time import traceback import _thread import typing +from typing import Union, Awaitable, Callable, Any from libc.stdint cimport ( int32_t, @@ -187,6 +188,8 @@ current_task_id_lock = threading.Lock() job_config_initialized = False job_config_initialization_lock = threading.Lock() +# It is used to indicate optional::nullopt for +# AllocateDynamicReturnId. cdef optional[ObjectIDIndexType] NULL_PUT_INDEX = nullopt @@ -205,7 +208,7 @@ class ObjectRefGenerator: return len(self._refs) -class ObjectRefStreamEoFError(RayError): +class ObjectRefStreamEneOfStreamError(RayError): pass @@ -283,6 +286,7 @@ class StreamingObjectRefGenerator: # a system error. while obj.is_nil(): error_ref = self._handle_error( + False, last_time, timeout_s, unexpected_network_failure_timeout_s) @@ -308,6 +312,7 @@ class StreamingObjectRefGenerator: # a system error. while obj.is_nil(): error_ref = self._handle_error( + True, last_time, timeout_s, unexpected_network_failure_timeout_s) @@ -322,16 +327,23 @@ class StreamingObjectRefGenerator: async def _handle_next_async(self): try: return self._handle_next() - except ObjectRefStreamEoFError: + except ObjectRefStreamEneOfStreamError: raise StopAsyncIteration def _handle_next_sync(self): try: return self._handle_next() - except ObjectRefStreamEoFError: + except ObjectRefStreamEneOfStreamError: raise StopIteration def _handle_next(self): + """Get the next item from the ObjectRefStream. + + This API return immediately all the time. It returns a nil object + if it doesn't have the next item ready. It raises + ObjectRefStreamEneOfStreamError if there's nothing more to read. + If there's a next item, it will return a object ref. + """ if hasattr(self.worker, "core_worker"): obj = self.worker.core_worker.try_read_next_object_ref_stream( self._generator_ref) @@ -343,6 +355,7 @@ class StreamingObjectRefGenerator: def _handle_error( self, + is_async: bool, last_time: int, timeout_s: float, unexpected_network_failure_timeout_s: float): @@ -355,7 +368,10 @@ class StreamingObjectRefGenerator: # The generator task has failed already. # We raise StopIteration # to conform the next interface in Python. - raise StopIteration + if is_async: + raise StopAsyncIteration + else: + raise StopIteration else: # Otherwise, we should ray.get on the generator # ref to find if the task has a system failure. @@ -422,8 +438,8 @@ cdef int check_status(const CRayStatus& status) nogil except -1: raise ObjectStoreFullError(message) elif status.IsOutOfDisk(): raise OutOfDiskError(message) - elif status.IsObjectRefStreamEoF(): - raise ObjectRefStreamEoFError(message) + elif status.IsObjectRefEndOfStream(): + raise ObjectRefStreamEneOfStreamError(message) elif status.IsInterrupted(): raise KeyboardInterrupt() elif status.IsTimedOut(): @@ -1124,8 +1140,6 @@ cdef c_pair[CObjectID, shared_ptr[CRayObject]] create_generator_error_object( c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker - # Generator only has 1 static return. - assert return_size == 1 is_retryable_error[0] = determine_if_retryable( e, serialized_retry_exception_allowlist, @@ -3562,9 +3576,21 @@ cdef class CoreWorker: return self.eventloop_for_default_cg, self.thread_for_default_cg def run_async_func_or_coro_in_event_loop( - self, func_or_coro, function_descriptor, specified_cgname, *args, **kwargs): + self, + func_or_coro: Union[Callable[[Any, Any], Awaitable[Any]], Awaitable], + function_descriptor: FunctionDescriptor, + specified_cgname: str, + *args, + **kwargs): """Run the async function or coroutine to the event loop. The event loop is running in a separate thread. + + Args: + func_or_coro: Async function (not a generator) or awaitable objects. + function_descriptor: The function descriptor. + specified_cgname: The name of a concurrent group. + args: The arguments for the async function. + kwargs: The keyword arguments for the async function. """ cdef: CFiberEvent event @@ -3740,14 +3766,15 @@ cdef class CoreWorker: owner_address: The address of the owner (caller) of the generator task. task_id: The task ID of the generator task. - returns: A list of return objects. This is used to - calculate the number of return values. + return_size: The size of the static return from the task. generator_index: The index of dynamically generated object ref. is_async_actor: True if the allocation is for async actor. If async actor is used, we should calculate the put_index ourselves. """ + # Generator only has 1 static return. + assert return_size == 1 if is_async_actor: # This part of code has a couple of assumptions. # - This API is not called within an asyncio event loop diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 8c191d5cc24a..535412f4ab50 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -100,7 +100,7 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: CRayStatus NotFound() @staticmethod - CRayStatus ObjectRefStreamEoF() + CRayStatus ObjectRefEndOfStream() c_bool ok() c_bool IsOutOfMemory() @@ -121,7 +121,7 @@ cdef extern from "ray/common/status.h" namespace "ray" nogil: c_bool IsObjectUnknownOwner() c_bool IsRpcError() c_bool IsOutOfResource() - c_bool IsObjectRefStreamEoF() + c_bool IsObjectRefEndOfStream() c_string ToString() c_string CodeAsString() diff --git a/python/ray/tests/test_async.py b/python/ray/tests/test_async.py index 21fa6a026c31..5136fb2bb593 100644 --- a/python/ray/tests/test_async.py +++ b/python/ray/tests/test_async.py @@ -8,6 +8,7 @@ import pytest import ray +from ray._private.async_compat import is_async_func from ray._private.test_utils import wait_for_condition from ray._private.utils import ( get_or_create_event_loop, @@ -33,6 +34,25 @@ def f(n): return [f.remote(i) for i in range(5)] +def test_is_async_func(): + def f(): + return 1 + + def f_gen(): + yield 1 + + async def g(): + return 1 + + async def g_gen(): + yield 1 + + assert is_async_func(f) is False + assert is_async_func(f_gen) is False + assert is_async_func(g) is True + assert is_async_func(g_gen) is True + + def test_simple(init): @ray.remote def f(): diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 6795c18d0567..68b0c6ba5ed3 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -10,7 +10,7 @@ import ray from ray._private.test_utils import wait_for_condition from ray.experimental.state.api import list_objects -from ray._raylet import StreamingObjectRefGenerator, ObjectRefStreamEoFError +from ray._raylet import StreamingObjectRefGenerator, ObjectRefStreamEneOfStreamError from ray.cloudpickle import dumps from ray.exceptions import WorkerCrashedError @@ -60,8 +60,8 @@ def test_streaming_object_ref_generator_basic_unit(mocked_worker): assert new_ref == ref # When try_read_next_object_ref_stream raises a - # ObjectRefStreamEoFError, it should raise a stop iteration. - c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError( + # ObjectRefStreamEneOfStreamError, it should raise a stop iteration. + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEneOfStreamError( "" ) # noqa with pytest.raises(StopIteration): @@ -166,14 +166,44 @@ async def test_streaming_object_ref_generator_unit_async(mocked_worker): assert new_ref == ref # When try_read_next_object_ref_stream raises a - # ObjectRefStreamEoFError, it should raise a stop iteration. - c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEoFError( + # ObjectRefStreamEneOfStreamError, it should raise a stop iteration. + c.try_read_next_object_ref_stream.side_effect = ObjectRefStreamEneOfStreamError( "" ) # noqa with pytest.raises(StopAsyncIteration): ref = await generator._next_async(timeout_s=0) +@pytest.mark.asyncio +async def test_async_ref_generator_task_failed_unit(mocked_worker): + """ + Verify when a task is failed by a system error, + the generator ref is returned. + """ + with patch("ray.get") as mocked_ray_get: + with patch("ray.wait") as mocked_ray_wait: + c = mocked_worker.core_worker + generator_ref = ray.ObjectRef.from_random() + generator = StreamingObjectRefGenerator(generator_ref, mocked_worker) + + # Simulate the worker failure happens. + mocked_ray_wait.return_value = [generator_ref], [] + mocked_ray_get.side_effect = WorkerCrashedError() + + c.try_read_next_object_ref_stream.return_value = ray.ObjectRef.nil() + ref = await generator._next_async(timeout_s=0) + # If the generator task fails by a systsem error, + # meaning the ref will raise an exception + # it should be returned. + assert ref == generator_ref + + # Once exception is raised, it should always + # raise stopIteration regardless of what + # the ref contains now. + with pytest.raises(StopAsyncIteration): + ref = await generator._next_async(timeout_s=0) + + def test_generator_basic(shutdown_only): ray.init(num_cpus=1) diff --git a/src/ray/common/status.h b/src/ray/common/status.h index 25d9befdfd08..cfbcff3dfc89 100644 --- a/src/ray/common/status.h +++ b/src/ray/common/status.h @@ -115,7 +115,8 @@ enum class StatusCode : char { ObjectUnknownOwner = 29, RpcError = 30, OutOfResource = 31, - ObjectRefStreamEoF = 32 + // Meaning the ObjectRefStream reaches to the end of stream. + ObjectRefEndOfStream = 32 }; #if defined(__clang__) @@ -147,8 +148,8 @@ class RAY_EXPORT Status { return Status(StatusCode::KeyError, msg); } - static Status ObjectRefStreamEoF(const std::string &msg) { - return Status(StatusCode::ObjectRefStreamEoF, msg); + static Status ObjectRefEndOfStream(const std::string &msg) { + return Status(StatusCode::ObjectRefEndOfStream, msg); } static Status TypeError(const std::string &msg) { @@ -259,7 +260,9 @@ class RAY_EXPORT Status { bool IsOutOfMemory() const { return code() == StatusCode::OutOfMemory; } bool IsOutOfDisk() const { return code() == StatusCode::OutOfDisk; } bool IsKeyError() const { return code() == StatusCode::KeyError; } - bool IsObjectRefStreamEoF() const { return code() == StatusCode::ObjectRefStreamEoF; } + bool IsObjectRefEndOfStream() const { + return code() == StatusCode::ObjectRefEndOfStream; + } bool IsInvalid() const { return code() == StatusCode::Invalid; } bool IsIOError() const { return code() == StatusCode::IOError; } bool IsTypeError() const { return code() == StatusCode::TypeError; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 125f42d17e39..7715d9636851 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -363,6 +363,29 @@ bool WorkerContext::CurrentActorDetached() const { return is_detached_actor_; } +const ObjectID WorkerContext::GetGeneratorReturnId( + const TaskID &task_id, std::optional put_index) { + TaskID current_task_id; + // We only allow to specify both task id and put index or not specifying both. + RAY_CHECK((task_id.IsNil() && !put_index.has_value()) || + (!task_id.IsNil() || put_index.has_value())); + if (task_id.IsNil()) { + const auto &task_spec = GetCurrentTask(); + current_task_id = task_spec->TaskId(); + } else { + current_task_id = task_id; + } + + ObjectIDIndexType current_put_index; + if (!put_index.has_value()) { + current_put_index = GetNextPutIndex(); + } else { + current_put_index = put_index.value(); + } + + return ObjectID::FromIndex(current_task_id, current_put_index); +} + WorkerThreadContext &WorkerContext::GetThreadContext() const { if (thread_context_ == nullptr) { absl::ReaderMutexLock lock(&mutex_); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 692063900526..b7d2d50e7260 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -31,6 +31,26 @@ class WorkerContext { public: WorkerContext(WorkerType worker_type, const WorkerID &worker_id, const JobID &job_id); + // Return the generator return ID. + /// + /// By default, it deduces a generator return ID from a current task + /// from the context. However, it also supports manual specification of + /// put index and task id to support `AllocateDynamicReturnId`. + /// See the docstring of AllocateDynamicReturnId for more details. + /// + /// The caller should either not specify both task_id AND put_index + /// or specify both at the same time. Otherwise it will panic. + /// + /// \param[in] task_id The task id of the dynamically generated return ID. + /// If Nil() is specified, it will deduce the Task ID from the current + /// worker context. + /// \param[in] put_index The equivalent of the return value of + /// WorkerContext::GetNextPutIndex. + /// If std::nullopt is specified, it will deduce the put index from the + /// current worker context. + const ObjectID GetGeneratorReturnId(const TaskID &task_id, + std::optional put_index); + const WorkerType GetWorkerType() const; const WorkerID &GetWorkerID() const; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 914ea280fb46..327a04c671a2 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2873,22 +2873,7 @@ bool CoreWorker::PinExistingReturnObject(const ObjectID &return_id, ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, const TaskID &task_id, std::optional put_index) { - TaskID current_task_id; - if (task_id.IsNil()) { - const auto &task_spec = worker_context_.GetCurrentTask(); - current_task_id = task_spec->TaskId(); - } else { - current_task_id = task_id; - } - - ObjectIDIndexType current_put_index; - if (!put_index.has_value()) { - current_put_index = worker_context_.GetNextPutIndex(); - } else { - current_put_index = put_index.value(); - } - - const auto return_id = ObjectID::FromIndex(current_task_id, current_put_index); + const auto return_id = worker_context_.GetGeneratorReturnId(task_id, put_index); AddLocalReference(return_id, ""); reference_counter_->AddBorrowedObject(return_id, ObjectID::Nil(), owner_address); return return_id; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index f3a4b5f33cd8..574ea8b69a95 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -370,14 +370,16 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void CreateObjectRefStream(const ObjectID &generator_id); /// Read the next index of a ObjectRefStream of generator_id. + /// This API always return immediately. /// /// \param[in] generator_id The object ref id of the streaming /// generator task. /// \param[out] object_ref_out The ObjectReference /// that the caller can convert to its own ObjectRef. /// The current process is always the owner of the - /// generated ObjectReference. - /// \return Status RayKeyError if the stream reaches to EoF. + /// generated ObjectReference. It will be Nil() if there's + /// no next item. + /// \return Status ObjectRefEndOfStream if the stream reaches to EoF. /// OK otherwise. Status TryReadObjectRefStream(const ObjectID &generator_id, rpc::ObjectReference *object_ref_out); @@ -1025,8 +1027,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// object to the task caller and have the resulting ObjectRef be owned by /// the caller. This is in contrast to static allocation, where the caller /// decides at task invocation time how many returns the task should have. - /// \param[in] owner_address The address of the owner who will own this - /// dynamically generated object. /// /// NOTE: Normally task_id and put_index it not necessary to be specified /// because we can obtain them from the global worker context. However, diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 33f14580d5e0..07b09b6feb5c 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -38,7 +38,7 @@ Status ObjectRefStream::TryReadNextItem(ObjectID *object_id_out) { RAY_LOG(DEBUG) << "ObjectRefStream of an id " << generator_id_ << " has no more objects."; *object_id_out = ObjectID::Nil(); - return Status::ObjectRefStreamEoF(""); + return Status::ObjectRefEndOfStream(""); } auto it = item_index_to_refs_.find(next_index_); @@ -393,7 +393,7 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { const auto &status = TryReadObjectRefStreamInternal(generator_id, &object_id); // keyError means the stream reaches to EoF. - if (status.IsObjectRefStreamEoF()) { + if (status.IsObjectRefEndOfStream()) { break; } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 0c19f48a9352..6af64dc26e5d 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -256,8 +256,9 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// generator task. bool ObjectRefStreamExists(const ObjectID &generator_id); - /// Asynchronously read object reference of the next index from the + /// Read object reference of the next index from the /// object stream of a generator_id. + /// This API always return immediately. /// /// The caller should ensure the ObjectRefStream is already created /// via CreateObjectRefStream. @@ -266,7 +267,7 @@ class TaskManager : public TaskFinisherInterface, public TaskResubmissionInterfa /// /// \param[out] object_id_out The next object ID from the stream. /// Nil ID is returned if the next index hasn't been written. - /// \return KeyError if it reaches to EoF. Ok otherwise. + /// \return ObjectRefEndOfStream if it reaches to EoF. Ok otherwise. Status TryReadObjectRefStream(const ObjectID &generator_id, ObjectID *object_id_out); /// Returns true if task can be retried. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 62dd91f4474b..31b6089d854c 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -684,6 +684,39 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { ASSERT_EQ(context.GetNextPutIndex(), num_returns + 2); } +TEST_F(ZeroNodeTest, TestWorkerContextGeneratorReturn) { + auto job_id = NextJobId(); + + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), job_id); + TaskSpecification task_spec; + size_t num_returns = 1; + task_spec.GetMutableMessage().set_job_id(job_id.Binary()); + task_spec.GetMutableMessage().set_num_returns(num_returns); + context.ResetCurrentTask(); + context.SetCurrentTask(task_spec); + ASSERT_EQ(context.GetCurrentTaskID(), task_spec.TaskId()); + ; + + // Verify when task ID is nil and put index is nullopt, + // it deduces the next return ID from the current context. + auto return_id = context.GetGeneratorReturnId(TaskID::Nil(), std::nullopt); + ASSERT_EQ(return_id.TaskId(), context.GetCurrentTaskID()); + ASSERT_EQ(return_id, ObjectID::FromIndex(context.GetCurrentTaskID(), 2)); + auto return_id2 = context.GetGeneratorReturnId(TaskID::Nil(), std::nullopt); + ASSERT_EQ(return_id2.TaskId(), context.GetCurrentTaskID()); + ASSERT_EQ(return_id2, ObjectID::FromIndex(context.GetCurrentTaskID(), 3)); + + // Verify manual specification of put index and taskId. + auto task_id = TaskID::FromRandom(job_id); + auto put_index = 1; + return_id = context.GetGeneratorReturnId(task_id, put_index); + ASSERT_EQ(return_id.TaskId(), task_id); + ASSERT_EQ(return_id, ObjectID::FromIndex(task_id, put_index)); + // Although we repeat, it should return the same value. + return_id = context.GetGeneratorReturnId(task_id, put_index); + ASSERT_EQ(return_id, ObjectID::FromIndex(task_id, put_index)); +} + TEST_F(ZeroNodeTest, TestActorHandle) { // Test actor handle serialization and deserialization round trip. JobID job_id = NextJobId(); diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index 6f163f352f11..3151fc629e27 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -1268,7 +1268,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamBasic) { } // READ (EoF) auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE manager_.DelObjectRefStream(generator_id); @@ -1315,13 +1315,13 @@ TEST_F(TaskManagerTest, TestObjectRefStreamMixture) { ObjectID obj_id; // READ (EoF) auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE manager_.DelObjectRefStream(generator_id); } -TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { +TEST_F(TaskManagerTest, TestObjectRefEndOfStream) { /** * Test that after writing EoF, write/read doesn't work. * CREATE WRITE WRITEEoF, WRITE(verify no op) DELETE @@ -1364,7 +1364,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEoF) { ASSERT_TRUE(manager_.HandleReportGeneratorItemReturns(req)); // READ (doesn't works because EoF is already written) status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); } TEST_F(TaskManagerTest, TestObjectRefStreamIndexDiscarded) { @@ -1529,7 +1529,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) { // Nothing more to read. status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); manager_.DelObjectRefStream(generator_id); } @@ -1670,7 +1670,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { // READ (EoF) auto status = manager_.TryReadObjectRefStream(generator_id, &obj_id); - ASSERT_TRUE(status.IsObjectRefStreamEoF()); + ASSERT_TRUE(status.IsObjectRefEndOfStream()); ASSERT_EQ(obj_id, ObjectID::Nil()); // DELETE manager_.DelObjectRefStream(generator_id); From a827426509a592f2b5520814c6cb24db100242d8 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 11:30:31 -0500 Subject: [PATCH 36/77] WIP Signed-off-by: Edward Oakes --- python/ray/serve/_private/client.py | 3 + python/ray/serve/_private/http_proxy.py | 56 ++++++++++++- python/ray/serve/_private/http_util.py | 29 ++++++- python/ray/serve/_private/replica.py | 106 ++++++++++++++++++------ python/ray/serve/_private/router.py | 83 ++++++++++++++----- 5 files changed, 225 insertions(+), 52 deletions(-) diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index fd4b9630e9aa..c1d61497616b 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -440,6 +440,7 @@ def get_handle( missing_ok: Optional[bool] = False, sync: bool = True, _internal_pickled_http_request: bool = False, + _use_ray_streaming: bool = False, ) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service deployment to invoke it from Python. @@ -469,12 +470,14 @@ def get_handle( self._controller, deployment_name, _internal_pickled_http_request=_internal_pickled_http_request, + _use_ray_streaming=_use_ray_streaming, ) else: handle = RayServeHandle( self._controller, deployment_name, _internal_pickled_http_request=_internal_pickled_http_request, + _use_ray_streaming=_use_ray_streaming, ) self.handle_cache[cache_key] = handle diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index b35fa1507c02..2d1e49753e1b 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -12,6 +12,7 @@ import uvicorn import starlette.responses import starlette.routing +from starlette.types import Receive, Scope, Send import ray from ray.exceptions import RayActorError, RayTaskError @@ -29,9 +30,10 @@ from ray.serve._private.common import EndpointInfo, EndpointTag, ApplicationName from ray.serve._private.constants import ( SERVE_LOGGER_NAME, + SERVE_MULTIPLEXED_MODEL_ID, SERVE_NAMESPACE, DEFAULT_LATENCY_BUCKET_MS, - SERVE_MULTIPLEXED_MODEL_ID, + RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, ) from ray.serve._private.long_poll import LongPollClient, LongPollNamespace from ray.serve._private.logging_utils import ( @@ -72,6 +74,53 @@ ) +async def _handle_streaming_response( + asgi_response_generator: ray._raylet.StreamingObjectRefGenerator, + scope: Scope, + receive: Receive, + send: Send, +) -> str: + """Consumes the `asgi_response_generator` and sends its data over `send`. + + This function is essentially a proxy for a downstream ASGI response. The + passed generator is expected to return a stream of pickled ASGI messages + (dictionaries) that will be passed to the provided ASGI send interface. + + Exception handling depends on if the first message has already been sent: + - if an exception happens *before* the first message, a 500 status will be sent. + - if an exception happens *after* the first message, the response stream will be + terminated. + + This is because once the first message has been sent, the client has already + received the status code. + + Returns: + status_code + """ + + status_code = "" + try: + async for obj_ref in asgi_response_generator: + asgi_message = pickle.loads(await obj_ref) + if asgi_message["type"] == "http.response.start": + status_code = str(asgi_message["status"]) + + await send(asgi_message) + except Exception as e: + error_message = "Unexpected error, traceback: {}.".format(e) + logger.warning(error_message) + + if status_code == "": + # If first message hasn't been sent, return 500 status. + await Response(error_message, status_code=500).send(scope, receive, send) + return "500" + else: + # If first message has been sent, terminate the response stream. + return status_code + + return status_code + + async def _send_request_to_handle(handle, scope, receive, send) -> str: http_body_bytes = await receive_http_body(scope, receive, send) @@ -112,6 +161,9 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str: try: object_ref = await assignment_task + if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): + return await _handle_streaming_response(object_ref, send) + # NOTE (shrekris-anyscale): when the gcs, Serve controller, and # some replicas crash simultaneously (e.g. if the head node crashes), # requests to the dead replicas hang until the gcs recovers. @@ -277,6 +329,7 @@ def get_handle(name): sync=False, missing_ok=True, _internal_pickled_http_request=True, + _use_ray_streaming=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, ) self.prefix_router = LongestPrefixRouter(get_handle) @@ -439,7 +492,6 @@ async def __call__(self, scope, receive, send): ray.serve.context.RequestContext(**request_context_info) ) status_code = await _send_request_to_handle(handle, scope, receive, send) - self.request_counter.inc( tags={ "route": route_path, diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index 7eac01e2459f..d63d7252621c 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -4,10 +4,9 @@ import inspect import json import logging -from typing import Any, Dict, Type +from typing import Any, Dict, Generator, Type -import starlette.responses -import starlette.requests +from starlette.requests import Request from starlette.types import Send, ASGIApp from fastapi.encoders import jsonable_encoder @@ -48,7 +47,7 @@ async def mock_receive(): received = True return {"body": serialized_body, "type": "http.request", "more_body": False} - return starlette.requests.Request(scope, mock_receive) + return Request(scope, mock_receive) class Response: @@ -159,6 +158,28 @@ def build_asgi_response(self) -> RawASGIResponse: return RawASGIResponse(self.messages) +class ASGIHTTPQueueSender(Send, asyncio.Queue): + """TODO: doc and better name""" + + def __init__(self): + self._message_queue = asyncio.Queue() + self._new_message_event = asyncio.Event() + + async def __call__(self, message: Dict[str, Any]): + assert message["type"] in ("http.response.start", "http.response.body") + await self._message_queue.put(message) + self._new_message_event.set() + + def get_messages_nowait(self) -> Generator[Dict[str, Any], None, None]: + while not self._message_queue.empty(): + yield self._message_queue.get_nowait() + + self._new_message_event.clear() + + async def wait_for_message(self): + await self._new_message_event.wait() + + def make_fastapi_class_based_view(fastapi_app, cls: Type) -> None: """Transform the `cls`'s methods and class annotations to FastAPI routes. diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 09dfef537335..806d771b5375 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -6,18 +6,20 @@ import os import pickle import time -from typing import Any, Callable, Optional, Tuple, Dict +from typing import Any, AsyncGenerator, Callable, Optional, Tuple, Dict import traceback import starlette.responses +from starlette.types import Send import ray from ray import cloudpickle from ray.actor import ActorClass, ActorHandle from ray.remote_function import RemoteFunction -from ray.serve import metrics from ray._private.async_compat import sync_to_async +from ray._private.utils import get_or_create_event_loop +from ray.serve import metrics from ray.serve._private.common import ( HEALTH_CHECK_CONCURRENCY_GROUP, ReplicaTag, @@ -33,7 +35,7 @@ ) from ray.serve.deployment import Deployment from ray.serve.exceptions import RayServeException -from ray.serve._private.http_util import ASGIHTTPSender +from ray.serve._private.http_util import ASGIHTTPSender, ASGIHTTPQueueSender from ray.serve._private.logging_utils import ( access_log_msg, configure_component_logger, @@ -85,6 +87,8 @@ async def __init__( component_id=replica_tag, ) + self._event_loop = get_or_create_event_loop() + deployment_def = cloudpickle.loads(serialized_deployment_def) if isinstance(deployment_def, str): @@ -194,13 +198,42 @@ async def handle_request( *request_args, **request_kwargs, ): - # The request metadata should be pickled for performance. - request_metadata: RequestMetadata = pickle.loads(pickled_request_metadata) - - # Directly receive input because it might contain an ObjectRef. - query = Query(request_args, request_kwargs, request_metadata) + query = Query( + request_args, + request_kwargs, + pickle.loads(pickled_request_metadata), + ) return await self.replica.handle_request(query) + async def handle_request_streaming( + self, + pickled_request_metadata: bytes, + *request_args, + **request_kwargs, + ) -> AsyncGenerator[Dict[str, Any], None]: + """TODO""" + query = Query( + request_args, + request_kwargs, + pickle.loads(pickled_request_metadata), + ) + + asgi_queue_sender = ASGIHTTPQueueSender() + handle_request_task = self._event_loop.create_task( + self.replica.handle_request(query, asgi_sender=asgi_queue_sender) + ) + while not handle_request_task.done(): + done, pending = await asyncio.wait( + [handle_request_task, asgi_queue_sender.wait_for_message()], + return_when=asyncio.FIRST_COMPLETED, + ) + for msg in asgi_queue_sender.get_messages_nowait(): + yield pickle.dumps(msg) + + e = handle_request_task.exception() + if e is not None: + raise e from None + async def handle_request_from_java( self, proto_request_metadata: bytes, @@ -377,15 +410,18 @@ async def check_health(self): await self.user_health_check() def _get_handle_request_stats(self) -> Optional[Dict[str, int]]: + replica_actor_name = _format_replica_actor_name(self.deployment_name) actor_stats = ray.runtime_context.get_runtime_context()._get_actor_call_stats() - method_stat = actor_stats.get( - f"{_format_replica_actor_name(self.deployment_name)}.handle_request" + method_stat = actor_stats.get(f"{replica_actor_name}.handle_request") + streaming_method_stat = actor_stats.get( + f"{replica_actor_name}.handle_request_streaming" ) method_stat_java = actor_stats.get( - f"{_format_replica_actor_name(self.deployment_name)}" - f".handle_request_from_java" + f"{replica_actor_name}.handle_request_from_java" + ) + return merge_dict( + merge_dict(method_stat, streaming_method_stat), method_stat_java ) - return merge_dict(method_stat, method_stat_java) def _collect_autoscaling_metrics(self): method_stat = self._get_handle_request_stats() @@ -418,22 +454,30 @@ def callable_method_filter(attr): return self.callable return getattr(self.callable, method_name) - async def ensure_serializable_response(self, response: Any) -> Any: - if isinstance(response, starlette.responses.StreamingResponse): - - async def mock_receive(): - # This is called in a tight loop in response() just to check - # for an http disconnect. So rather than return immediately - # we should suspend execution to avoid wasting CPU cycles. - never_set_event = asyncio.Event() - await never_set_event.wait() - + async def handle_http_response( + self, response: Any, asgi_sender: Optional[Send] = None + ) -> Any: + async def mock_receive(): + # This is called in a tight loop in response() just to check + # for an http disconnect. So rather than return immediately + # we should suspend execution to avoid wasting CPU cycles. + never_set_event = asyncio.Event() + await never_set_event.wait() + + if asgi_sender is not None and isinstance( + response, starlette.responses.Response + ): + await response(scope=None, receive=mock_receive, send=asgi_sender) + elif isinstance(response, starlette.responses.StreamingResponse): sender = ASGIHTTPSender() await response(scope=None, receive=mock_receive, send=sender) return sender.build_asgi_response() + return response - async def invoke_single(self, request_item: Query) -> Tuple[Any, bool]: + async def invoke_single( + self, request_item: Query, *, asgi_sender: Optional[Send] = None + ) -> Tuple[Any, bool]: """Executes the provided request on this replica. Returns the user-provided output and a boolean indicating if the @@ -445,6 +489,9 @@ async def invoke_single(self, request_item: Query) -> Tuple[Any, bool]: ) args, kwargs = parse_request_item(request_item) + if asgi_sender is not None and hasattr(self.callable, "_serve_app"): + # TODO: comment + kwargs["asgi_sender"] = asgi_sender method_to_call = None success = True @@ -468,7 +515,7 @@ async def invoke_single(self, request_item: Query) -> Tuple[Any, bool]: # call with non-empty args result = await method_to_call(*args, **kwargs) - result = await self.ensure_serializable_response(result) + result = await self.handle_http_response(result, asgi_sender=asgi_sender) self.request_counter.inc(tags={"route": request_item.metadata.route}) except Exception as e: logger.exception(f"Request failed due to {type(e).__name__}:") @@ -517,7 +564,9 @@ async def reconfigure(self, deployment_config: DeploymentConfig): ) await reconfigure_method(self.deployment_config.user_config) - async def handle_request(self, request: Query) -> asyncio.Future: + async def handle_request( + self, request: Query, *, asgi_sender: Optional[Send] = None + ) -> asyncio.Future: async with self.rwlock.reader_lock: num_running_requests = self._get_handle_request_stats()["running"] self.num_processing_items.set(num_running_requests) @@ -534,7 +583,10 @@ async def handle_request(self, request: Query) -> asyncio.Future: ) start_time = time.time() - result, success = await self.invoke_single(request) + result, success = await self.invoke_single( + request, + asgi_sender=asgi_sender, + ) latency_ms = (time.time() - start_time) * 1000 self.processing_latency_tracker.observe( latency_ms, tags={"route": request.metadata.route} diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 1cefee55cc26..93600ed64ad4 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -1,3 +1,4 @@ +from abc import ABC import asyncio from collections import defaultdict from dataclasses import dataclass @@ -6,7 +7,7 @@ import pickle import random import sys -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import ray from ray.actor import ActorHandle @@ -23,12 +24,12 @@ from ray.serve._private.utils import ( compute_iterable_delta, JavaActorHandleProxy, + MetricsPusher, ) from ray.serve.generated.serve_pb2 import ( - RequestMetadata as RequestMetadataProto, DeploymentRoute, + RequestMetadata as RequestMetadataProto, ) -from ray.serve._private.utils import MetricsPusher logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -74,7 +75,51 @@ async def resolve_async_tasks(self): scanner.clear() -class ReplicaSet: +class ReplicaScheduler(ABC): + async def assign_replica( + self, query: Query + ) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]: + pass + + def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): + pass + + +class RoundRobinStreamingReplicaScheduler(ReplicaScheduler): + def __init__(self): + self._replica_iterator = itertools.cycle([]) + self._replicas_updated_event = asyncio.Event() + + async def assign_replica( + self, query: Query + ) -> "ray._raylet.StreamingObjectRefGenerator": + replica = None + while replica is None: + try: + replica = next(self._replica_iterator) + except StopIteration: + logger.info( + "Tried to assign replica but none available", + extra={"log_to_stderr": False}, + ) + await self._replicas_updated_event.wait() + + if replica.is_cross_language: + raise RuntimeError( + "Streaming is not yet supported for cross-language actors." + ) + + return replica.actor_handle.handle_request_streaming.options( + num_returns="streaming" + ).remote(pickle.dumps(query.metadata), *query.args, **query.kwargs) + + def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): + random.shuffle(running_replicas) + self._replica_iterator = itertools.cycle(running_replicas) + self._replicas_updated_event.set() + + +class RoundRobinReplicaScheduler(ReplicaScheduler): """Data structure representing a set of replica actor handles""" def __init__(self, event_loop: asyncio.AbstractEventLoop): @@ -316,14 +361,17 @@ def __init__( controller_handle: ActorHandle, deployment_name: str, event_loop: asyncio.BaseEventLoop = None, + _use_ray_streaming: bool = False, ): """Router process incoming queries: assign a replica. Args: controller_handle: The controller handle. """ - self._event_loop = event_loop - self._replica_set = ReplicaSet(event_loop) + if _use_ray_streaming: + self._replica_scheduler = RoundRobinStreamingReplicaScheduler(event_loop) + else: + self._replica_scheduler = RoundRobinReplicaScheduler(event_loop) # -- Metrics Registration -- # self.num_router_requests = metrics.Counter( @@ -350,7 +398,7 @@ def __init__( ( LongPollNamespace.RUNNING_REPLICAS, deployment_name, - ): self._replica_set.update_running_replicas, + ): self._replica_scheduler.update_running_replicas, }, call_in_event_loop=event_loop, ) @@ -370,18 +418,15 @@ def __init__( self.metrics_pusher.start() def _collect_handle_queue_metrics(self) -> Dict[str, int]: - return {self.deployment_name: self.get_num_queued_queries()} - - def get_num_queued_queries(self): - return self.num_queued_queries + return {self.deployment_name: self.num_queued_queries} async def assign_request( self, request_meta: RequestMetadata, *request_args, **request_kwargs, - ) -> ray.ObjectRef: - """Assign a query and returns an object ref represent the result""" + ) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]: + """Assign a query and returns an object ref represent the result.""" self.num_router_requests.inc( tags={"route": request_meta.route, "application": request_meta.app_name} @@ -394,13 +439,13 @@ async def assign_request( }, ) - result: ray.ObjectRef = await self._replica_set.assign_replica( - Query( - args=list(request_args), - kwargs=request_kwargs, - metadata=request_meta, - ) + query = Query( + args=list(request_args), + kwargs=request_kwargs, + metadata=request_meta, ) + await query.resolve_async_tasks() + result = await self._replica_scheduler.assign_replica(query) self.num_queued_queries -= 1 self.num_queued_queries_gauge.set( From e4106b0f63f9f991d63ab2702d790c3c0b97f7ba Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 12:11:01 -0500 Subject: [PATCH 37/77] fix typing Signed-off-by: Edward Oakes --- python/ray/serve/_private/router.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 93600ed64ad4..b2d4d5f53520 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -78,7 +78,7 @@ async def resolve_async_tasks(self): class ReplicaScheduler(ABC): async def assign_replica( self, query: Query - ) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]: + ) -> Union[ray.ObjectRef, ray._raylet.StreamingObjectRefGenerator]: pass def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): @@ -92,7 +92,7 @@ def __init__(self): async def assign_replica( self, query: Query - ) -> "ray._raylet.StreamingObjectRefGenerator": + ) -> ray._raylet.StreamingObjectRefGenerator: replica = None while replica is None: try: @@ -425,7 +425,7 @@ async def assign_request( request_meta: RequestMetadata, *request_args, **request_kwargs, - ) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]: + ) -> Union[ray.ObjectRef, ray._raylet.StreamingObjectRefGenerator]: """Assign a query and returns an object ref represent the result.""" self.num_router_requests.inc( From 2d58c669866dcfba6ac802e01c51738fb5476741 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 12:18:35 -0500 Subject: [PATCH 38/77] stuff Signed-off-by: Edward Oakes --- python/ray/serve/_private/constants.py | 6 ++++++ python/ray/serve/_private/http_proxy.py | 2 +- python/ray/serve/_private/router.py | 2 +- python/ray/serve/api.py | 15 +++++++++++---- python/ray/serve/handle.py | 4 ++++ python/ray/serve/tests/test_router.py | 4 ++-- 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index 527c7d47123b..23401337dcd1 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -191,3 +191,9 @@ class ServeHandleType(str, Enum): # Serve HTTP request header key for routing requests. SERVE_MULTIPLEXED_MODEL_ID = "serve_multiplexed_model_id" + +# Feature flag to enable StreamingResponse support. +# When turned on, *all* HTTP responses will use Ray streaming object refs. +RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING = ( + os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") != "0" +) diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 2d1e49753e1b..69f7c417bf2b 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -162,7 +162,7 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str: object_ref = await assignment_task if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): - return await _handle_streaming_response(object_ref, send) + return await _handle_streaming_response(object_ref, scope, receive, send) # NOTE (shrekris-anyscale): when the gcs, Serve controller, and # some replicas crash simultaneously (e.g. if the head node crashes), diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index b2d4d5f53520..3e2dc59dcf45 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -369,7 +369,7 @@ def __init__( controller_handle: The controller handle. """ if _use_ray_streaming: - self._replica_scheduler = RoundRobinStreamingReplicaScheduler(event_loop) + self._replica_scheduler = RoundRobinStreamingReplicaScheduler() else: self._replica_scheduler = RoundRobinReplicaScheduler(event_loop) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index c86d8b749f36..02c072ee2f66 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, FastAPI from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag from starlette.requests import Request +from starlette.types import ASGIApp, Send from uvicorn.config import Config from uvicorn.lifespan.on import LifespanOn @@ -236,14 +237,20 @@ async def __init__(self, *args, **kwargs): ): await self._serve_asgi_lifespan.startup() - async def __call__(self, request: Request): - sender = ASGIHTTPSender() + async def __call__( + self, request: Request, asgi_sender: Optional[Send] = None + ) -> Optional[ASGIApp]: + if asgi_sender is None: + asgi_sender = ASGIHTTPSender() + await self._serve_app( request.scope, request.receive, - sender, + asgi_sender, ) - return sender.build_asgi_response() + + if asgi_sender is None: + return asgi_sender.build_asgi_response() # NOTE: __del__ must be async so that we can run asgi shutdown # in the same event loop. diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index c449dd100920..b282bd3ba7e2 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -119,12 +119,14 @@ def __init__( *, _router: Optional[Router] = None, _internal_pickled_http_request: bool = False, + _use_ray_streaming: bool = False, ): self.controller_handle = controller_handle self.deployment_name = deployment_name self.handle_options = handle_options or HandleOptions() self.handle_tag = f"{self.deployment_name}#{get_random_letters()}" self._pickled_http_request = _internal_pickled_http_request + self._use_ray_streaming = _use_ray_streaming self.request_counter = metrics.Counter( "serve_handle_request_counter", @@ -145,6 +147,7 @@ def _make_router(self) -> Router: self.controller_handle, self.deployment_name, event_loop=get_or_create_event_loop(), + _use_ray_streaming=self._use_ray_streaming, ) @property @@ -312,6 +315,7 @@ def _make_router(self) -> Router: self.controller_handle, self.deployment_name, event_loop=_create_or_get_async_loop_in_thread(), + _use_ray_streaming=self._use_ray_streaming, ) def options( diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 4a287ce95c4d..f18d424c8b69 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -10,7 +10,7 @@ import ray from ray._private.utils import get_or_create_event_loop from ray.serve._private.common import RunningReplicaInfo -from ray.serve._private.router import Query, ReplicaSet, RequestMetadata +from ray.serve._private.router import Query, RoundRobinReplicaScheduler, RequestMetadata from ray._private.test_utils import SignalActor pytestmark = pytest.mark.asyncio @@ -79,7 +79,7 @@ async def num_queries(self): return self._num_queries # We will test a scenario with two replicas in the replica set. - rs = ReplicaSet(get_or_create_event_loop()) + rs = RoundRobinReplicaScheduler(get_or_create_event_loop()) replicas = [ RunningReplicaInfo( deployment_name="my_deployment", From a3b53970a0ab840359f3a21f78069eae8cfe19b7 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 12:34:36 -0500 Subject: [PATCH 39/77] add basic tests Signed-off-by: Edward Oakes --- python/ray/serve/BUILD | 25 ++++++++- python/ray/serve/_private/http_proxy.py | 4 +- .../tests/test_experimental_streaming.py | 19 +++++++ .../serve/tests/test_streaming_response.py | 56 +++++++++++++++++++ 4 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 python/ray/serve/tests/test_experimental_streaming.py create mode 100644 python/ray/serve/tests/test_streaming_response.py diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index c4495b9cd4f7..4cdeb6ea97f6 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -448,6 +448,7 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"], ) + # Runs test_api and test_failure with injected failures in the controller. py_test( name = "test_controller_crashes", @@ -460,6 +461,28 @@ py_test( deps = [":serve_lib"], ) +# Runs test_api, test_fastapi, and test_http_adapters with experimental streaming turned on. +py_test( + name = "test_experimental_streaming", + size = "large", + srcs = glob(["tests/test_experimental_streaming.py", + "tests/test_api.py", + "tests/test_fastapi.py", + "tests/test_http_adapters.py", + "**/conftest.py"]), + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) + +py_test( + name = "test_streaming_response", + size = "large", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], + env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, +) + py_test( name = "test_controller_recovery", size = "medium", @@ -581,4 +604,4 @@ py_test( srcs = serve_tests_srcs, tags = ["exclusive", "team:serve"], deps = [":serve_lib"], -) \ No newline at end of file +) diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 69f7c417bf2b..3f20bbbe1f0c 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -162,7 +162,9 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str: object_ref = await assignment_task if isinstance(object_ref, ray._raylet.StreamingObjectRefGenerator): - return await _handle_streaming_response(object_ref, scope, receive, send) + return await _handle_streaming_response( + object_ref, scope, receive, send + ) # NOTE (shrekris-anyscale): when the gcs, Serve controller, and # some replicas crash simultaneously (e.g. if the head node crashes), diff --git a/python/ray/serve/tests/test_experimental_streaming.py b/python/ray/serve/tests/test_experimental_streaming.py new file mode 100644 index 000000000000..4aacad77fc43 --- /dev/null +++ b/python/ray/serve/tests/test_experimental_streaming.py @@ -0,0 +1,19 @@ +import os +import pytest +from pathlib import Path +import sys + +if __name__ == "__main__": + curr_dir = Path(__file__).parent + test_paths = curr_dir.rglob("test_*.py") + sorted_path = sorted(map(lambda path: str(path.absolute()), test_paths)) + serve_tests_files = list(sorted_path) + + print("Testing the following files") + for test_file in serve_tests_files: + print("->", test_file.split("/")[-1]) + + print("Setting RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1") + os.environ["RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING"] = "1" + + sys.exit(pytest.main(["-v", "-s"] + serve_tests_files)) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py new file mode 100644 index 000000000000..c7a24664744f --- /dev/null +++ b/python/ray/serve/tests/test_streaming_response.py @@ -0,0 +1,56 @@ +import asyncio +import pytest +from typing import Generator + +from fastapi import FastAPI +import requests +from starlette.responses import StreamingResponse +from starlette.requests import Request + +import ray +from ray import serve +from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING + +def make_streaming_request() -> Generator[str, None, None]: + r = requests.get("http://localhost:8000", stream=True) + r.raise_for_status() + for chunk in r.iter_content(chunk_size=None, decode_unicode=True): + yield chunk + +@pytest.mark.skipif(not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, reason="Streaming feature flag is disabled.") +@pytest.mark.parametrize("use_fastapi", [False, True]) +def test_basic(serve_instance, use_fastapi: bool): + if use_fastapi: + app = FastAPI() + + @serve.deployment + @serve.ingress(app) + class SimpleGenerator: + async def hi_gen(self): + for i in range(10): + yield f"hi_{i}" + await asyncio.sleep(0.01) + + @app.get("/") + def stream_hi(self, request: Request) -> StreamingResponse: + return StreamingResponse(self.hi_gen(), media_type="text/plain") + else: + @serve.deployment + class SimpleGenerator: + async def hi_gen(self): + for i in range(10): + yield f"hi_{i}" + await asyncio.sleep(0.01) + + def __call__(self, request: Request) -> StreamingResponse: + return StreamingResponse(self.hi_gen(), media_type="text/plain") + + serve.run(SimpleGenerator.bind()) + + for i, chunk in enumerate(make_streaming_request()): + assert chunk == f"hi_{i}" + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-s", __file__])) From ff16f5497eab862263d6b8561c25d521fd865d1e Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 13:54:18 -0500 Subject: [PATCH 40/77] nit Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index d63d7252621c..4f2caee37d5b 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -158,7 +158,7 @@ def build_asgi_response(self) -> RawASGIResponse: return RawASGIResponse(self.messages) -class ASGIHTTPQueueSender(Send, asyncio.Queue): +class ASGIHTTPQueueSender(Send): """TODO: doc and better name""" def __init__(self): From 6402eb08ba035b2c58b006af8ea06d28f11a7d17 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 13:58:38 -0500 Subject: [PATCH 41/77] no signoff Signed-off-by: Edward Oakes --- python/ray/serve/tests/test_streaming_response.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index c7a24664744f..d8540c66fc08 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -7,17 +7,21 @@ from starlette.responses import StreamingResponse from starlette.requests import Request -import ray from ray import serve from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING + def make_streaming_request() -> Generator[str, None, None]: r = requests.get("http://localhost:8000", stream=True) r.raise_for_status() for chunk in r.iter_content(chunk_size=None, decode_unicode=True): yield chunk -@pytest.mark.skipif(not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, reason="Streaming feature flag is disabled.") + +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + reason="Streaming feature flag is disabled.", +) @pytest.mark.parametrize("use_fastapi", [False, True]) def test_basic(serve_instance, use_fastapi: bool): if use_fastapi: @@ -34,7 +38,9 @@ async def hi_gen(self): @app.get("/") def stream_hi(self, request: Request) -> StreamingResponse: return StreamingResponse(self.hi_gen(), media_type="text/plain") + else: + @serve.deployment class SimpleGenerator: async def hi_gen(self): @@ -50,6 +56,7 @@ def __call__(self, request: Request) -> StreamingResponse: for i, chunk in enumerate(make_streaming_request()): assert chunk == f"hi_{i}" + if __name__ == "__main__": import sys From aa8823abc9578738d04417e5637ff75ba627ce61 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 14:52:21 -0500 Subject: [PATCH 42/77] replace _event_loop Signed-off-by: Edward Oakes --- python/ray/serve/_private/router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 3e2dc59dcf45..9f2dd1a143f3 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -368,6 +368,7 @@ def __init__( Args: controller_handle: The controller handle. """ + self._event_loop = event_loop if _use_ray_streaming: self._replica_scheduler = RoundRobinStreamingReplicaScheduler() else: From f21f89c1591611d63eafb8e809c152c91c0dc76f Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 15:20:09 -0500 Subject: [PATCH 43/77] add sender test Signed-off-by: Edward Oakes --- python/ray/serve/BUILD | 8 ++++ python/ray/serve/_private/http_util.py | 18 +++++++- python/ray/serve/tests/test_http_util.py | 53 ++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 python/ray/serve/tests/test_http_util.py diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 4cdeb6ea97f6..b0b5f12e22dd 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -105,6 +105,14 @@ py_test( deps = [":serve_lib"], ) +py_test( + name = "test_http_util", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) + py_test( name = "test_advanced", size = "small", diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index 4f2caee37d5b..38632440cc6f 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -159,7 +159,12 @@ def build_asgi_response(self) -> RawASGIResponse: class ASGIHTTPQueueSender(Send): - """TODO: doc and better name""" + """ASGI sender that enables polling for the sent messages off a queue. + + This class assumes there's only a single consumer of the queue (concurrent + calls to `get_messages_nowait` and `wait_for_message` may result in undefined + behavior). + """ def __init__(self): self._message_queue = asyncio.Queue() @@ -171,12 +176,23 @@ async def __call__(self, message: Dict[str, Any]): self._new_message_event.set() def get_messages_nowait(self) -> Generator[Dict[str, Any], None, None]: + """Returns all messages that are currently available (non-blocking). + + At least one message will be present if `wait_for_message` had previously + returned and a subsequent call to `wait_for_message` will block until at + least one new message is available. + """ while not self._message_queue.empty(): yield self._message_queue.get_nowait() self._new_message_event.clear() async def wait_for_message(self): + """Wait until at least one new message is available. + + This will continuously return immediately once a message is available until + `get_messages_nowait` is called. + """ await self._new_message_event.wait() diff --git a/python/ray/serve/tests/test_http_util.py b/python/ray/serve/tests/test_http_util.py new file mode 100644 index 000000000000..abcb03f7c688 --- /dev/null +++ b/python/ray/serve/tests/test_http_util.py @@ -0,0 +1,53 @@ +import asyncio +import pytest + +from ray.serve._private.http_util import ASGIHTTPQueueSender + + +@pytest.mark.asyncio +async def test_asgi_queue_sender(): + sender = ASGIHTTPQueueSender() + + # Check that wait_for_message hangs until a message is sent. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(sender.wait_for_message(), 0.001) + + assert len(list(sender.get_messages_nowait())) == 0 + + await sender({"type": "http.response.start"}) + await sender.wait_for_message() + assert len(list(sender.get_messages_nowait())) == 1 + + # Check that messages are cleared after being consumed. + assert len(list(sender.get_messages_nowait())) == 0 + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(sender.wait_for_message(), 0.001) + + # Check that consecutive messages are returned in order. + await sender({"type": "http.response.start", "idx": 0}) + await sender({"type": "http.response.start", "idx": 1}) + await sender.wait_for_message() + messages = list(sender.get_messages_nowait()) + assert len(messages) == 2 + assert messages[0]["idx"] == 0 + assert messages[1]["idx"] == 1 + + assert len(list(sender.get_messages_nowait())) == 0 + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(sender.wait_for_message(), 0.001) + + # Check that a concurrent waiter is notified when a message is available. + loop = asyncio.get_running_loop() + waiting_task = loop.create_task(sender.wait_for_message()) + for _ in range(1000): + assert not waiting_task.done() + + await sender({"type": "http.response.start"}) + await waiting_task + assert len(list(sender.get_messages_nowait())) == 1 + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-s", __file__])) From afa8699a83aa3271cac193f588f01225c5485d26 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 15:44:09 -0500 Subject: [PATCH 44/77] more comments Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 58 ++++++++++++++++++++++------ python/ray/serve/_private/router.py | 14 ++++++- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 806d771b5375..150fd983f395 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -35,7 +35,7 @@ ) from ray.serve.deployment import Deployment from ray.serve.exceptions import RayServeException -from ray.serve._private.http_util import ASGIHTTPSender, ASGIHTTPQueueSender +from ray.serve._private.http_util import ASGIHTTPSender, ASGIHTTPQueueSender, Response from ray.serve._private.logging_utils import ( access_log_msg, configure_component_logger, @@ -211,13 +211,26 @@ async def handle_request_streaming( *request_args, **request_kwargs, ) -> AsyncGenerator[Dict[str, Any], None]: - """TODO""" + """Handle a request and stream the results to the caller. + + This is currently only used by the HTTP proxy for experimental + StreamingResponse support. + + The messages yielded by this generator will be ASGI-compliant messages + sent via an ASGI sender interface. This allows us to effectively proxy the + messages back to the HTTP proxy as they're sent by user code + (e.g., FastAPI wrapper). + """ query = Query( request_args, request_kwargs, pickle.loads(pickled_request_metadata), ) + # Handle the request in a background asyncio.Task. It's expected that this + # task will use the provided ASGI sender interface to send its HTTP + # response. We will poll for the sent messages and yield them back to the + # caller. asgi_queue_sender = ASGIHTTPQueueSender() handle_request_task = self._event_loop.create_task( self.replica.handle_request(query, asgi_sender=asgi_queue_sender) @@ -227,7 +240,11 @@ async def handle_request_streaming( [handle_request_task, asgi_queue_sender.wait_for_message()], return_when=asyncio.FIRST_COMPLETED, ) + # Consume all available messages in the queue. If handle_request_task + # is done, this must contain all messages sent by the user code. for msg in asgi_queue_sender.get_messages_nowait(): + # Pickle the raw ASGI dictionary because vanilla pickle is faster + # then cloudpickle and we know it's safe for these messages. yield pickle.dumps(msg) e = handle_request_task.exception() @@ -455,8 +472,18 @@ def callable_method_filter(attr): return getattr(self.callable, method_name) async def handle_http_response( - self, response: Any, asgi_sender: Optional[Send] = None + self, result: Any, asgi_sender: Optional[Send] = None ) -> Any: + """Handle a starlette Response returned by user code. + + If the result is not a Response type, this is a no-op. If it is, the + behavior depends on if asgi_sender was passed or not: + - if asgi_sender is provided, we consume the response using the sender. + It's expected that the caller is consuming the messages from this sender. + - if asgi_sender is not provided, we check to see if the response is a + StreamingResponse and convert it to a vanilla unary response. + """ + async def mock_receive(): # This is called in a tight loop in response() just to check # for an http disconnect. So rather than return immediately @@ -464,16 +491,22 @@ async def mock_receive(): never_set_event = asyncio.Event() await never_set_event.wait() - if asgi_sender is not None and isinstance( - response, starlette.responses.Response - ): - await response(scope=None, receive=mock_receive, send=asgi_sender) + if asgi_sender is not None: + if isinstance(result, starlette.responses.Response): + # Case where the user returns a Response directly. + await result(scope=None, receive=mock_receive, send=asgi_sender) + else: + # Case where the user returns a plain object (not a Response). + response = Response(result) + await response.send(scope=None, receive=mock_receive, send=asgi_sender) + + result = None elif isinstance(response, starlette.responses.StreamingResponse): sender = ASGIHTTPSender() - await response(scope=None, receive=mock_receive, send=sender) - return sender.build_asgi_response() + await result(scope=None, receive=mock_receive, send=sender) + result = sender.build_asgi_response() - return response + return result async def invoke_single( self, request_item: Query, *, asgi_sender: Optional[Send] = None @@ -482,6 +515,9 @@ async def invoke_single( Returns the user-provided output and a boolean indicating if the request succeeded (user code didn't raise an exception). + + If asgi_sender is provided, then the result will always be `None` + because the response will be sent over that interface instead. """ logger.info( f"Started executing request {request_item.metadata.request_id}", @@ -490,7 +526,7 @@ async def invoke_single( args, kwargs = parse_request_item(request_item) if asgi_sender is not None and hasattr(self.callable, "_serve_app"): - # TODO: comment + # If the callable is our FastAPI wrapper, pass it the asgi_sender. kwargs["asgi_sender"] = asgi_sender method_to_call = None diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 9f2dd1a143f3..b739f8832329 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -86,6 +86,11 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): class RoundRobinStreamingReplicaScheduler(ReplicaScheduler): + """Round-robins requests across a set of actor replicas using streaming calls. + + This policy does *not* currently respect `max_concurrent_queries`. + """ + def __init__(self): self._replica_iterator = itertools.cycle([]) self._replicas_updated_event = asyncio.Event() @@ -120,7 +125,14 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): class RoundRobinReplicaScheduler(ReplicaScheduler): - """Data structure representing a set of replica actor handles""" + """Round-robins requests across a set of actor replicas. + + The policy respects `max_concurrent_queries` for the replicas: a replica + will not be chosen if `max_concurrent_queries` requests are already outstanding. + + This is maintained using a "tracker" object ref to determine when a given request + has finished (to decrement the number of concurrent queries). + """ def __init__(self, event_loop: asyncio.AbstractEventLoop): self.in_flight_queries: Dict[RunningReplicaInfo, set] = dict() From 4c3442f2552b230740d5a80dcd85be054141b20a Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 16:16:07 -0500 Subject: [PATCH 45/77] weeee Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 11 +-- .../serve/tests/test_streaming_response.py | 97 +++++++++++++++---- 2 files changed, 84 insertions(+), 24 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 150fd983f395..ec57bd1e4892 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -374,8 +374,6 @@ def user_health_check(): self.user_health_check = sync_to_async(user_health_check) - self.num_ongoing_requests = 0 - self.request_counter = metrics.Counter( "serve_deployment_request_counter", description=( @@ -495,13 +493,13 @@ async def mock_receive(): if isinstance(result, starlette.responses.Response): # Case where the user returns a Response directly. await result(scope=None, receive=mock_receive, send=asgi_sender) - else: + elif result is not None: # Case where the user returns a plain object (not a Response). response = Response(result) await response.send(scope=None, receive=mock_receive, send=asgi_sender) result = None - elif isinstance(response, starlette.responses.StreamingResponse): + elif isinstance(result, starlette.responses.StreamingResponse): sender = ASGIHTTPSender() await result(scope=None, receive=mock_receive, send=sender) result = sender.build_asgi_response() @@ -654,14 +652,15 @@ async def prepare_for_shutdown(self): # The handle_request method wasn't even invoked. if method_stat is None: break + num_ongoing_requests = method_stat["running"] + method_stat["pending"] # The handle_request method has 0 inflight requests. - if method_stat["running"] + method_stat["pending"] == 0: + if num_ongoing_requests == 0: break else: logger.info( "Waiting for an additional " f"{self.deployment_config.graceful_shutdown_wait_loop_s}s to shut " - f"down because there are {self.num_ongoing_requests} ongoing " + f"down because there are {num_ongoing_requests} ongoing " "requests." ) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index d8540c66fc08..3fef6a492c66 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -7,15 +7,20 @@ from starlette.responses import StreamingResponse from starlette.requests import Request +import ray +from ray._private.test_utils import SignalActor + from ray import serve from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING -def make_streaming_request() -> Generator[str, None, None]: - r = requests.get("http://localhost:8000", stream=True) - r.raise_for_status() - for chunk in r.iter_content(chunk_size=None, decode_unicode=True): - yield chunk +@ray.remote +class StreamingRequester: + def make_request(self) -> Generator[str, None, None]: + r = requests.get("http://localhost:8000", stream=True) + r.raise_for_status() + for chunk in r.iter_content(chunk_size=None, decode_unicode=True): + yield chunk @pytest.mark.skipif( @@ -24,39 +29,95 @@ def make_streaming_request() -> Generator[str, None, None]: ) @pytest.mark.parametrize("use_fastapi", [False, True]) def test_basic(serve_instance, use_fastapi: bool): + async def hi_gen(): + for i in range(10): + yield f"hi_{i}" + await asyncio.sleep(0.01) + if use_fastapi: app = FastAPI() @serve.deployment @serve.ingress(app) class SimpleGenerator: - async def hi_gen(self): - for i in range(10): - yield f"hi_{i}" - await asyncio.sleep(0.01) - @app.get("/") def stream_hi(self, request: Request) -> StreamingResponse: - return StreamingResponse(self.hi_gen(), media_type="text/plain") + return StreamingResponse(hi_gen(), media_type="text/plain") else: @serve.deployment class SimpleGenerator: - async def hi_gen(self): - for i in range(10): - yield f"hi_{i}" - await asyncio.sleep(0.01) - def __call__(self, request: Request) -> StreamingResponse: - return StreamingResponse(self.hi_gen(), media_type="text/plain") + return StreamingResponse(hi_gen(), media_type="text/plain") serve.run(SimpleGenerator.bind()) - for i, chunk in enumerate(make_streaming_request()): + r = requests.get("http://localhost:8000", stream=True) + r.raise_for_status() + for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)): assert chunk == f"hi_{i}" +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + reason="Streaming feature flag is disabled.", +) +@pytest.mark.parametrize("use_fastapi", [False, True]) +def test_response_actually_streamed(serve_instance, use_fastapi: bool): + """Checks that responses are streamed as they are yielded.""" + signal_actor = SignalActor.remote() + + async def wait_on_signal_generator(): + yield "before signal" + await signal_actor.wait.remote() + yield "after signal" + + if use_fastapi: + app = FastAPI() + + @serve.deployment + @serve.ingress(app) + class SimpleGenerator: + @app.get("/") + def stream(self, request: Request) -> StreamingResponse: + return StreamingResponse( + wait_on_signal_generator(), media_type="text/plain" + ) + + else: + + @serve.deployment + class SimpleGenerator: + def __call__(self, request: Request) -> StreamingResponse: + return StreamingResponse( + wait_on_signal_generator(), media_type="text/plain" + ) + + serve.run(SimpleGenerator.bind()) + + requester = StreamingRequester.remote() + gen = requester.make_request.options(num_returns="streaming").remote() + + # Check that we get the first response before the signal is sent + # (so the generator is still hanging after the first yield). + obj_ref = next(gen) + assert ray.get(obj_ref) == "before signal" + + # Check that the next obj_ref is not ready yet. + obj_ref = gen._next_sync(timeout_s=0.01) + assert obj_ref.is_nil() + + # Now send signal to actor, second yield happens. + ray.get(signal_actor.send.remote()) + obj_ref = next(gen) + assert ray.get(obj_ref) == "after signal" + + # Client should be done getting messages. + with pytest.raises(StopIteration): + next(gen) + + if __name__ == "__main__": import sys From 3d48b017175bc95ec908c17934b235796e3c6625 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 16:28:03 -0500 Subject: [PATCH 46/77] more tests Signed-off-by: Edward Oakes --- .../serve/tests/test_streaming_response.py | 66 +++++++++++++++++-- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index 3fef6a492c66..b7021a6f26cc 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -1,4 +1,3 @@ -import asyncio import pytest from typing import Generator @@ -28,11 +27,15 @@ def make_request(self) -> Generator[str, None, None]: reason="Streaming feature flag is disabled.", ) @pytest.mark.parametrize("use_fastapi", [False, True]) -def test_basic(serve_instance, use_fastapi: bool): - async def hi_gen(): +@pytest.mark.parametrize("use_async", [False, True]) +def test_basic(serve_instance, use_async: bool, use_fastapi: bool): + async def hi_gen_async(): + for i in range(10): + yield f"hi_{i}" + + def hi_gen_sync(): for i in range(10): yield f"hi_{i}" - await asyncio.sleep(0.01) if use_fastapi: app = FastAPI() @@ -42,14 +45,16 @@ async def hi_gen(): class SimpleGenerator: @app.get("/") def stream_hi(self, request: Request) -> StreamingResponse: - return StreamingResponse(hi_gen(), media_type="text/plain") + gen = hi_gen_async() if use_async else hi_gen_sync() + return StreamingResponse(gen, media_type="text/plain") else: @serve.deployment class SimpleGenerator: def __call__(self, request: Request) -> StreamingResponse: - return StreamingResponse(hi_gen(), media_type="text/plain") + gen = hi_gen_async() if use_async else hi_gen_sync() + return StreamingResponse(gen, media_type="text/plain") serve.run(SimpleGenerator.bind()) @@ -118,6 +123,55 @@ def __call__(self, request: Request) -> StreamingResponse: next(gen) +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + reason="Streaming feature flag is disabled.", +) +@pytest.mark.parametrize("use_fastapi", [False, True]) +def test_metadata_preserved(serve_instance, use_fastapi: bool): + """Check that status code, headers, and media type are preserved.""" + + def hi_gen(): + for i in range(10): + yield f"hi_{i}" + + if use_fastapi: + app = FastAPI() + + @serve.deployment + @serve.ingress(app) + class SimpleGenerator: + @app.get("/") + def stream_hi(self, request: Request) -> StreamingResponse: + return StreamingResponse( + hi_gen(), + status_code=301, + headers={"hello": "world"}, + media_type="foo/bar", + ) + + else: + + @serve.deployment + class SimpleGenerator: + def __call__(self, request: Request) -> StreamingResponse: + return StreamingResponse( + hi_gen(), + status_code=301, + headers={"hello": "world"}, + media_type="foo/bar", + ) + + serve.run(SimpleGenerator.bind()) + + r = requests.get("http://localhost:8000", stream=True) + assert r.status_code == 301 + assert r.headers["hello"] == "world" + assert r.headers["content-type"] == "foo/bar" + for i, chunk in enumerate(r.iter_content(chunk_size=None)): + assert chunk == f"hi_{i}".encode("utf-8") + + if __name__ == "__main__": import sys From 88781fb1c75872839d5c3cfbca445d12ca52dd67 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 16:33:52 -0500 Subject: [PATCH 47/77] async/sync Signed-off-by: Edward Oakes --- .../serve/tests/test_streaming_response.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index b7021a6f26cc..d6e9b42431a4 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -69,15 +69,21 @@ def __call__(self, request: Request) -> StreamingResponse: reason="Streaming feature flag is disabled.", ) @pytest.mark.parametrize("use_fastapi", [False, True]) -def test_response_actually_streamed(serve_instance, use_fastapi: bool): +@pytest.mark.parametrize("use_async", [False, True]) +def test_response_actually_streamed(serve_instance, use_fastapi: bool, use_async: bool): """Checks that responses are streamed as they are yielded.""" signal_actor = SignalActor.remote() - async def wait_on_signal_generator(): + async def wait_on_signal_async(): yield "before signal" await signal_actor.wait.remote() yield "after signal" + async def wait_on_signal_sync(): + yield "before signal" + ray.get(signal_actor.wait.remote()) + yield "after signal" + if use_fastapi: app = FastAPI() @@ -86,18 +92,16 @@ async def wait_on_signal_generator(): class SimpleGenerator: @app.get("/") def stream(self, request: Request) -> StreamingResponse: - return StreamingResponse( - wait_on_signal_generator(), media_type="text/plain" - ) + gen = wait_on_signal_async() if use_async else wait_on_signal_sync() + return StreamingResponse(gen, media_type="text/plain") else: @serve.deployment class SimpleGenerator: def __call__(self, request: Request) -> StreamingResponse: - return StreamingResponse( - wait_on_signal_generator(), media_type="text/plain" - ) + gen = wait_on_signal_async() if use_async else wait_on_signal_sync() + return StreamingResponse(gen, media_type="text/plain") serve.run(SimpleGenerator.bind()) From 4d6eacea7892300fad2de9a55bf9ba31857793b4 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 16:53:31 -0500 Subject: [PATCH 48/77] more tests Signed-off-by: Edward Oakes --- .../serve/tests/test_streaming_response.py | 73 +++++++++++++------ 1 file changed, 52 insertions(+), 21 deletions(-) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index d6e9b42431a4..742f4a29fc39 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -1,5 +1,7 @@ +import asyncio +import os import pytest -from typing import Generator +from typing import AsyncGenerator from fastapi import FastAPI import requests @@ -15,11 +17,12 @@ @ray.remote class StreamingRequester: - def make_request(self) -> Generator[str, None, None]: + async def make_request(self) -> AsyncGenerator[str, None]: r = requests.get("http://localhost:8000", stream=True) r.raise_for_status() for chunk in r.iter_content(chunk_size=None, decode_unicode=True): yield chunk + await asyncio.sleep(0.001) @pytest.mark.skipif( @@ -70,19 +73,26 @@ def __call__(self, request: Request) -> StreamingResponse: ) @pytest.mark.parametrize("use_fastapi", [False, True]) @pytest.mark.parametrize("use_async", [False, True]) -def test_response_actually_streamed(serve_instance, use_fastapi: bool, use_async: bool): - """Checks that responses are streamed as they are yielded.""" +@pytest.mark.parametrize("use_multiple_replicas", [False, True]) +def test_responses_actually_streamed( + serve_instance, use_fastapi: bool, use_async: bool, use_multiple_replicas: bool +): + """Checks that responses are streamed as they are yielded. + + Also checks that responses can be streamed concurrently from a single replica + or from multiple replicas. + """ signal_actor = SignalActor.remote() async def wait_on_signal_async(): - yield "before signal" + yield f"{os.getpid()}: before signal" await signal_actor.wait.remote() - yield "after signal" + yield f"{os.getpid()}: after signal" - async def wait_on_signal_sync(): - yield "before signal" + def wait_on_signal_sync(): + yield f"{os.getpid()}: before signal" ray.get(signal_actor.wait.remote()) - yield "after signal" + yield f"{os.getpid()}: after signal" if use_fastapi: app = FastAPI() @@ -103,28 +113,49 @@ def __call__(self, request: Request) -> StreamingResponse: gen = wait_on_signal_async() if use_async else wait_on_signal_sync() return StreamingResponse(gen, media_type="text/plain") - serve.run(SimpleGenerator.bind()) + serve.run( + SimpleGenerator.options( + ray_actor_options={"num_cpus": 0}, + num_replicas=2 if use_multiple_replicas else 1, + ).bind() + ) requester = StreamingRequester.remote() - gen = requester.make_request.options(num_returns="streaming").remote() + gen1 = requester.make_request.options(num_returns="streaming").remote() + gen2 = requester.make_request.options(num_returns="streaming").remote() - # Check that we get the first response before the signal is sent + # Check that we get the first responses before the signal is sent # (so the generator is still hanging after the first yield). - obj_ref = next(gen) - assert ray.get(obj_ref) == "before signal" + gen1_result = ray.get(next(gen1)) + gen2_result = ray.get(next(gen2)) + assert gen1_result.endswith("before signal") + assert gen2_result.endswith("before signal") + gen1_pid = gen1_result.split(":")[0] + gen2_pid = gen2_result.split(":")[0] + if use_multiple_replicas: + assert gen1_pid != gen2_pid + else: + assert gen1_pid == gen2_pid - # Check that the next obj_ref is not ready yet. - obj_ref = gen._next_sync(timeout_s=0.01) - assert obj_ref.is_nil() + # Check that the next obj_ref is not ready yet for both generators. + assert gen1._next_sync(timeout_s=0.01).is_nil() + assert gen2._next_sync(timeout_s=0.01).is_nil() - # Now send signal to actor, second yield happens. + # Now send signal to actor, second yield happens and we should get responses. ray.get(signal_actor.send.remote()) - obj_ref = next(gen) - assert ray.get(obj_ref) == "after signal" + gen1_result = ray.get(next(gen1)) + gen2_result = ray.get(next(gen2)) + assert gen1_result.startswith(gen1_pid) + assert gen2_result.startswith(gen2_pid) + assert gen1_result.endswith("after signal") + assert gen2_result.endswith("after signal") # Client should be done getting messages. with pytest.raises(StopIteration): - next(gen) + next(gen1) + + with pytest.raises(StopIteration): + next(gen2) @pytest.mark.skipif( From 7a050a7440908d2cf608a5bf35cd7cd64c89c5d7 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 16:58:45 -0500 Subject: [PATCH 49/77] fix fastapi test Signed-off-by: Edward Oakes --- python/ray/serve/api.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 02c072ee2f66..3016593c910c 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -240,8 +240,18 @@ async def __init__(self, *args, **kwargs): async def __call__( self, request: Request, asgi_sender: Optional[Send] = None ) -> Optional[ASGIApp]: + """Calls into the wrapped ASGI app. + + If asgi_sender is provided, it's passed into the app and nothing is + returned. + + If no asgi_sender is provided, an ASGI response will be built and + returned. + """ + build_and_return_response = False if asgi_sender is None: asgi_sender = ASGIHTTPSender() + build_and_return_response = True await self._serve_app( request.scope, @@ -249,7 +259,7 @@ async def __call__( asgi_sender, ) - if asgi_sender is None: + if build_and_return_response: return asgi_sender.build_asgi_response() # NOTE: __del__ must be async so that we can run asgi shutdown From 5bde9fdcc81a486a0bd1b4c6c75c3be77698f0a3 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 17:17:31 -0500 Subject: [PATCH 50/77] batch Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_proxy.py | 11 ++++++----- python/ray/serve/_private/replica.py | 11 +++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 3f20bbbe1f0c..206d5d36c23d 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -6,7 +6,7 @@ import pickle import socket import time -from typing import Callable, List, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from ray._private.utils import get_or_create_event_loop import uvicorn @@ -101,11 +101,12 @@ async def _handle_streaming_response( status_code = "" try: async for obj_ref in asgi_response_generator: - asgi_message = pickle.loads(await obj_ref) - if asgi_message["type"] == "http.response.start": - status_code = str(asgi_message["status"]) + asgi_messages: List[Dict[str, Any]] = pickle.loads(await obj_ref) + for asgi_message in asgi_messages: + if asgi_message["type"] == "http.response.start": + status_code = str(asgi_message["status"]) - await send(asgi_message) + await send(asgi_message) except Exception as e: error_message = "Unexpected error, traceback: {}.".format(e) logger.warning(error_message) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index ec57bd1e4892..e7fed184f80f 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -240,12 +240,11 @@ async def handle_request_streaming( [handle_request_task, asgi_queue_sender.wait_for_message()], return_when=asyncio.FIRST_COMPLETED, ) - # Consume all available messages in the queue. If handle_request_task - # is done, this must contain all messages sent by the user code. - for msg in asgi_queue_sender.get_messages_nowait(): - # Pickle the raw ASGI dictionary because vanilla pickle is faster - # then cloudpickle and we know it's safe for these messages. - yield pickle.dumps(msg) + # Consume and yield all available messages in the queue. + # The messages are batched into a list to avoid unnecessary RPCs and + # we use vanilla pickle because it's faster than cloudpickle and we + # know it's safe for these messages containing primitive types. + yield pickle.dumps(list(asgi_queue_sender.get_messages_nowait())) e = handle_request_task.exception() if e is not None: From 3aa69b90840ccc74e95316e6fafa2b231786fe84 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 20:59:50 -0500 Subject: [PATCH 51/77] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_util.py | 8 +++++--- python/ray/serve/_private/replica.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index 38632440cc6f..af2726c6c9a9 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -4,7 +4,7 @@ import inspect import json import logging -from typing import Any, Dict, Generator, Type +from typing import Any, Dict, List, Type from starlette.requests import Request from starlette.types import Send, ASGIApp @@ -175,17 +175,19 @@ async def __call__(self, message: Dict[str, Any]): await self._message_queue.put(message) self._new_message_event.set() - def get_messages_nowait(self) -> Generator[Dict[str, Any], None, None]: + def get_messages_nowait(self) -> List[Dict[str, Any]]: """Returns all messages that are currently available (non-blocking). At least one message will be present if `wait_for_message` had previously returned and a subsequent call to `wait_for_message` will block until at least one new message is available. """ + messages = [] while not self._message_queue.empty(): - yield self._message_queue.get_nowait() + messages.append(self._message_queue.get_nowait()) self._new_message_event.clear() + return messages async def wait_for_message(self): """Wait until at least one new message is available. diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index e7fed184f80f..9df2d2394d28 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -235,8 +235,10 @@ async def handle_request_streaming( handle_request_task = self._event_loop.create_task( self.replica.handle_request(query, asgi_sender=asgi_queue_sender) ) - while not handle_request_task.done(): - done, pending = await asyncio.wait( + + done = [] + while handle_request_task not in done: + done, _ = await asyncio.wait( [handle_request_task, asgi_queue_sender.wait_for_message()], return_when=asyncio.FIRST_COMPLETED, ) @@ -244,7 +246,7 @@ async def handle_request_streaming( # The messages are batched into a list to avoid unnecessary RPCs and # we use vanilla pickle because it's faster than cloudpickle and we # know it's safe for these messages containing primitive types. - yield pickle.dumps(list(asgi_queue_sender.get_messages_nowait())) + yield pickle.dumps(asgi_queue_sender.get_messages_nowait()) e = handle_request_task.exception() if e is not None: From 61acf4104159c7b35c1929d00deaef334147288a Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 24 May 2023 22:48:02 -0500 Subject: [PATCH 52/77] fix tests Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_util.py | 2 +- python/ray/serve/_private/replica.py | 9 +++++++-- python/ray/serve/tests/test_multiplex.py | 16 ++++++++-------- python/ray/serve/tests/test_router.py | 2 +- python/ray/serve/tests/test_standalone3.py | 14 +++++++------- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index af2726c6c9a9..e34ed07e70d1 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -133,7 +133,7 @@ class RawASGIResponse(ASGIApp): def __init__(self, messages): self.messages = messages - async def __call__(self, _scope, _receive, send): + async def __call__(self, scope, receive, send): for message in self.messages: await send(message) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 9df2d2394d28..bf2a5eecdd5a 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -35,7 +35,12 @@ ) from ray.serve.deployment import Deployment from ray.serve.exceptions import RayServeException -from ray.serve._private.http_util import ASGIHTTPSender, ASGIHTTPQueueSender, Response +from ray.serve._private.http_util import ( + ASGIHTTPSender, + ASGIHTTPQueueSender, + RawASGIResponse, + Response, +) from ray.serve._private.logging_utils import ( access_log_msg, configure_component_logger, @@ -491,7 +496,7 @@ async def mock_receive(): await never_set_event.wait() if asgi_sender is not None: - if isinstance(result, starlette.responses.Response): + if isinstance(result, (starlette.responses.Response, RawASGIResponse)): # Case where the user returns a Response directly. await result(scope=None, receive=mock_receive, send=asgi_sender) elif result is not None: diff --git a/python/ray/serve/tests/test_multiplex.py b/python/ray/serve/tests/test_multiplex.py index 86c2bf785124..8213eb603e7d 100644 --- a/python/ray/serve/tests/test_multiplex.py +++ b/python/ray/serve/tests/test_multiplex.py @@ -232,7 +232,7 @@ def check_replica_information( wait_for_condition( check_replica_information, - replicas=handle.router._replica_set.in_flight_queries.keys(), + replicas=handle.router._replica_scheduler.in_flight_queries.keys(), deployment=deployment, replica_tag=replica_tag, model_ids=[ @@ -243,7 +243,7 @@ def check_replica_information( ray.get(handle.remote("model2")) wait_for_condition( check_replica_information, - replicas=handle.router._replica_set.in_flight_queries.keys(), + replicas=handle.router._replica_scheduler.in_flight_queries.keys(), deployment=deployment, replica_tag=replica_tag, model_ids=[ @@ -256,7 +256,7 @@ def check_replica_information( ray.get(handle.remote("model3")) wait_for_condition( check_replica_information, - replicas=handle.router._replica_set.in_flight_queries.keys(), + replicas=handle.router._replica_scheduler.in_flight_queries.keys(), deployment=deployment, replica_tag=replica_tag, model_ids=[ @@ -291,7 +291,7 @@ async def __call__(self, request): resp = requests.get("http://localhost:8000", headers=headers) assert resp.json() == pid wait_for_condition( - lambda: "1" in handle.router._replica_set.multiplexed_replicas_table, + lambda: "1" in handle.router._replica_scheduler.multiplexed_replicas_table, ) for _ in range(10): @@ -325,8 +325,8 @@ async def __call__(self, request): requests.get("http://localhost:8000", headers=headers) wait_for_condition( - lambda: "1" in handle.router._replica_set.multiplexed_replicas_table - and "3" in handle.router._replica_set.multiplexed_replicas_table, + lambda: "1" in handle.router._replica_scheduler.multiplexed_replicas_table + and "3" in handle.router._replica_scheduler.multiplexed_replicas_table, ) @@ -354,8 +354,8 @@ async def __call__(self, request): signal.send.remote() assert ray.get(resp1_ref) != ray.get(resp2_ref) wait_for_condition( - lambda: "1" in handle.router._replica_set.multiplexed_replicas_table - and len(handle.router._replica_set.multiplexed_replicas_table["1"]) == 2 + lambda: "1" in handle.router._replica_scheduler.multiplexed_replicas_table + and len(handle.router._replica_scheduler.multiplexed_replicas_table["1"]) == 2 ) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index f18d424c8b69..ae46fb0e7e84 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -62,7 +62,7 @@ def task_runner_mock_actor(): yield mock_task_runner() -async def test_replica_set(ray_instance): +async def test_replica_scheduler(ray_instance): signal = SignalActor.remote() @ray.remote(num_cpus=0) diff --git a/python/ray/serve/tests/test_standalone3.py b/python/ray/serve/tests/test_standalone3.py index 47b7173b187c..287cc2042b2b 100644 --- a/python/ray/serve/tests/test_standalone3.py +++ b/python/ray/serve/tests/test_standalone3.py @@ -135,13 +135,13 @@ async def f(): ) # Make sure the inflight queries still one - assert len(handle.router._replica_set.in_flight_queries) == 1 - key = list(handle.router._replica_set.in_flight_queries.keys())[0] - assert len(handle.router._replica_set.in_flight_queries[key]) == 1 + assert len(handle.router._replica_scheduler.in_flight_queries) == 1 + key = list(handle.router._replica_scheduler.in_flight_queries.keys())[0] + assert len(handle.router._replica_scheduler.in_flight_queries[key]) == 1 # Make sure the first request is being run. - replicas = list(handle.router._replica_set.in_flight_queries.keys()) - assert len(handle.router._replica_set.in_flight_queries[replicas[0]]) == 1 + replicas = list(handle.router._replica_scheduler.in_flight_queries.keys()) + assert len(handle.router._replica_scheduler.in_flight_queries[replicas[0]]) == 1 # First ref should be still ongoing with pytest.raises(ray.exceptions.GetTimeoutError): ray.get(first_ref, timeout=1) @@ -307,7 +307,7 @@ def f(do_crash: bool = False): handle = serve.run(f.bind()) pids = ray.get([handle.remote() for _ in range(2)]) assert len(set(pids)) == 2 - assert len(handle.router._replica_set.in_flight_queries.keys()) == 2 + assert len(handle.router._replica_scheduler.in_flight_queries.keys()) == 2 client = get_global_client() # Kill the controller so that the replicas membership won't be updated @@ -319,7 +319,7 @@ def f(do_crash: bool = False): pids = ray.get([handle.remote() for _ in range(10)]) assert len(set(pids)) == 1 - assert len(handle.router._replica_set.in_flight_queries.keys()) == 1 + assert len(handle.router._replica_scheduler.in_flight_queries.keys()) == 1 # Restart the controller, and then clean up all the replicas serve.start(detached=True) From 86b741a797b08d46d9fb3de174da7c25df4c4c61 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 08:58:23 -0500 Subject: [PATCH 53/77] fix doc import error Signed-off-by: Edward Oakes --- python/ray/serve/_private/router.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index b739f8832329..4664db29f343 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -78,7 +78,7 @@ async def resolve_async_tasks(self): class ReplicaScheduler(ABC): async def assign_replica( self, query: Query - ) -> Union[ray.ObjectRef, ray._raylet.StreamingObjectRefGenerator]: + ) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]: pass def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): @@ -97,7 +97,7 @@ def __init__(self): async def assign_replica( self, query: Query - ) -> ray._raylet.StreamingObjectRefGenerator: + ) -> "ray._raylet.StreamingObjectRefGenerator": replica = None while replica is None: try: @@ -438,7 +438,7 @@ async def assign_request( request_meta: RequestMetadata, *request_args, **request_kwargs, - ) -> Union[ray.ObjectRef, ray._raylet.StreamingObjectRefGenerator]: + ) -> Union[ray.ObjectRef, "ray._raylet.StreamingObjectRefGenerator"]: """Assign a query and returns an object ref represent the result.""" self.num_router_requests.inc( From c16f437c25859ca8c413ab8ef01fa1a23948600f Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:09:52 -0500 Subject: [PATCH 54/77] skip windows Signed-off-by: Edward Oakes --- python/ray/serve/tests/test_fastapi.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/ray/serve/tests/test_fastapi.py b/python/ray/serve/tests/test_fastapi.py index f340314bbc26..61a4000fae7a 100644 --- a/python/ray/serve/tests/test_fastapi.py +++ b/python/ray/serve/tests/test_fastapi.py @@ -1,11 +1,8 @@ import time from typing import Any, List, Optional import tempfile -import numpy as np +import sys -import pytest -import inspect -import requests from fastapi import ( Cookie, Depends, @@ -19,19 +16,25 @@ ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +import inspect +import numpy as np from pydantic import BaseModel, Field +import pytest +import requests from starlette.applications import Starlette import starlette.responses from starlette.routing import Route import ray +from ray._private.test_utils import SignalActor, wait_for_condition + from ray import serve from ray.exceptions import GetTimeoutError from ray.serve.exceptions import RayServeException +from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING from ray.serve._private.client import ServeControllerClient from ray.serve._private.http_util import make_fastapi_class_based_view from ray.serve._private.utils import DEFAULT -from ray._private.test_utils import SignalActor, wait_for_condition def test_fastapi_function(serve_instance): @@ -648,6 +651,10 @@ def decr2(self): assert requests.get("http://localhost:8000" + path).status_code == 404, path +@pytest.mark.skipif( + RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING and sys.platform == "win32", + reason="https://github.com/ray-project/ray/issues/35775", +) def test_fastapi_custom_serializers(serve_instance): app = FastAPI() From 7c066f8b643dfddaac93e48b5e7c81af3e09c2e8 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:31:56 -0500 Subject: [PATCH 55/77] fix Signed-off-by: Edward Oakes --- .../serve/tests/test_streaming_response.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index 742f4a29fc39..e322c016bd2b 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -207,6 +207,52 @@ def __call__(self, request: Request) -> StreamingResponse: assert chunk == f"hi_{i}".encode("utf-8") +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + reason="Streaming feature flag is disabled.", +) +@pytest.mark.parametrize("use_fastapi", [False, True]) +@pytest.mark.parametrize("use_async", [False, True]) +def test_exception_in_generator(serve_instance, use_async: bool, use_fastapi: bool): + async def hi_gen_async(): + yield "first result" + raise Exception("raised in generator") + yield "never reached" + + def hi_gen_sync(): + yield "first result" + raise Exception("raised in generator") + yield "never reached" + + if use_fastapi: + app = FastAPI() + + @serve.deployment + @serve.ingress(app) + class SimpleGenerator: + @app.get("/") + def stream_hi(self, request: Request) -> StreamingResponse: + gen = hi_gen_async() if use_async else hi_gen_sync() + return StreamingResponse(gen, media_type="text/plain") + + else: + + @serve.deployment + class SimpleGenerator: + def __call__(self, request: Request) -> StreamingResponse: + gen = hi_gen_async() if use_async else hi_gen_sync() + return StreamingResponse(gen, media_type="text/plain") + + serve.run(SimpleGenerator.bind()) + + r = requests.get("http://localhost:8000", stream=True) + r.raise_for_status() + stream_iter = r.iter_content(chunk_size=None, decode_unicode=True) + assert next(stream_iter) == "first result" + with pytest.raises(requests.exceptions.ChunkedEncodingError): + next(stream_iter) + + if __name__ == "__main__": import sys From 9db27af3101b968af6804b6d83b605063cabdab6 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:33:44 -0500 Subject: [PATCH 56/77] docstring Signed-off-by: Edward Oakes --- python/ray/serve/_private/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index c1d61497616b..8c9dc3f10eac 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -451,6 +451,10 @@ def get_handle( sync: If true, then Serve will return a ServeHandle that works everywhere. Otherwise, Serve will return a ServeHandle that's only usable in asyncio loop. + _internal_pickled_http_request: Indicates that this handle will be used + to send HTTP requests from the proxy to ingress deployment replicas. + _use_ray_streaming: Indicates that this handle should use + `num_returns="streaming"`. Returns: RayServeHandle From 7f049060577417d6e35bb596f351bc474bdb6170 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:36:59 -0500 Subject: [PATCH 57/77] nit Signed-off-by: Edward Oakes --- python/ray/serve/tests/test_experimental_streaming.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/tests/test_experimental_streaming.py b/python/ray/serve/tests/test_experimental_streaming.py index 4aacad77fc43..2558426752e8 100644 --- a/python/ray/serve/tests/test_experimental_streaming.py +++ b/python/ray/serve/tests/test_experimental_streaming.py @@ -1,8 +1,9 @@ -import os import pytest from pathlib import Path import sys +from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING + if __name__ == "__main__": curr_dir = Path(__file__).parent test_paths = curr_dir.rglob("test_*.py") @@ -13,7 +14,8 @@ for test_file in serve_tests_files: print("->", test_file.split("/")[-1]) - print("Setting RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1") - os.environ["RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING"] = "1" + assert ( + RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING + ), "RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 must be set." sys.exit(pytest.main(["-v", "-s"] + serve_tests_files)) From 6bfa34f43ecdcf25db6ef524f9bd1a96123adf25 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:38:32 -0500 Subject: [PATCH 58/77] s/_use_ray_streaming/_stream Signed-off-by: Edward Oakes --- python/ray/serve/BUILD | 1 + python/ray/serve/_private/client.py | 8 ++++---- python/ray/serve/_private/http_proxy.py | 2 +- python/ray/serve/_private/router.py | 4 ++-- python/ray/serve/handle.py | 8 ++++---- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index b0b5f12e22dd..d9566c848631 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -480,6 +480,7 @@ py_test( "**/conftest.py"]), tags = ["exclusive", "team:serve"], deps = [":serve_lib"], + env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, ) py_test( diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index 8c9dc3f10eac..ca4966ef7327 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -440,7 +440,7 @@ def get_handle( missing_ok: Optional[bool] = False, sync: bool = True, _internal_pickled_http_request: bool = False, - _use_ray_streaming: bool = False, + _stream: bool = False, ) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service deployment to invoke it from Python. @@ -453,7 +453,7 @@ def get_handle( that's only usable in asyncio loop. _internal_pickled_http_request: Indicates that this handle will be used to send HTTP requests from the proxy to ingress deployment replicas. - _use_ray_streaming: Indicates that this handle should use + _stream: Indicates that this handle should use `num_returns="streaming"`. Returns: @@ -474,14 +474,14 @@ def get_handle( self._controller, deployment_name, _internal_pickled_http_request=_internal_pickled_http_request, - _use_ray_streaming=_use_ray_streaming, + _stream=_stream, ) else: handle = RayServeHandle( self._controller, deployment_name, _internal_pickled_http_request=_internal_pickled_http_request, - _use_ray_streaming=_use_ray_streaming, + _stream=_stream, ) self.handle_cache[cache_key] = handle diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 206d5d36c23d..e1c141d31d27 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -332,7 +332,7 @@ def get_handle(name): sync=False, missing_ok=True, _internal_pickled_http_request=True, - _use_ray_streaming=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + _stream=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, ) self.prefix_router = LongestPrefixRouter(get_handle) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 4664db29f343..29ec4ef2fdbe 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -373,7 +373,7 @@ def __init__( controller_handle: ActorHandle, deployment_name: str, event_loop: asyncio.BaseEventLoop = None, - _use_ray_streaming: bool = False, + _stream: bool = False, ): """Router process incoming queries: assign a replica. @@ -381,7 +381,7 @@ def __init__( controller_handle: The controller handle. """ self._event_loop = event_loop - if _use_ray_streaming: + if _stream: self._replica_scheduler = RoundRobinStreamingReplicaScheduler() else: self._replica_scheduler = RoundRobinReplicaScheduler(event_loop) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index b282bd3ba7e2..0b7f2e3623fa 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -119,14 +119,14 @@ def __init__( *, _router: Optional[Router] = None, _internal_pickled_http_request: bool = False, - _use_ray_streaming: bool = False, + _stream: bool = False, ): self.controller_handle = controller_handle self.deployment_name = deployment_name self.handle_options = handle_options or HandleOptions() self.handle_tag = f"{self.deployment_name}#{get_random_letters()}" self._pickled_http_request = _internal_pickled_http_request - self._use_ray_streaming = _use_ray_streaming + self._stream = _stream self.request_counter = metrics.Counter( "serve_handle_request_counter", @@ -147,7 +147,7 @@ def _make_router(self) -> Router: self.controller_handle, self.deployment_name, event_loop=get_or_create_event_loop(), - _use_ray_streaming=self._use_ray_streaming, + _stream=self._stream, ) @property @@ -315,7 +315,7 @@ def _make_router(self) -> Router: self.controller_handle, self.deployment_name, event_loop=_create_or_get_async_loop_in_thread(), - _use_ray_streaming=self._use_ray_streaming, + _stream=self._stream, ) def options( From 5c081723f9e8689d6be7f2c72d6376bf6afa9928 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:39:51 -0500 Subject: [PATCH 59/77] .format change Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_proxy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index e1c141d31d27..28a111315b08 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -108,7 +108,7 @@ async def _handle_streaming_response( await send(asgi_message) except Exception as e: - error_message = "Unexpected error, traceback: {}.".format(e) + error_message = f"Unexpected error, traceback: {e}." logger.warning(error_message) if status_code == "": @@ -194,8 +194,8 @@ async def _send_request_to_handle(handle, scope, receive, send) -> str: # Here because the client disconnected, we will return a custom # error code for metric tracking. return DISCONNECT_ERROR_CODE - except RayTaskError as error: - error_message = "Task Error. Traceback: {}.".format(error) + except RayTaskError as e: + error_message = f"Unexpected error, traceback: {e}." await Response(error_message, status_code=500).send(scope, receive, send) return "500" except RayActorError: From 31c78202099ab6dca513568076396249f2371425 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 09:46:13 -0500 Subject: [PATCH 60/77] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_proxy.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 28a111315b08..3863b6aeea3b 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -103,7 +103,15 @@ async def _handle_streaming_response( async for obj_ref in asgi_response_generator: asgi_messages: List[Dict[str, Any]] = pickle.loads(await obj_ref) for asgi_message in asgi_messages: - if asgi_message["type"] == "http.response.start": + # There must be exactly one "http.response.start" message that + # always contains the "status" field. + if not status_code: + assert asgi_message["type"] == "http.response.start", ( + "First response message must be 'http.response.start'", + ) + assert "status" in asgi_message, ( + "'http.response.start' message must contain 'status'", + ) status_code = str(asgi_message["status"]) await send(asgi_message) From 3751f898f5cab90fb9659e05e14f6049896b840d Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 10:51:57 -0500 Subject: [PATCH 61/77] add docs Signed-off-by: Edward Oakes --- doc/BUILD | 29 ++++++++++- .../serve/doc_code/streaming_example.py | 38 ++++++++++++++ doc/source/serve/http-guide.md | 49 ++++++++++++++++++- 3 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 doc/source/serve/doc_code/streaming_example.py diff --git a/doc/BUILD b/doc/BUILD index ff6dc8a956bc..b5a728b81355 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -131,10 +131,36 @@ py_test_run_all_subdirectory( exclude = [ "source/serve/doc_code/distilbert.py", "source/serve/doc_code/stable_diffusion.py", - "source/serve/doc_code/object_detection.py", + "source/serve/doc_code/object_detection.py", + ], + extra_srcs = [], + tags = ["exclusive", "team:serve"], +) + +py_test_run_all_subdirectory( + size = "medium", + include = [ + "source/serve/doc_code/distilbert.py", + "source/serve/doc_code/stable_diffusion.py", + "source/serve/doc_code/object_detection.py", + ], + exclude = [], + extra_srcs = [], + tags = ["exclusive", "team:serve", "gpu"], +) + +# Run all Serve doc code with streaming enabled as well. +py_test_run_all_subdirectory( + size = "medium", + include = ["source/serve/doc_code/**/*.py"], + exclude = [ + "source/serve/doc_code/distilbert.py", + "source/serve/doc_code/stable_diffusion.py", + "source/serve/doc_code/object_detection.py", ], extra_srcs = [], tags = ["exclusive", "team:serve"], + env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, ) py_test_run_all_subdirectory( @@ -147,6 +173,7 @@ py_test_run_all_subdirectory( exclude = [], extra_srcs = [], tags = ["exclusive", "team:serve", "gpu"], + env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, ) diff --git a/doc/source/serve/doc_code/streaming_example.py b/doc/source/serve/doc_code/streaming_example.py new file mode 100644 index 000000000000..e42866e3f9c2 --- /dev/null +++ b/doc/source/serve/doc_code/streaming_example.py @@ -0,0 +1,38 @@ +# flake8: noqa + +# __begin_example__ +import time +from typing import Generator + +import requests +from starlette.responses import StreamingResponse +from starlette.requests import Request + +from ray import serve + +@serve.deployment +class StreamingResponder: + def generate_numbers(self, max: int) -> Generator[str, None, None]: + for i in range(max): + yield str(i) + time.sleep(0.1) + + def __call__(self, request: Request) -> StreamingResponse: + max = request.query_params.get("max", "25") + gen = self.generate_numbers(int(max)) + return StreamingResponse(gen, status_code=200, media_type="text/plain") + +serve.run(StreamingResponder.bind()) + +r = requests.get("http://localhost:8000?max=10", stream=True) +start = time.time() +r.raise_for_status() +for chunk in r.iter_content(chunk_size=None, decode_unicode=True): + print(f"Got result {round(time.time()-start, 1)}s after start: '{chunk}'") +# __end_example__ + + +r = requests.get("http://localhost:8000?max=10", stream=True) +r.raise_for_status() +for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)): + assert chunk == str(i) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index 234e252c29e7..9d5044806fa4 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -3,7 +3,7 @@ This section helps you understand how to: - send HTTP requests to Serve deployments - use Ray Serve to integrate with FastAPI -- use customized HTTP Adapters +- use customized HTTP adapters - choose which feature to use for your use case ## Choosing the right HTTP feature @@ -74,6 +74,51 @@ Existing middlewares, **automatic OpenAPI documentation generation**, and other Serve currently does not support WebSockets. If you have a use case that requires it, please [let us know](https://github.com/ray-project/ray/issues/new/choose)! ``` +(serve-http-streaming-response)= +## Streaming Responses + +:::{warning} +Support for HTTP streaming responses is currently experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 on the cluster. If you encounter any issues, please [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). +::: + +Some applications require streaming incremental results back to the caller. +This is common for large language models (LLMs) used for text generation. +The full forward pass may take multiple seconds, so providing incremental results as they're available provides a much better user experience. + +To use HTTP response streaming, return a [StreamingResponse](https://www.starlette.io/responses/#streamingresponse) that wraps a generator from your HTTP handler. +This is supported for basic HTTP ingress deployments using a `__call__` method and when using the [FastAPI integration](serve-fastapi-http). + +The code below defines a Serve application that incrementally streams numbers up to a provided `max`. +The client-side code is also updated to handle the streaming outputs. +In this case, we are using the `stream=True` option to the [requests](https://requests.readthedocs.io/en/latest/user/advanced/#streaming-requests) library. + +```{literalinclude} ../serve/doc_code/streaming_example.py +:start-after: __begin_example__ +:end-before: __end_example__ +:language: python +``` + +Save this code in `stream.py` and run it: + +```bash +$ RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 python stream.py +[2023-05-25 10:44:23] INFO ray._private.worker::Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 +(ServeController pid=40401) INFO 2023-05-25 10:44:25,296 controller 40401 deployment_state.py:1259 - Deploying new version of deployment default_StreamingResponder. +(HTTPProxyActor pid=40403) INFO: Started server process [40403] +(ServeController pid=40401) INFO 2023-05-25 10:44:25,333 controller 40401 deployment_state.py:1498 - Adding 1 replica to deployment default_StreamingResponder. +Got result 0.0s after start: '0' +Got result 0.1s after start: '1' +Got result 0.2s after start: '2' +Got result 0.3s after start: '3' +Got result 0.4s after start: '4' +Got result 0.5s after start: '5' +Got result 0.6s after start: '6' +Got result 0.7s after start: '7' +Got result 0.8s after start: '8' +Got result 0.9s after start: '9' +(ServeReplica:default_StreamingResponder pid=41052) INFO 2023-05-25 10:49:52,230 default_StreamingResponder default_StreamingResponder#qlZFCa yomKnJifNJ / default replica.py:634 - __CALL__ OK 1017.6ms +``` + (serve-http-adapters)= ## HTTP Adapters @@ -190,7 +235,7 @@ PredictorDeployment.deploy(..., http_adapter=User) DAGDriver.bind(other_node, http_adapter=User) ``` -### List of Built-in Adapters +### List of built-in adapters Here is a list of adapters; please feel free to [contribute more](https://github.com/ray-project/ray/issues/new/choose)! From 8b2df0de31c081342a1568d56092120ede358145 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 10:54:00 -0500 Subject: [PATCH 62/77] fix fmt Signed-off-by: Edward Oakes --- doc/source/serve/http-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index 9d5044806fa4..ef6ef71738a6 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -77,9 +77,9 @@ Serve currently does not support WebSockets. If you have a use case that require (serve-http-streaming-response)= ## Streaming Responses -:::{warning} +```{warning} Support for HTTP streaming responses is currently experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 on the cluster. If you encounter any issues, please [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). -::: +``` Some applications require streaming incremental results back to the caller. This is common for large language models (LLMs) used for text generation. From b047feb4a86197182dcf64cb1d616954316b2481 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 11:35:58 -0500 Subject: [PATCH 63/77] fix Signed-off-by: Edward Oakes --- doc/source/serve/doc_code/streaming_example.py | 2 ++ doc/source/serve/http-guide.md | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/source/serve/doc_code/streaming_example.py b/doc/source/serve/doc_code/streaming_example.py index e42866e3f9c2..9c3b0dd63472 100644 --- a/doc/source/serve/doc_code/streaming_example.py +++ b/doc/source/serve/doc_code/streaming_example.py @@ -10,6 +10,7 @@ from ray import serve + @serve.deployment class StreamingResponder: def generate_numbers(self, max: int) -> Generator[str, None, None]: @@ -22,6 +23,7 @@ def __call__(self, request: Request) -> StreamingResponse: gen = self.generate_numbers(int(max)) return StreamingResponse(gen, status_code=200, media_type="text/plain") + serve.run(StreamingResponder.bind()) r = requests.get("http://localhost:8000?max=10", stream=True) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index ef6ef71738a6..9bcd91a18875 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -78,7 +78,7 @@ Serve currently does not support WebSockets. If you have a use case that require ## Streaming Responses ```{warning} -Support for HTTP streaming responses is currently experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 on the cluster. If you encounter any issues, please [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). +Support for HTTP streaming responses is currently experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 on the cluster before starting Ray. If you encounter any issues, please [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). ``` Some applications require streaming incremental results back to the caller. From 34f585e7c5b5a07a709e19738b3396796fd03340 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 11:37:07 -0500 Subject: [PATCH 64/77] fix Signed-off-by: Edward Oakes --- doc/source/serve/http-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index 9bcd91a18875..cc398ae15b6a 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -82,7 +82,7 @@ Support for HTTP streaming responses is currently experimental. To enable this f ``` Some applications require streaming incremental results back to the caller. -This is common for large language models (LLMs) used for text generation. +This is common for large language models (LLMs) used for text generation or video processing applications. The full forward pass may take multiple seconds, so providing incremental results as they're available provides a much better user experience. To use HTTP response streaming, return a [StreamingResponse](https://www.starlette.io/responses/#streamingresponse) that wraps a generator from your HTTP handler. From 12271c674caf8f326743a9fc2d529295129c71c3 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 12:08:03 -0500 Subject: [PATCH 65/77] revert Signed-off-by: Edward Oakes --- doc/BUILD | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/doc/BUILD b/doc/BUILD index b5a728b81355..867508dd3fe6 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -132,6 +132,7 @@ py_test_run_all_subdirectory( "source/serve/doc_code/distilbert.py", "source/serve/doc_code/stable_diffusion.py", "source/serve/doc_code/object_detection.py", + "source/serve/doc_code/streaming_example.py", ], extra_srcs = [], tags = ["exclusive", "team:serve"], @@ -139,25 +140,7 @@ py_test_run_all_subdirectory( py_test_run_all_subdirectory( size = "medium", - include = [ - "source/serve/doc_code/distilbert.py", - "source/serve/doc_code/stable_diffusion.py", - "source/serve/doc_code/object_detection.py", - ], - exclude = [], - extra_srcs = [], - tags = ["exclusive", "team:serve", "gpu"], -) - -# Run all Serve doc code with streaming enabled as well. -py_test_run_all_subdirectory( - size = "medium", - include = ["source/serve/doc_code/**/*.py"], - exclude = [ - "source/serve/doc_code/distilbert.py", - "source/serve/doc_code/stable_diffusion.py", - "source/serve/doc_code/object_detection.py", - ], + include = ["source/serve/doc_code/streaming_example.py"], extra_srcs = [], tags = ["exclusive", "team:serve"], env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, @@ -173,11 +156,11 @@ py_test_run_all_subdirectory( exclude = [], extra_srcs = [], tags = ["exclusive", "team:serve", "gpu"], - env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, ) + # -------------------------------------------------------------------- # Test all doc/source/tune/doc_code code included in rst/md files. # -------------------------------------------------------------------- From 4efcf7b66a92220ddb4419a2a6ad4d4530e827ea Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 13:25:47 -0500 Subject: [PATCH 66/77] fix BUILD Signed-off-by: Edward Oakes --- doc/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/BUILD b/doc/BUILD index 867508dd3fe6..9b34a804c86b 100644 --- a/doc/BUILD +++ b/doc/BUILD @@ -141,6 +141,7 @@ py_test_run_all_subdirectory( py_test_run_all_subdirectory( size = "medium", include = ["source/serve/doc_code/streaming_example.py"], + exclude = [], extra_srcs = [], tags = ["exclusive", "team:serve"], env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, From 97611c51f8cbe55f977ec5172797a24280e113c0 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 13:56:04 -0500 Subject: [PATCH 67/77] clearer Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 83 ++++++++++++++++------------ python/ray/serve/api.py | 3 +- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index bf2a5eecdd5a..7059f85dde7b 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -60,6 +60,14 @@ logger = logging.getLogger(SERVE_LOGGER_NAME) +async def mock_asgi_receive(): + # This is called in a tight loop in responses just to check + # for an HTTP disconnect. So rather than returning immediately + # we should suspend execution to avoid wasting CPU cycles. + never_set_event = asyncio.Event() + await never_set_event.wait() + + def _format_replica_actor_name(deployment_name: str): return f"ServeReplica:{deployment_name}" @@ -475,42 +483,31 @@ def callable_method_filter(attr): return self.callable return getattr(self.callable, method_name) - async def handle_http_response( - self, result: Any, asgi_sender: Optional[Send] = None - ) -> Any: - """Handle a starlette Response returned by user code. - - If the result is not a Response type, this is a no-op. If it is, the - behavior depends on if asgi_sender was passed or not: - - if asgi_sender is provided, we consume the response using the sender. - It's expected that the caller is consuming the messages from this sender. - - if asgi_sender is not provided, we check to see if the response is a - StreamingResponse and convert it to a vanilla unary response. - """ + async def send_user_result_over_asgi(self, result: Any, asgi_sender: Send): + """Handle the result from user code and send it over the ASGI interface. - async def mock_receive(): - # This is called in a tight loop in response() just to check - # for an http disconnect. So rather than return immediately - # we should suspend execution to avoid wasting CPU cycles. - never_set_event = asyncio.Event() - await never_set_event.wait() - - if asgi_sender is not None: - if isinstance(result, (starlette.responses.Response, RawASGIResponse)): - # Case where the user returns a Response directly. - await result(scope=None, receive=mock_receive, send=asgi_sender) - elif result is not None: - # Case where the user returns a plain object (not a Response). - response = Response(result) - await response.send(scope=None, receive=mock_receive, send=asgi_sender) + If the result is already a Response type, it will be sent directly. Else it + will be converted to our custom Response type that handles serialization for + common Python objects. + """ + if not isinstance(result, (starlette.responses.Response, RawASGIResponse)): + await Response(result).send( + scope=None, receive=mock_asgi_receive, send=asgi_sender + ) + else: + await result(scope=None, receive=mock_asgi_receive, send=asgi_sender) - result = None - elif isinstance(result, starlette.responses.StreamingResponse): - sender = ASGIHTTPSender() - await result(scope=None, receive=mock_receive, send=sender) - result = sender.build_asgi_response() + async def convert_streaming_response( + self, response: starlette.responses.StreamingResponse + ) -> RawASGIResponse: + """Convert a StreamingResponse to a custom buffered unary response. - return result + This is used on the legacy non-streaming codepath because we cannot serialize + and return a StreamingResponse. + """ + sender = ASGIHTTPSender() + await response(scope=None, receive=mock_asgi_receive, send=sender) + return sender.build_asgi_response() async def invoke_single( self, request_item: Query, *, asgi_sender: Optional[Send] = None @@ -529,8 +526,9 @@ async def invoke_single( ) args, kwargs = parse_request_item(request_item) - if asgi_sender is not None and hasattr(self.callable, "_serve_app"): - # If the callable is our FastAPI wrapper, pass it the asgi_sender. + + callable_is_asgi_wrapper = hasattr(self.callable, "_is_serve_asgi_wrapper") + if asgi_sender is not None and callable_is_asgi_wrapper: kwargs["asgi_sender"] = asgi_sender method_to_call = None @@ -555,7 +553,20 @@ async def invoke_single( # call with non-empty args result = await method_to_call(*args, **kwargs) - result = await self.handle_http_response(result, asgi_sender=asgi_sender) + # Streaming HTTP codepath: always send response over ASGI interface. + if asgi_sender is not None: + # For the FastAPI codepath, the response has already been sent over the + # ASGI interace and result should always be `None`. + if callable_is_asgi_wrapper: + assert result is None + # For the vanilla deployment codepath, always send the result over ASGI. + else: + result = await self.send_user_result_over_asgi(result, asgi_sender) + + # Legacy codepath: always return the result, so ensure it can be serialized. + elif isinstance(result, starlette.responses.StreamingResponse): + result = await self.convert_streaming_response(result) + self.request_counter.inc(tags={"route": request_item.metadata.route}) except Exception as e: logger.exception(f"Request failed due to {type(e).__name__}:") diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 3016593c910c..8ef8e9440f69 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -222,7 +222,8 @@ async def __init__(self, *args, **kwargs): install_serve_encoders_to_fastapi() self._serve_app = frozen_app - + # Used in `replica.py` to detect the usage of this class. + self._is_serve_asgi_wrapper = True # Use uvicorn's lifespan handling code to properly deal with # startup and shutdown event. self._serve_asgi_lifespan = LifespanOn( From 5fdc3bc1834dfe663d1b9340ba1d8c2f7aaae640 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 13:58:07 -0500 Subject: [PATCH 68/77] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index 23401337dcd1..dd00525422a4 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -195,5 +195,5 @@ class ServeHandleType(str, Enum): # Feature flag to enable StreamingResponse support. # When turned on, *all* HTTP responses will use Ray streaming object refs. RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING = ( - os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") != "0" + os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") == "1" ) From 8617ad55465ea8b7e9b6d8adb66f17168820ab8d Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 14:28:22 -0500 Subject: [PATCH 69/77] fix? Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 3863b6aeea3b..3f289e0ef123 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -75,7 +75,7 @@ async def _handle_streaming_response( - asgi_response_generator: ray._raylet.StreamingObjectRefGenerator, + asgi_response_generator: "ray._raylet.StreamingObjectRefGenerator", scope: Scope, receive: Receive, send: Send, From 1cc577267df824f503e78569b93677a6a8f495eb Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 15:06:58 -0500 Subject: [PATCH 70/77] comment Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 7059f85dde7b..b361cd593e57 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -527,6 +527,8 @@ async def invoke_single( args, kwargs = parse_request_item(request_item) + # Check if the callable is our ASGI wrapper (i.e., the user used + # `@serve.ingress`). callable_is_asgi_wrapper = hasattr(self.callable, "_is_serve_asgi_wrapper") if asgi_sender is not None and callable_is_asgi_wrapper: kwargs["asgi_sender"] = asgi_sender From d039bb4458675ae9a96a780cec64ada5b4911daf Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 15:07:38 -0500 Subject: [PATCH 71/77] fix Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index b361cd593e57..3e4c2fb37b88 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -497,7 +497,7 @@ async def send_user_result_over_asgi(self, result: Any, asgi_sender: Send): else: await result(scope=None, receive=mock_asgi_receive, send=asgi_sender) - async def convert_streaming_response( + async def convert_streaming_response_to_unary( self, response: starlette.responses.StreamingResponse ) -> RawASGIResponse: """Convert a StreamingResponse to a custom buffered unary response. @@ -567,7 +567,7 @@ async def invoke_single( # Legacy codepath: always return the result, so ensure it can be serialized. elif isinstance(result, starlette.responses.StreamingResponse): - result = await self.convert_streaming_response(result) + result = await self.convert_streaming_response_to_unary(result) self.request_counter.inc(tags={"route": request_item.metadata.route}) except Exception as e: From 8dc379cf09761869484a0a0026005e137184bfb6 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 17:09:34 -0500 Subject: [PATCH 72/77] Update doc/source/serve/http-guide.md Co-authored-by: angelinalg <122562471+angelinalg@users.noreply.github.com> Signed-off-by: Edward Oakes --- doc/source/serve/http-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index cc398ae15b6a..5b683c14310b 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -78,7 +78,7 @@ Serve currently does not support WebSockets. If you have a use case that require ## Streaming Responses ```{warning} -Support for HTTP streaming responses is currently experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 on the cluster before starting Ray. If you encounter any issues, please [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). +Support for HTTP streaming responses is experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1` on the cluster before starting Ray. If you encounter any issues, [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). ``` Some applications require streaming incremental results back to the caller. From d5ccabbde84031fa4c12fb8099c2f4b44a8c8d58 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 17:09:46 -0500 Subject: [PATCH 73/77] Update doc/source/serve/http-guide.md Co-authored-by: angelinalg <122562471+angelinalg@users.noreply.github.com> Signed-off-by: Edward Oakes --- doc/source/serve/http-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index 5b683c14310b..8bbb3f129557 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -81,7 +81,7 @@ Serve currently does not support WebSockets. If you have a use case that require Support for HTTP streaming responses is experimental. To enable this feature, set `RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1` on the cluster before starting Ray. If you encounter any issues, [file an issue on GitHub](https://github.com/ray-project/ray/issues/new/choose). ``` -Some applications require streaming incremental results back to the caller. +Some applications must stream incremental results back to the caller. This is common for large language models (LLMs) used for text generation or video processing applications. The full forward pass may take multiple seconds, so providing incremental results as they're available provides a much better user experience. From 2d21721e29bb88b5377759188957abf0f6a4652a Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 17:17:23 -0500 Subject: [PATCH 74/77] angelina's suggestions Signed-off-by: Edward Oakes --- doc/source/serve/http-guide.md | 4 ++-- python/ray/serve/_private/http_proxy.py | 17 +++++++++-------- python/ray/serve/_private/http_util.py | 7 +++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/source/serve/http-guide.md b/doc/source/serve/http-guide.md index 8bbb3f129557..06a16f544294 100644 --- a/doc/source/serve/http-guide.md +++ b/doc/source/serve/http-guide.md @@ -82,7 +82,7 @@ Support for HTTP streaming responses is experimental. To enable this feature, se ``` Some applications must stream incremental results back to the caller. -This is common for large language models (LLMs) used for text generation or video processing applications. +This is common for text generation using large language models (LLMs) or video processing applications. The full forward pass may take multiple seconds, so providing incremental results as they're available provides a much better user experience. To use HTTP response streaming, return a [StreamingResponse](https://www.starlette.io/responses/#streamingresponse) that wraps a generator from your HTTP handler. @@ -90,7 +90,7 @@ This is supported for basic HTTP ingress deployments using a `__call__` method a The code below defines a Serve application that incrementally streams numbers up to a provided `max`. The client-side code is also updated to handle the streaming outputs. -In this case, we are using the `stream=True` option to the [requests](https://requests.readthedocs.io/en/latest/user/advanced/#streaming-requests) library. +This code uses the `stream=True` option to the [requests](https://requests.readthedocs.io/en/latest/user/advanced/#streaming-requests) library. ```{literalinclude} ../serve/doc_code/streaming_example.py :start-after: __begin_example__ diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 3f289e0ef123..45032752cb62 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -82,17 +82,18 @@ async def _handle_streaming_response( ) -> str: """Consumes the `asgi_response_generator` and sends its data over `send`. - This function is essentially a proxy for a downstream ASGI response. The - passed generator is expected to return a stream of pickled ASGI messages - (dictionaries) that will be passed to the provided ASGI send interface. + This function is a proxy for a downstream ASGI response. The passed + generator is expected to return a stream of pickled ASGI messages + (dictionaries) that are sent using the provided ASGI interface. - Exception handling depends on if the first message has already been sent: - - if an exception happens *before* the first message, a 500 status will be sent. - - if an exception happens *after* the first message, the response stream will be + Exception handling depends on whether the first message has already been sent: + - if an exception happens *before* the first message, a 500 status is sent. + - if an exception happens *after* the first message, the response stream is terminated. - This is because once the first message has been sent, the client has already - received the status code. + The difference in behavior is because once the first message has been sent, the + client has already received the status code so we cannot send a `500` (internal + server error). Returns: status_code diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index e34ed07e70d1..acf568a36bdd 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -161,9 +161,8 @@ def build_asgi_response(self) -> RawASGIResponse: class ASGIHTTPQueueSender(Send): """ASGI sender that enables polling for the sent messages off a queue. - This class assumes there's only a single consumer of the queue (concurrent - calls to `get_messages_nowait` and `wait_for_message` may result in undefined - behavior). + This class assumes a single consumer of the queue (concurrent calls to + `get_messages_nowait` and `wait_for_message` may result in undefined behavior). """ def __init__(self): @@ -179,7 +178,7 @@ def get_messages_nowait(self) -> List[Dict[str, Any]]: """Returns all messages that are currently available (non-blocking). At least one message will be present if `wait_for_message` had previously - returned and a subsequent call to `wait_for_message` will block until at + returned and a subsequent call to `wait_for_message` blocks until at least one new message is available. """ messages = [] From e35fb1334476414b9ce547ddcb0d7ac8c1a9ed25 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 17:19:25 -0500 Subject: [PATCH 75/77] angelina 2 Signed-off-by: Edward Oakes --- python/ray/serve/_private/http_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index acf568a36bdd..7c5410995649 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -191,8 +191,8 @@ def get_messages_nowait(self) -> List[Dict[str, Any]]: async def wait_for_message(self): """Wait until at least one new message is available. - This will continuously return immediately once a message is available until - `get_messages_nowait` is called. + If a message is available, this method will return immediately on each call + until `get_messages_nowait` is called. """ await self._new_message_event.wait() From 8b5f9be181b1fe36caa891f26c18962456c14544 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 17:20:17 -0500 Subject: [PATCH 76/77] angelina3 Signed-off-by: Edward Oakes --- python/ray/serve/_private/replica.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 3e4c2fb37b88..0aa376a7a407 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -226,13 +226,11 @@ async def handle_request_streaming( ) -> AsyncGenerator[Dict[str, Any], None]: """Handle a request and stream the results to the caller. - This is currently only used by the HTTP proxy for experimental - StreamingResponse support. + This is used by the HTTP proxy for experimental StreamingResponse support. - The messages yielded by this generator will be ASGI-compliant messages - sent via an ASGI sender interface. This allows us to effectively proxy the - messages back to the HTTP proxy as they're sent by user code - (e.g., FastAPI wrapper). + This generator yields ASGI-compliant messages sent via an ASGI sender + interface. This allows us to return the messages back to the HTTP proxy as + they're sent by user code (e.g., the FastAPI wrapper). """ query = Query( request_args, @@ -486,8 +484,8 @@ def callable_method_filter(attr): async def send_user_result_over_asgi(self, result: Any, asgi_sender: Send): """Handle the result from user code and send it over the ASGI interface. - If the result is already a Response type, it will be sent directly. Else it - will be converted to our custom Response type that handles serialization for + If the result is already a Response type, it is sent directly. Otherwise, it + is converted to a custom Response type that handles serialization for common Python objects. """ if not isinstance(result, (starlette.responses.Response, RawASGIResponse)): @@ -517,8 +515,8 @@ async def invoke_single( Returns the user-provided output and a boolean indicating if the request succeeded (user code didn't raise an exception). - If asgi_sender is provided, then the result will always be `None` - because the response will be sent over that interface instead. + If asgi_sender is provided, then the result is always `None` + because the response is sent over that interface instead. """ logger.info( f"Started executing request {request_item.metadata.request_id}", From dca76c6fcc2586a36929eaa1ed07bd028c684680 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 May 2023 17:23:41 -0500 Subject: [PATCH 77/77] angelina4 Signed-off-by: Edward Oakes --- python/ray/serve/_private/router.py | 4 ++-- python/ray/serve/api.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 29ec4ef2fdbe..6bbed7c75887 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -88,7 +88,7 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]): class RoundRobinStreamingReplicaScheduler(ReplicaScheduler): """Round-robins requests across a set of actor replicas using streaming calls. - This policy does *not* currently respect `max_concurrent_queries`. + This policy does *not* respect `max_concurrent_queries`. """ def __init__(self): @@ -128,7 +128,7 @@ class RoundRobinReplicaScheduler(ReplicaScheduler): """Round-robins requests across a set of actor replicas. The policy respects `max_concurrent_queries` for the replicas: a replica - will not be chosen if `max_concurrent_queries` requests are already outstanding. + is not chosen if `max_concurrent_queries` requests are already outstanding. This is maintained using a "tracker" object ref to determine when a given request has finished (to decrement the number of concurrent queries). diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 8ef8e9440f69..6e92457da2c8 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -246,8 +246,7 @@ async def __call__( If asgi_sender is provided, it's passed into the app and nothing is returned. - If no asgi_sender is provided, an ASGI response will be built and - returned. + If no asgi_sender is provided, an ASGI response is built and returned. """ build_and_return_response = False if asgi_sender is None: