Skip to content

Commit

Permalink
[Workflow] Support "retry_exceptions" of Ray tasks (#26913)
Browse files Browse the repository at this point in the history
* support 'retry_exceptions'

Signed-off-by: Siyuan Zhuang <[email protected]>

* add test

Signed-off-by: Siyuan Zhuang <[email protected]>

* add doc

Signed-off-by: Siyuan Zhuang <[email protected]>

* fix

Signed-off-by: Siyuan Zhuang <[email protected]>

* typo

Signed-off-by: Siyuan Zhuang <[email protected]>
  • Loading branch information
suquark authored Jul 25, 2022
1 parent a012033 commit 4a1ad3e
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 75 deletions.
37 changes: 24 additions & 13 deletions doc/source/workflows/basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,30 @@ Error handling

Workflows provides two ways to handle application-level exceptions: (1) automatic retry (as in normal Ray tasks), and (2) the ability to catch and handle exceptions.

The following error handling flags can be either set in the task decorator or via ``.options()``:
- If ``max_retries`` is given, the task will be retried for the given number of times if the workflow task failed.
- If ``retry_exceptions`` is True, then the workflow task retries both task crashes and application-level errors;
if it is ``False``, then the workflow task only retries task crashes.
- If ``catch_exceptions`` is True, the return value of the function will be converted to ``Tuple[Optional[T], Optional[Exception]]``.
It can be combined with ``max_retries`` to retry a given number of times before returning the result tuple.

``max_retries`` and ``retry_exceptions`` are also Ray task options,
so they should be used inside the Ray remote decorator. Here is how you could use them:

.. code-block:: python
# specify in decorator
@workflow.options(catch_exceptions=True)
@ray.remote(max_retries=5, retry_exceptions=True)
def faulty_function():
pass
# specify in .options()
faulty_function.options(max_retries=3, retry_exceptions=False,
**workflow.options(catch_exceptions=False))
.. note:: By default ``retry_exceptions`` is ``False``, ``max_retries`` is ``3``.

Here is one example:

.. code-block:: python
Expand Down Expand Up @@ -244,18 +267,6 @@ The following error handling flags can be either set in the task decorator or vi
workflow.run(handle_errors.bind(r2))
- If ``max_retries`` is given, the task will be retried for the given number of times if an exception is raised. It will only retry for the application level error. For system errors, it's controlled by ray. By default, ``max_retries`` is set to be 3.
- If ``catch_exceptions`` is True, the return value of the function will be converted to ``Tuple[Optional[T], Optional[Exception]]``. This can be combined with ``max_retries`` to try a given number of times before returning the result tuple.

The parameters can also be passed to the decorator

.. code-block:: python
@workflow.options(catch_exceptions=True)
@ray.remote(max_retries=5)
def faulty_function():
pass
Durability guarantees
---------------------

Expand Down
1 change: 0 additions & 1 deletion python/ray/workflow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,6 @@ def __init__(self, **workflow_options: Dict[str, Any]):
"name",
"metadata",
"catch_exceptions",
"allow_inplace",
"checkpoint",
}
invalid_keywords = set(workflow_options.keys()) - valid_options
Expand Down
10 changes: 5 additions & 5 deletions python/ray/workflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ class WorkflowStepRuntimeOptions:
step_type: "StepType"
# Whether the user want to handle the exception manually.
catch_exceptions: bool
# The num of retry for application exception.
# Whether application-level errors should be retried.
retry_exceptions: bool
# The num of retry for application exceptions & system failures.
max_retries: int
# Run the workflow step inplace.
allow_inplace: bool
# Checkpoint mode.
checkpoint: CheckpointModeType
# ray_remote options
Expand All @@ -172,7 +172,7 @@ def to_dict(self) -> Dict[str, Any]:
"step_type": self.step_type,
"max_retries": self.max_retries,
"catch_exceptions": self.catch_exceptions,
"allow_inplace": self.allow_inplace,
"retry_exceptions": self.retry_exceptions,
"checkpoint": self.checkpoint,
"ray_options": self.ray_options,
}
Expand All @@ -183,7 +183,7 @@ def from_dict(cls, value: Dict[str, Any]):
step_type=StepType[value["step_type"]],
max_retries=value["max_retries"],
catch_exceptions=value["catch_exceptions"],
allow_inplace=value["allow_inplace"],
retry_exceptions=value["retry_exceptions"],
checkpoint=value["checkpoint"],
ray_options=value["ray_options"],
)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/workflow/step_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_step_executor(step_options: "WorkflowStepRuntimeOptions"):
if step_options.step_type == StepType.FUNCTION:
# prevent automatic lineage reconstruction
step_options.ray_options["max_retries"] = 0
# prevent retrying exception by Ray
step_options.ray_options["retry_exceptions"] = False
executor = _workflow_step_executor_remote.options(
**step_options.ray_options
).remote
Expand Down
17 changes: 15 additions & 2 deletions python/ray/workflow/tests/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_step_failure(workflow_start_regular_shared, tmp_path):
@ray.remote(max_retries=10)
@ray.remote(max_retries=10, retry_exceptions=True)
def unstable_task_exception(n):
v = int((tmp_path / "test").read_text())
(tmp_path / "test").write_text(f"{v + 1}")
Expand All @@ -25,7 +25,7 @@ def unstable_task_crash(n):
os.kill(os.getpid(), 9)
return v

