From 4eaeac17228ef55e96bbbd3b7bcfb7f2300c7678 Mon Sep 17 00:00:00 2001 From: Siyuan Zhuang Date: Thu, 7 Apr 2022 18:15:40 -0700 Subject: [PATCH 1/2] update workflow events --- python/ray/workflow/api.py | 16 ++++---- python/ray/workflow/api.pyi | 6 ++- python/ray/workflow/tests/test_events.py | 48 ++++++++++++++---------- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/python/ray/workflow/api.py b/python/ray/workflow/api.py index ac879218c9fb..d8dc4f0e54a8 100644 --- a/python/ray/workflow/api.py +++ b/python/ray/workflow/api.py @@ -364,19 +364,19 @@ def get_status(workflow_id: str) -> WorkflowStatus: @PublicAPI(stability="beta") def wait_for_event( event_listener_type: EventListenerType, *args, **kwargs -) -> Workflow[Event]: +) -> "DAGNode[Event]": if not issubclass(event_listener_type, EventListener): raise TypeError( f"Event listener type is {event_listener_type.__name__}" ", which is not a subclass of workflow.EventListener" ) - @step + @ray.remote def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event: event_listener = event_listener_type() return asyncio_run(event_listener.poll_for_event(*args, **kwargs)) - @step + @ray.remote def message_committed( event_listener_type: EventListenerType, event: Event ) -> Event: @@ -384,22 +384,22 @@ def message_committed( asyncio_run(event_listener.event_checkpointed(event)) return event - return message_committed.step( - event_listener_type, get_message.step(event_listener_type, *args, **kwargs) + return message_committed.bind( + event_listener_type, get_message.bind(event_listener_type, *args, **kwargs) ) @PublicAPI(stability="beta") -def sleep(duration: float) -> Workflow[Event]: +def sleep(duration: float) -> "DAGNode[Event]": """ A workfow that resolves after sleeping for a given duration. """ - @step + @ray.remote def end_time(): return time.time() + duration - return wait_for_event(TimerListener, end_time.step()) + return wait_for_event(TimerListener, end_time.bind()) @PublicAPI(stability="beta") diff --git a/python/ray/workflow/api.pyi b/python/ray/workflow/api.pyi index e09eafc9e05b..6effa82910d0 100644 --- a/python/ray/workflow/api.pyi +++ b/python/ray/workflow/api.pyi @@ -8,6 +8,8 @@ from ray.workflow.storage import Storage from ray.workflow.virtual_actor_class import VirtualActorClass, VirtualActor from ray.workflow.common import WorkflowStatus +from ray.experimental.dag import DAGNode + T0 = TypeVar("T0") T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -101,9 +103,9 @@ def resume_all(include_failed: bool) -> List[str]: ... def get_status(workflow_id: str) -> WorkflowStatus: ... -def wait_for_event(event_listener_type: EventListenerType, *args, **kwargs) -> Workflow: ... +def wait_for_event(event_listener_type: EventListenerType, *args, **kwargs) -> DAGNode: ... -def sleep(duration: float) -> Workflow: ... +def sleep(duration: float) -> DAGNode: ... @overload def get_metadata(workflow_id: str) -> Dict[str, Any]: ... diff --git a/python/ray/workflow/tests/test_events.py b/python/ray/workflow/tests/test_events.py index f7ee4440d65b..0da8e1d11602 100644 --- a/python/ray/workflow/tests/test_events.py +++ b/python/ray/workflow/tests/test_events.py @@ -11,15 +11,15 @@ def test_sleep(workflow_start_regular_shared): - @workflow.step - def sleep_helper(): - @workflow.step - def after_sleep(sleep_start_time, _): - return (sleep_start_time, time.time()) + @ray.remote + def after_sleep(sleep_start_time, _): + return sleep_start_time, time.time() - return after_sleep.step(time.time(), workflow.sleep(2)) + @ray.remote + def sleep_helper(): + return workflow.continuation(after_sleep.bind(time.time(), workflow.sleep(2))) - start, end = sleep_helper.step().run() + start, end = workflow.create(sleep_helper.bind()).run() duration = end - start assert 1 < duration @@ -31,7 +31,7 @@ def test_sleep_checkpointing(workflow_start_regular_shared): sleep_step = workflow.sleep(2) time.sleep(2) start_time = time.time() - sleep_step.run() + workflow.create(sleep_step).run() end_time = time.time() duration = end_time - start_time assert 1 < duration @@ -66,14 +66,16 @@ async def poll_for_event(self): await asyncio.sleep(0.1) return "event2" - @workflow.step + @ray.remote def trivial_step(arg1, arg2): return f"{arg1} {arg2}" event1_promise = workflow.wait_for_event(EventListener1) event2_promise = workflow.wait_for_event(EventListener2) - promise = trivial_step.step(event1_promise, event2_promise).run_async() + promise = workflow.create( + trivial_step.bind(event1_promise, event2_promise) + ).run_async() while not ( utils.check_global_mark("listener1") and utils.check_global_mark("listener2") @@ -106,17 +108,20 @@ async def poll_for_event(self): # Give the other step time to finish. await asyncio.sleep(1) - @workflow.step + @ray.remote def triggers_event(): utils.set_global_mark() - @workflow.step + @ray.remote def gather(*args): return args event_promise = workflow.wait_for_event(MyEventListener) - assert gather.step(event_promise, triggers_event.step()).run() == (None, None) + assert workflow.create(gather.bind(event_promise, triggers_event.bind())).run() == ( + None, + None, + ) @pytest.mark.parametrize( @@ -140,18 +145,21 @@ async def poll_for_event(self): await asyncio.sleep(0.1) utils.set_global_mark("event_returning") - @workflow.step + @ray.remote def triggers_event(): utils.set_global_mark() while not utils.check_global_mark("event_returning"): time.sleep(0.1) - @workflow.step + @ray.remote def gather(*args): return args event_promise = workflow.wait_for_event(MyEventListener) - assert gather.step(event_promise, triggers_event.step()).run() == (None, None) + assert workflow.create(gather.bind(event_promise, triggers_event.bind())).run() == ( + None, + None, + ) def test_crash_during_event_checkpointing(workflow_start_regular_shared): @@ -175,12 +183,12 @@ async def poll_for_event(self): async def event_checkpointed(self, event): utils.set_global_mark("committed") - @workflow.step + @ray.remote def wait_then_finish(arg): pass event_promise = workflow.wait_for_event(MyEventListener) - wait_then_finish.step(event_promise).run_async("workflow") + workflow.create(wait_then_finish.bind(event_promise)).run_async("workflow") while not utils.check_global_mark("time_to_die"): time.sleep(0.1) @@ -233,7 +241,7 @@ async def event_checkpointed(self, event): await asyncio.sleep(1000000) event_promise = workflow.wait_for_event(MyEventListener) - event_promise.run_async("workflow") + workflow.create(event_promise).run_async("workflow") while not utils.check_global_mark("first"): time.sleep(0.1) @@ -266,7 +274,7 @@ async def poll_for_event(self): await asyncio.sleep(1) utils.unset_global_mark() - promise = workflow.wait_for_event(MyEventListener).run_async("wf") + promise = workflow.create(workflow.wait_for_event(MyEventListener)).run_async("wf") assert workflow.get_status("wf") == workflow.WorkflowStatus.RUNNING From 779944dd0bb063660f5119b8d74d0dcb3017a3c4 Mon Sep 17 00:00:00 2001 From: Siyuan Zhuang Date: Fri, 8 Apr 2022 09:20:51 -0700 Subject: [PATCH 2/2] flaky tests --- python/ray/workflow/tests/test_virtual_actor_2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/ray/workflow/tests/test_virtual_actor_2.py b/python/ray/workflow/tests/test_virtual_actor_2.py index d63b4ec1281b..d11be53a62ec 100644 --- a/python/ray/workflow/tests/test_virtual_actor_2.py +++ b/python/ray/workflow/tests/test_virtual_actor_2.py @@ -305,6 +305,9 @@ def __setstate__(self, n): # guarentee obj1's finish. obj1 = c.incr.run_async(10) # noqa obj2 = c.incr.run(10) # noqa + # TODO(suquark): The test is flaky sometimes (only on CI), which might indicates + # some bugs. This is a workaroundde temporarily. + time.sleep(3) assert c.get.run() == 20