diff --git a/python/ray/workflow/common.py b/python/ray/workflow/common.py index 6461fbab251e..2813498ef549 100644 --- a/python/ray/workflow/common.py +++ b/python/ray/workflow/common.py @@ -216,12 +216,10 @@ def make( ): if max_retries is None: max_retries = 3 - elif not isinstance(max_retries, int) or max_retries < 1: - raise ValueError("max_retries should be greater or equal to 1.") + elif not isinstance(max_retries, int) or max_retries < -1: + raise ValueError("'max_retries' only accepts 0, -1 or a positive integer.") if catch_exceptions is None: catch_exceptions = False - if max_retries is None: - max_retries = 3 if not isinstance(checkpoint, bool) and checkpoint is not None: raise ValueError("'checkpoint' should be None or a boolean.") if ray_options is None: diff --git a/python/ray/workflow/step_executor.py b/python/ray/workflow/step_executor.py index b7aa091593a1..ec261bac5e89 100644 --- a/python/ray/workflow/step_executor.py +++ b/python/ray/workflow/step_executor.py @@ -318,28 +318,39 @@ def _wrap_run( """ exception = None result = None + done = False # max_retries are for application level failure. # For ray failure, we should use max_retries. - for i in range(runtime_options.max_retries): - logger.info( - f"{get_step_status_info(WorkflowStatus.RUNNING)}" - f"\t[{i + 1}/{runtime_options.max_retries}]" - ) + i = 0 + while not done: + if i == 0: + logger.info(f"{get_step_status_info(WorkflowStatus.RUNNING)}") + else: + total_retries = ( + runtime_options.max_retries + if runtime_options.max_retries != -1 + else "inf" + ) + logger.info( + f"{get_step_status_info(WorkflowStatus.RUNNING)}" + f"\tretries: [{i}/{total_retries}]" + ) try: result = func(*args, **kwargs) exception = None - break + done = True except BaseException as e: - if i + 1 == runtime_options.max_retries: + if i == runtime_options.max_retries: retry_msg = "Maximum retry reached, stop retry." + exception = e + done = True else: retry_msg = "The step will be retried." + i += 1 logger.error( f"{workflow_context.get_name()} failed with error message" f" {e}. {retry_msg}" ) - exception = e - step_type = runtime_options.step_type if runtime_options.catch_exceptions: if step_type == StepType.FUNCTION: diff --git a/python/ray/workflow/tests/test_basic_workflows.py b/python/ray/workflow/tests/test_basic_workflows.py index 7e03a323e8e5..bd046b4263c2 100644 --- a/python/ray/workflow/tests/test_basic_workflows.py +++ b/python/ray/workflow/tests/test_basic_workflows.py @@ -197,19 +197,19 @@ def unstable_step(): return v with pytest.raises(Exception): - unstable_step.options(max_retries=-1).step().run() + unstable_step.options(max_retries=-2).step().run() with pytest.raises(Exception): - unstable_step.options(max_retries=3).step().run() - assert 10 == unstable_step.options(max_retries=8).step().run() + unstable_step.options(max_retries=2).step().run() + assert 10 == unstable_step.options(max_retries=7).step().run() (tmp_path / "test").write_text("0") (ret, err) = ( - unstable_step.options(max_retries=3, catch_exceptions=True).step().run() + unstable_step.options(max_retries=2, catch_exceptions=True).step().run() ) assert ret is None assert isinstance(err, ValueError) (ret, err) = ( - unstable_step.options(max_retries=8, catch_exceptions=True).step().run() + unstable_step.options(max_retries=7, catch_exceptions=True).step().run() ) assert ret == 10 assert err is None @@ -218,7 +218,7 @@ def unstable_step(): def test_step_failure_decorator(workflow_start_regular_shared, tmp_path): (tmp_path / "test").write_text("0") - @workflow.step(max_retries=11) + @workflow.step(max_retries=10) def unstable_step(): v = int((tmp_path / "test").read_text()) (tmp_path / "test").write_text(f"{v + 1}") @@ -244,7 +244,7 @@ def unstable_step_exception(): (tmp_path / "test").write_text("0") - @workflow.step(catch_exceptions=True, max_retries=4) + @workflow.step(catch_exceptions=True, max_retries=3) def unstable_step_exception(): v = int((tmp_path / "test").read_text()) (tmp_path / "test").write_text(f"{v + 1}") diff --git a/python/ray/workflow/tests/test_basic_workflows_2.py b/python/ray/workflow/tests/test_basic_workflows_2.py index 214368614e44..0c82fbe843e5 100644 --- a/python/ray/workflow/tests/test_basic_workflows_2.py +++ b/python/ray/workflow/tests/test_basic_workflows_2.py @@ -109,7 +109,7 @@ def incr(): return 10 with pytest.raises(ray.exceptions.RaySystemError): - incr.options(max_retries=1).step().run("incr") + incr.options(max_retries=0).step().run("incr") assert cnt_file.read_text() == "1" diff --git a/python/ray/workflow/virtual_actor_class.py b/python/ray/workflow/virtual_actor_class.py index 3d5efbf8e3b1..ee861bfb410e 100644 --- a/python/ray/workflow/virtual_actor_class.py +++ b/python/ray/workflow/virtual_actor_class.py @@ -92,7 +92,7 @@ def run_async(self, *args, **kwargs) -> "ObjectRef": def options( self, *, - max_retries: int = 1, + max_retries: int = 0, catch_exceptions: bool = False, name: str = None, metadata: Dict[str, Any] = None, @@ -262,7 +262,7 @@ def step(self, *args, **kwargs): def options( self, *, - max_retries=1, + max_retries=0, catch_exceptions=False, name=None, metadata=None,