@ray.remote(max_retries=10)
@ray.remote(max_retries=10, retry_exceptions=True)
def unstable_task_crash_then_exception(n):
v = int((tmp_path / "test").read_text())
(tmp_path / "test").write_text(f"{v + 1}")
Expand Down Expand Up @@ -68,6 +68,19 @@ def unstable_task_crash_then_exception(n):
assert ret is None
assert err is not None

(tmp_path / "test").write_text("0")
with pytest.raises(workflow.WorkflowExecutionError):
workflow.run(unstable_task_exception.options(retry_exceptions=False).bind(10))

(tmp_path / "test").write_text("0")
workflow.run(unstable_task_crash.options(retry_exceptions=False).bind(10))

(tmp_path / "test").write_text("0")
with pytest.raises(workflow.WorkflowExecutionError):
workflow.run(
unstable_task_crash_then_exception.options(retry_exceptions=False).bind(10)
)


def test_nested_catch_exception(workflow_start_regular_shared):
@ray.remote
Expand Down
2 changes: 1 addition & 1 deletion python/ray/workflow/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def test_workflow_storage(workflow_start_regular):
step_options = WorkflowStepRuntimeOptions(
step_type=StepType.FUNCTION,
catch_exceptions=False,
retry_exceptions=True,
max_retries=0,
allow_inplace=False,
checkpoint=False,
ray_options={},
)
Expand Down
120 changes: 82 additions & 38 deletions python/ray/workflow/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,50 @@ def _iter_callstack(self, task_id: TaskID) -> Iterator[Tuple[TaskID, Task]]:
yield task_id, state.tasks[task_id]
task_id = state.task_context[task_id].creator_task_id

def _retry_failed_task(
self, workflow_id: str, failed_task_id: TaskID, exc: Exception
) -> bool:
state = self._state
is_application_error = isinstance(exc, RayTaskError)
options = state.tasks[failed_task_id].options
if not is_application_error or options.retry_exceptions:
if state.task_retries[failed_task_id] < options.max_retries:
state.task_retries[failed_task_id] += 1
logger.info(
f"Retry [{workflow_id}@{failed_task_id}] "
f"({state.task_retries[failed_task_id]}/{options.max_retries})"
)
state.construct_scheduling_plan(failed_task_id)
return True
return False

async def _catch_failed_task(
self, workflow_id: str, failed_task_id: TaskID, exc: Exception
) -> bool:
# lookup a creator task that catches the exception
is_application_error = isinstance(exc, RayTaskError)
exception_catcher = None
if is_application_error:
for t, task in self._iter_callstack(failed_task_id):
if task.options.catch_exceptions:
exception_catcher = t
break
if exception_catcher is not None:
logger.info(
f"Exception raised by '{workflow_id}@{failed_task_id}' is caught by "
f"'{workflow_id}@{exception_catcher}'"
)
# assign output to exception catching task;
# compose output with caught exception
await self._post_process_ready_task(
exception_catcher,
metadata=WorkflowExecutionMetadata(),
output_ref=WorkflowRef(failed_task_id, ray.put((None, exc))),
)
# TODO(suquark): cancel other running tasks?
return True
return False

