Skip to content

Commit

Permalink
[Core][Streaming Generator] Fix the perf regression from a serve hand…
Browse files Browse the repository at this point in the history
…le bug fix ray-project#38171 (ray-project#38280)

Before ray-project#37972, we ran the reporting & serilization output (in cpp) in a main thread while all the async actor tasks run in an async thread. However, after the PR, we now run both of them in an async thread.

This caused regression when there are decently large size (200~2KB) generator workloads (Aviary) because the serialization code was running with nogil. It means we could utilize real multi-threading because serialization code runs in a main thread, and async actor code runs in an async thread.

This PR fixes the issue by dispatching a cpp code (reporting & serialization) to a separate thread again. I also found when I used threadPoolExecutor, there were some circular dependencies issues where it leaks objects when exceptions happen. I realized this was due to the fact that Python exception captures the local references (thus there were some circular references). I refactored some part of code to avoid this from happening and added an unit test for that.

Signed-off-by: e428265 <[email protected]>
  • Loading branch information
rkooo567 authored and arvind-chandra committed Aug 31, 2023
1 parent 3fc8738 commit 93e3e96
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 19 deletions.
1 change: 1 addition & 0 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,7 @@ def connect(
logs_dir = ""
else:
logs_dir = node.get_logs_dir_path()

worker.core_worker = ray._raylet.CoreWorker(
mode,
node.plasma_store_socket_name,
Expand Down
1 change: 1 addition & 0 deletions python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ cdef class CoreWorker:
object eventloop_for_default_cg
object thread_for_default_cg
object fd_to_cgname_dict
object thread_pool_for_async_event_loop

cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata,
size_t data_size, ObjectRef object_ref,
Expand Down
53 changes: 35 additions & 18 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ from typing import (
Generator,
AsyncGenerator
)
from concurrent.futures import ThreadPoolExecutor

from libc.stdint cimport (
int32_t,
Expand Down Expand Up @@ -1013,7 +1014,7 @@ cdef class StreamingGeneratorExecutionContext:
return self


cpdef report_streaming_generator_output(
cdef report_streaming_generator_output(
output_or_exception: Union[object, Exception],
StreamingGeneratorExecutionContext context):
"""Report a given generator output to a caller.
Expand All @@ -1036,23 +1037,12 @@ cpdef report_streaming_generator_output(
worker = ray._private.worker.global_worker

cdef:
CoreWorker core_worker = worker.core_worker
# Ray Object created from an output.
c_pair[CObjectID, shared_ptr[CRayObject]] return_obj

try:
if isinstance(output_or_exception, Exception):
raise output_or_exception
except AsyncioActorExit:
# Make the task handle this exception.
raise
except StopAsyncIteration:
return True
except StopIteration:
return True
except Exception as e:
if isinstance(output_or_exception, Exception):
create_generator_error_object(
e,
output_or_exception,
worker,
context.task_type,
context.caller_address,
Expand Down Expand Up @@ -1143,8 +1133,11 @@ cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context
while True:
try:
output_or_exception = next(gen)
except StopIteration:
break
except Exception as e:
output_or_exception = e

done = report_streaming_generator_output(output_or_exception, context)
if done:
break
Expand Down Expand Up @@ -1180,12 +1173,24 @@ async def execute_streaming_generator_async(
while True:
try:
output_or_exception = await gen.__anext__()
except StopAsyncIteration:
break
except AsyncioActorExit:
# The execute_task will handle this case.
raise
except Exception as e:
output_or_exception = e
# TODO(sang): This method involves in serializing the output.
# Ideally, we don't want to run this inside an event loop.
done = report_streaming_generator_output(
output_or_exception, context)

loop = asyncio.get_running_loop()
worker = ray._private.worker.global_worker
# Run it in a separate thread to that we can
# avoid blocking the event loop when serializing
# the output (which has nogil).
done = await loop.run_in_executor(
worker.core_worker.get_thread_pool_for_async_event_loop(),
report_streaming_generator_output,
output_or_exception,
context)
if done:
break

Expand Down Expand Up @@ -1760,6 +1765,7 @@ cdef void execute_task(
worker, outputs,
caller_address,
returns)

except Exception as e:
num_errors_stored = store_task_errors(
worker, e, task_exception, actor, actor_id, function_name,
Expand Down Expand Up @@ -2926,6 +2932,7 @@ cdef class CoreWorker:
self.fd_to_cgname_dict = None
self.eventloop_for_default_cg = None
self.current_runtime_env = None
self.thread_pool_for_async_event_loop = None

def shutdown(self):
# If it's a worker, the core worker process should have been
Expand Down Expand Up @@ -4072,6 +4079,13 @@ cdef class CoreWorker:
for fd in function_descriptors:
self.fd_to_cgname_dict[fd] = cg_name

def get_thread_pool_for_async_event_loop(self):
if self.thread_pool_for_async_event_loop is None:
# Theoretically, we can use multiple threads,
self.thread_pool_for_async_event_loop = ThreadPoolExecutor(
max_workers=1)
return self.thread_pool_for_async_event_loop

def get_event_loop(self, function_descriptor, specified_cgname):
# __init__ will be invoked in default eventloop
if function_descriptor.function_name == "__init__":
Expand Down Expand Up @@ -4143,6 +4157,9 @@ cdef class CoreWorker:
def stop_and_join_asyncio_threads_if_exist(self):
event_loops = []
threads = []
if self.thread_pool_for_async_event_loop:
self.thread_pool_for_async_event_loop.shutdown(
wait=False, cancel_futures=True)
if self.eventloop_for_default_cg is not None:
event_loops.append(self.eventloop_for_default_cg)
if self.thread_for_default_cg is not None:
Expand Down
107 changes: 106 additions & 1 deletion python/ray/tests/test_streaming_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import ray
from ray._private.test_utils import wait_for_condition
from ray.experimental.state.api import list_objects
from ray.experimental.state.api import list_objects, list_actors
from ray._raylet import StreamingObjectRefGenerator, ObjectRefStreamEndOfStreamError
from ray.cloudpickle import dumps
from ray.exceptions import WorkerCrashedError
Expand Down Expand Up @@ -1168,6 +1168,111 @@ def f():
assert_no_leak()


def test_python_object_leak(shutdown_only):
"""Make sure the objects are not leaked
(due to circular references) when tasks run
for all the execution model in Ray actors.
"""
ray.init()

@ray.remote
class AsyncActor:
def __init__(self):
self.gc_garbage_len = 0

def get_gc_garbage_len(self):
return self.gc_garbage_len

async def gen(self, fail=False):
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
self.gc_garbage_len = len(gc.garbage)
print("Objects: ", self.gc_garbage_len)
if fail:
print("exception")
raise Exception
yield 1

async def f(self, fail=False):
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
self.gc_garbage_len = len(gc.garbage)
print("Objects: ", self.gc_garbage_len)
if fail:
print("exception")
raise Exception
return 1

@ray.remote
class A:
def __init__(self):
self.gc_garbage_len = 0

def get_gc_garbage_len(self):
return self.gc_garbage_len

def f(self, fail=False):
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
self.gc_garbage_len = len(gc.garbage)
print("Objects: ", self.gc_garbage_len)
if fail:
print("exception")
raise Exception
return 1

def gen(self, fail=False):
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
self.gc_garbage_len = len(gc.garbage)
print("Objects: ", self.gc_garbage_len)
if fail:
print("exception")
raise Exception
yield 1

def verify_regular(actor, fail):
for _ in range(100):
try:
ray.get(actor.f.remote(fail=fail))
except Exception:
pass
assert ray.get(actor.get_gc_garbage_len.remote()) == 0

def verify_generator(actor, fail):
for _ in range(100):
for ref in actor.gen.options(num_returns="streaming").remote(fail=fail):
try:
ray.get(ref)
except Exception:
pass
assert ray.get(actor.get_gc_garbage_len.remote()) == 0

print("Test regular actors")
verify_regular(A.remote(), True)
verify_regular(A.remote(), False)
print("Test regular actors + generator")
verify_generator(A.remote(), True)
verify_generator(A.remote(), False)

# Test threaded actor
print("Test threaded actors")
verify_regular(A.options(max_concurrency=10).remote(), True)
verify_regular(A.options(max_concurrency=10).remote(), False)
print("Test threaded actors + generator")
verify_generator(A.options(max_concurrency=10).remote(), True)
verify_generator(A.options(max_concurrency=10).remote(), False)

# Test async actor
print("Test async actors")
verify_regular(AsyncActor.remote(), True)
verify_regular(AsyncActor.remote(), False)
print("Test async actors + generator")
verify_generator(AsyncActor.remote(), True)
verify_generator(AsyncActor.remote(), False)
assert len(list_actors()) == 12


if __name__ == "__main__":
import os

Expand Down

0 comments on commit 93e3e96

Please sign in to comment.