Skip to content

Commit

Permalink
[workflow] Fix event loop can't find in thread (ray-project#22363)
Browse files Browse the repository at this point in the history
Event loop will only be set in main thread by default and this will make workflow unable to work if it's called in thread other than main thread which can happen when it's called from a library (for example ray serve).
This PR fixed it.
  • Loading branch information
fishbone authored and simonsays1980 committed Feb 27, 2022
1 parent ecf1f04 commit cb80e39
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 14 deletions.
8 changes: 3 additions & 5 deletions python/ray/workflow/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
import os
import types
Expand All @@ -22,6 +21,7 @@
WorkflowNotFoundError,
WorkflowStepRuntimeOptions,
StepType,
asyncio_run,
)
from ray.workflow import serialization
from ray.workflow.event_listener import EventListener, EventListenerType, TimerListener
Expand Down Expand Up @@ -354,16 +354,14 @@ def wait_for_event(
@step
def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event:
event_listener = event_listener_type()
loop = asyncio.get_event_loop()
return loop.run_until_complete(event_listener.poll_for_event(*args, **kwargs))
return asyncio_run(event_listener.poll_for_event(*args, **kwargs))

@step
def message_committed(
event_listener_type: EventListenerType, event: Event
) -> Event:
event_listener = event_listener_type()
loop = asyncio.get_event_loop()
loop.run_until_complete(event_listener.event_checkpointed(event))
asyncio_run(event_listener.event_checkpointed(event))
return event

return message_committed.step(
Expand Down
11 changes: 11 additions & 0 deletions python/ray/workflow/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import asyncio

from ray import cloudpickle
from collections import deque
from enum import Enum, unique
Expand All @@ -23,6 +25,15 @@
STORAGE_ACTOR_NAME = "StorageManagementActor"


def asyncio_run(coro):
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)


def get_module(f):
return f.__module__ if hasattr(f, "__module__") else "__anonymous_module__"

Expand Down
4 changes: 2 additions & 2 deletions python/ray/workflow/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _put_helper(
)
paths = obj_id_to_paths(workflow_id, identifier)
promise = dump_to_storage(paths, obj, workflow_id, storage, update_existing=False)
return asyncio.get_event_loop().run_until_complete(promise)
return common.asyncio_run(promise)


def _reduce_objectref(
Expand Down Expand Up @@ -207,7 +207,7 @@ class ObjectRefPickler(cloudpickle.CloudPickler):
@ray.remote
def _load_ref_helper(key: str, storage: storage.Storage):
# TODO(Alex): We should stream the data directly into `cloudpickle.load`.
serialized = asyncio.get_event_loop().run_until_complete(storage.get(key))
serialized = common.asyncio_run(storage.get(key))
return cloudpickle.loads(serialized)


Expand Down
3 changes: 2 additions & 1 deletion python/ray/workflow/step_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
StepID,
WorkflowData,
WorkflowStaticRef,
asyncio_run,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -247,7 +248,7 @@ def commit_step(
# its input (again).
if w.ref is None:
tasks.append(_write_step_inputs(store, w.step_id, w.data))
asyncio.get_event_loop().run_until_complete(asyncio.gather(*tasks))
asyncio_run(asyncio.gather(*tasks))

context = workflow_context.get_workflow_step_context()
store.save_step_output(
Expand Down
22 changes: 22 additions & 0 deletions python/ray/workflow/tests/test_basic_workflows_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,28 @@ def join(*a):
assert "4" == join.step(a, a, a, a).run()


def test_run_off_main_thread(workflow_start_regular):
@workflow.step
def fake_data(num: int):
return list(range(num))

succ = False

# Start new thread here ⚠️
def run():
global succ
# Setup the workflow.
data = fake_data.step(10)
assert data.run(workflow_id="run") == list(range(10))

import threading

t = threading.Thread(target=run)
t.start()
t.join()
assert workflow.get_status("run") == workflow.SUCCESSFUL


if __name__ == "__main__":
import sys

Expand Down
7 changes: 1 addition & 6 deletions python/ray/workflow/workflow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
WorkflowRef,
WorkflowNotFoundError,
WorkflowStepRuntimeOptions,
asyncio_run,
)
from ray.workflow import workflow_context
from ray.workflow import serialization
Expand Down Expand Up @@ -55,12 +56,6 @@
DUPLICATE_NAME_COUNTER = "duplicate_name_counter"


# TODO: Get rid of this and use asyncio.run instead once we don't support py36
def asyncio_run(coro):
loop = asyncio.get_event_loop()
return loop.run_until_complete(coro)


@dataclass
class StepInspectResult:
# The step output checkpoint exists and valid. If this field
Expand Down

0 comments on commit cb80e39

Please sign in to comment.