diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index afd8fe5606a6..c9a01a14a316 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -81,6 +81,10 @@ def __init__( ray_remote_args, ) self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args) + self._ray_actor_task_remote_args = {} + actor_task_errors = DataContext.get_current().actor_task_retry_on_errors + if actor_task_errors: + self._ray_actor_task_remote_args["retry_exceptions"] = actor_task_errors self._min_rows_per_bundle = min_rows_per_bundle # Create autoscaling policy from compute strategy. @@ -194,9 +198,11 @@ def _dispatch_tasks(self): task_idx=self._next_data_task_idx, target_max_block_size=self.actual_target_max_block_size, ) - gen = actor.submit.options(num_returns="streaming", name=self.name).remote( - DataContext.get_current(), ctx, *input_blocks - ) + gen = actor.submit.options( + num_returns="streaming", + name=self.name, + **self._ray_actor_task_remote_args, + ).remote(DataContext.get_current(), ctx, *input_blocks) def _task_done_callback(actor_to_return): # Return the actor that was running the task to the pool. diff --git a/python/ray/data/context.py b/python/ray/data/context.py index 642313703c9a..ce2171d986c6 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -1,6 +1,6 @@ import os import threading -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import ray from ray._private.ray_constants import env_integer @@ -159,6 +159,12 @@ "AWS Error SLOW_DOWN", ] +# The application-level errors that actor task would retry. +# Default to `False` to not retry on any errors. +# Set to `True` to retry all errors, or set to a list of errors to retry. +# This follows same format as `retry_exceptions` in Ray Core. +DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS = False + @DeveloperAPI class DataContext: @@ -201,6 +207,7 @@ def __init__( enable_get_object_locations_for_metrics: bool, use_runtime_metrics_scheduling: bool, write_file_retry_on_errors: List[str], + actor_task_retry_on_errors: Union[bool, List[BaseException]], ): """Private constructor (use get_current() instead).""" self.target_max_block_size = target_max_block_size @@ -239,6 +246,7 @@ def __init__( ) self.use_runtime_metrics_scheduling = use_runtime_metrics_scheduling self.write_file_retry_on_errors = write_file_retry_on_errors + self.actor_task_retry_on_errors = actor_task_retry_on_errors # The additonal ray remote args that should be added to # the task-pool-based data tasks. self._task_pool_data_task_remote_args: Dict[str, Any] = {} @@ -309,6 +317,7 @@ def get_current() -> "DataContext": enable_get_object_locations_for_metrics=DEFAULT_ENABLE_GET_OBJECT_LOCATIONS_FOR_METRICS, # noqa E501 use_runtime_metrics_scheduling=DEFAULT_USE_RUNTIME_METRICS_SCHEDULING, # noqa: E501 write_file_retry_on_errors=DEFAULT_WRITE_FILE_RETRY_ON_ERRORS, + actor_task_retry_on_errors=DEFAULT_ACTOR_TASK_RETRY_ON_ERRORS, ) return _default_context diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index e1197a93ca37..97d0fc12d1b1 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -217,6 +217,27 @@ def mapper(x): ds.map(mapper).materialize() +def test_actor_task_failure(shutdown_only, restore_data_context): + ray.init(num_cpus=2) + + ctx = DataContext.get_current() + ctx.actor_task_retry_on_errors = [ValueError] + + ds = ray.data.from_items([0, 10], parallelism=2) + + class Mapper: + def __init__(self): + self._counter = 0 + + def __call__(self, x): + if self._counter < 2: + self._counter += 1 + raise ValueError("oops") + return x + + ds.map_batches(Mapper, concurrency=1).materialize() + + def test_concurrency(shutdown_only): ray.init(num_cpus=6) ds = ray.data.range(10, parallelism=10)