diff --git a/python/ray/data/_internal/compute.py b/python/ray/data/_internal/compute.py index 99635e6892a6..50b61e89b8b1 100644 --- a/python/ray/data/_internal/compute.py +++ b/python/ray/data/_internal/compute.py @@ -136,7 +136,7 @@ def _apply( # Common wait for non-data refs. try: results = map_bar.fetch_until_complete(refs) - except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e: + except (ray.exceptions.RayError, KeyboardInterrupt) as e: # One or more mapper tasks failed, or we received a SIGINT signal # while waiting; either way, we cancel all map tasks. for ref in refs: @@ -145,7 +145,10 @@ def _apply( for ref in refs: try: ray.get(ref) - except (ray.exceptions.RayTaskError, ray.exceptions.TaskCancelledError): + except ray.exceptions.RayError: + # Cancellation either succeeded, or the task had already failed with + # a different error, or cancellation failed. In all cases, we + # swallow the exception. pass # Reraise the original task failure exception. raise e from None @@ -418,7 +421,7 @@ def map_block_nosplit( except Exception as err: logger.exception(f"Error killing workers: {err}") finally: - raise e + raise e from None def get_compute(compute_spec: Union[str, ComputeStrategy]) -> ComputeStrategy: