From cb80e39349bf4979329c8e13a1da792f8b26c03f Mon Sep 17 00:00:00 2001 From: Yi Cheng <74173148+iycheng@users.noreply.github.com> Date: Mon, 14 Feb 2022 23:31:32 -0800 Subject: [PATCH] [workflow] Fix event loop can't find in thread (#22363) Event loop will only be set in main thread by default and this will make workflow unable to work if it's called in thread other than main thread which can happen when it's called from a library (for example ray serve). This PR fixed it. --- python/ray/workflow/api.py | 8 +++---- python/ray/workflow/common.py | 11 ++++++++++ python/ray/workflow/serialization.py | 4 ++-- python/ray/workflow/step_executor.py | 3 ++- .../workflow/tests/test_basic_workflows_2.py | 22 +++++++++++++++++++ python/ray/workflow/workflow_storage.py | 7 +----- 6 files changed, 41 insertions(+), 14 deletions(-) diff --git a/python/ray/workflow/api.py b/python/ray/workflow/api.py index a5a9dc107321..aba0d1896906 100644 --- a/python/ray/workflow/api.py +++ b/python/ray/workflow/api.py @@ -1,4 +1,3 @@ -import asyncio import logging import os import types @@ -22,6 +21,7 @@ WorkflowNotFoundError, WorkflowStepRuntimeOptions, StepType, + asyncio_run, ) from ray.workflow import serialization from ray.workflow.event_listener import EventListener, EventListenerType, TimerListener @@ -354,16 +354,14 @@ def wait_for_event( @step def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event: event_listener = event_listener_type() - loop = asyncio.get_event_loop() - return loop.run_until_complete(event_listener.poll_for_event(*args, **kwargs)) + return asyncio_run(event_listener.poll_for_event(*args, **kwargs)) @step def message_committed( event_listener_type: EventListenerType, event: Event ) -> Event: event_listener = event_listener_type() - loop = asyncio.get_event_loop() - loop.run_until_complete(event_listener.event_checkpointed(event)) + asyncio_run(event_listener.event_checkpointed(event)) return event return message_committed.step( diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 27a30e8e96ea..3eda03e628ba 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -1,4 +1,6 @@ import base64 +import asyncio + from ray import cloudpickle from collections import deque from enum import Enum, unique @@ -23,6 +25,15 @@ STORAGE_ACTOR_NAME = "StorageManagementActor" +def asyncio_run(coro): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + + def get_module(f): return f.__module__ if hasattr(f, "__module__") else "__anonymous_module__" diff --git a/python/ray/workflow/serialization.py b/python/ray/workflow/serialization.py index 549fb71c61ad..81487aaca702 100644 --- a/python/ray/workflow/serialization.py +++ b/python/ray/workflow/serialization.py @@ -129,7 +129,7 @@ def _put_helper( ) paths = obj_id_to_paths(workflow_id, identifier) promise = dump_to_storage(paths, obj, workflow_id, storage, update_existing=False) - return asyncio.get_event_loop().run_until_complete(promise) + return common.asyncio_run(promise) def _reduce_objectref( @@ -207,7 +207,7 @@ class ObjectRefPickler(cloudpickle.CloudPickler): @ray.remote def _load_ref_helper(key: str, storage: storage.Storage): # TODO(Alex): We should stream the data directly into `cloudpickle.load`. - serialized = asyncio.get_event_loop().run_until_complete(storage.get(key)) + serialized = common.asyncio_run(storage.get(key)) return cloudpickle.loads(serialized) diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index e6dd64d0594a..fb2258033e55 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -26,6 +26,7 @@ StepID, WorkflowData, WorkflowStaticRef, + asyncio_run, ) if TYPE_CHECKING: @@ -247,7 +248,7 @@ def commit_step( # its input (again). if w.ref is None: tasks.append(_write_step_inputs(store, w.step_id, w.data)) - asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks)) + asyncio_run(asyncio.gather(*tasks)) context = workflow_context.get_workflow_step_context() store.save_step_output( diff --git a/python/ray/workflow/tests/test_basic_workflows_2.py b/python/ray/workflow/tests/test_basic_workflows_2.py index b0e0006bf52a..3bd71edb803f 100644 --- a/python/ray/workflow/tests/test_basic_workflows_2.py +++ b/python/ray/workflow/tests/test_basic_workflows_2.py @@ -333,6 +333,28 @@ def join(*a): assert "4" == join.step(a, a, a, a).run() +def test_run_off_main_thread(workflow_start_regular): + @workflow.step + def fake_data(num: int): + return list(range(num)) + + succ = False + + # Start new thread here ⚠️ + def run(): + global succ + # Setup the workflow. + data = fake_data.step(10) + assert data.run(workflow_id="run") == list(range(10)) + + import threading + + t = threading.Thread(target=run) + t.start() + t.join() + assert workflow.get_status("run") == workflow.SUCCESSFUL + + if __name__ == "__main__": import sys diff --git a/python/ray/workflow/workflow_storage.py b/python/ray/workflow/workflow_storage.py index 364591ecfa44..371dbe370270 100644 --- a/python/ray/workflow/workflow_storage.py +++ b/python/ray/workflow/workflow_storage.py @@ -20,6 +20,7 @@ WorkflowRef, WorkflowNotFoundError, WorkflowStepRuntimeOptions, + asyncio_run, ) from ray.workflow import workflow_context from ray.workflow import serialization @@ -55,12 +56,6 @@ DUPLICATE_NAME_COUNTER = "duplicate_name_counter" -# TODO: Get rid of this and use asyncio.run instead once we don't support py36 -def asyncio_run(coro): - loop = asyncio.get_event_loop() - return loop.run_until_complete(coro) - - @dataclass class StepInspectResult: # The step output checkpoint exists and valid. If this field