async def _handle_ready_task(
self, fut: asyncio.Future, workflow_id: str, wf_store: "WorkflowStorage"
) -> None:
Expand All @@ -253,10 +297,8 @@ async def _handle_ready_task(
self._broadcast_exception(err)
raise err from None
except Exception as e:
is_application_error = False
if isinstance(e, RayTaskError):
reason = "an exception raised by the task"
is_application_error = True
elif isinstance(e, RayError):
reason = "a system error"
else:
Expand All @@ -266,50 +308,52 @@ async def _handle_ready_task(
f"[{workflow_id}@{task_id}]"
)

is_application_error = isinstance(e, RayTaskError)
options = state.tasks[task_id].options

# ---------------------- retry the task ----------------------
state.task_retries[task_id] += 1
total_retries = state.tasks[task_id].options.max_retries
if state.task_retries[task_id] <= total_retries:
logger.info(
f"Retry [{workflow_id}@{task_id}] "
f"({state.task_retries[task_id]}/{total_retries})"
)
state.construct_scheduling_plan(task_id)
return
if not is_application_error or options.retry_exceptions:
if state.task_retries[task_id] < options.max_retries:
state.task_retries[task_id] += 1
logger.info(
f"Retry [{workflow_id}@{task_id}] "
f"({state.task_retries[task_id]}/{options.max_retries})"
)
state.construct_scheduling_plan(task_id)
return

# ----------- retry used up, handle the task error -----------
# on error, the error is caught by this task
exception_catching_task_id = None
# lookup a creator task that catches the exception
exception_catcher = None
if is_application_error:
for t, task in self._iter_callstack(task_id):
if task.options.catch_exceptions:
exception_catching_task_id = t
exception_catcher = t
break
if exception_catcher is not None:
logger.info(
f"Exception raised by '{workflow_id}@{task_id}' is caught by "
f"'{workflow_id}@{exception_catcher}'"
)
# assign output to exception catching task;
# compose output with caught exception
await self._post_process_ready_task(
exception_catcher,
metadata=WorkflowExecutionMetadata(),
output_ref=WorkflowRef(task_id, ray.put((None, e))),
)
# TODO(suquark): cancel other running tasks?
return

if exception_catching_task_id is None:
# NOTE: We must update the workflow status before broadcasting
# the exception. Otherwise, the workflow status would still be
# 'RUNNING' if check the status immediately after the exception.
wf_store.update_workflow_status(WorkflowStatus.FAILED)
logger.error(f"Workflow '{workflow_id}' failed due to {e}")
err = WorkflowExecutionError(workflow_id)
err.__cause__ = e # chain exceptions
self._broadcast_exception(err)
raise err

logger.info(
f"Exception raised by '{workflow_id}@{task_id}' is caught by "
f"'{workflow_id}@{exception_catching_task_id}'"
)
# assign output to exception catching task;
# compose output with caught exception
await self._post_process_ready_task(
exception_catching_task_id,
metadata=WorkflowExecutionMetadata(),
output_ref=WorkflowRef(task_id, ray.put((None, e))),
)
# TODO(suquark): cancel other running tasks?
# ------------------- raise the task error -------------------
# NOTE: We must update the workflow status before broadcasting
# the exception. Otherwise, the workflow status would still be
# 'RUNNING' if check the status immediately after the exception.
wf_store.update_workflow_status(WorkflowStatus.FAILED)
logger.error(f"Workflow '{workflow_id}' failed due to {e}")
err = WorkflowExecutionError(workflow_id)
err.__cause__ = e # chain exceptions
self._broadcast_exception(err)
raise err

async def _post_process_ready_task(
self,
Expand Down
19 changes: 4 additions & 15 deletions python/ray/workflow/workflow_state_from_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,6 @@ def _node_visitor(node: Any) -> Any:
# should be passed recursively.
catch_exceptions = workflow_options.get("catch_exceptions", None)
if catch_exceptions is None:
# TODO(suquark): should we also handle exceptions from a "leaf node"
# in the continuation? For example, we have a workflow
# > @ray.remote
# > def A(): pass
# > @ray.remote
# > def B(x): return x
# > @ray.remote
# > def C(x): return workflow.continuation(B.bind(A.bind()))
# > dag = C.options(**workflow.options(catch_exceptions=True)).bind()
# Should C catches exceptions of A?
if node.get_stable_uuid() == dag_node.get_stable_uuid():
# 'catch_exception' context should be passed down to
# its direct continuation task.
Expand All @@ -115,17 +105,16 @@ def _node_visitor(node: Any) -> Any:
else:
catch_exceptions = False

# We do not need to check the validness of bound options, because
# Ray option has already checked them for us.
max_retries = bound_options.get("max_retries", 3)
if not isinstance(max_retries, int) or max_retries < -1:
raise ValueError(
"'max_retries' only accepts 0, -1 or a positive integer."
)
retry_exceptions = bound_options.get("retry_exceptions", False)

step_options = WorkflowStepRuntimeOptions(
step_type=StepType.FUNCTION,
catch_exceptions=catch_exceptions,
retry_exceptions=retry_exceptions,
max_retries=max_retries,
allow_inplace=False,
checkpoint=checkpoint,
ray_options=bound_options,
)
Expand Down

0 comments on commit 4a1ad3e

Please sign in to comment.