diff --git a/python/ray/data/impl/compute.py b/python/ray/data/impl/compute.py index f3ef94f4746c..8f0a7fb8e41f 100644 --- a/python/ray/data/impl/compute.py +++ b/python/ray/data/impl/compute.py @@ -51,8 +51,23 @@ def apply(self, fn: Any, remote_args: dict, ] new_blocks, new_metadata = zip(*refs) - map_bar.block_until_complete(list(new_blocks)) - new_metadata = ray.get(list(new_metadata)) + new_metadata = list(new_metadata) + try: + new_metadata = map_bar.fetch_until_complete(new_metadata) + except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e: + # One or more mapper tasks failed, or we received a SIGINT signal + # while waiting; either way, we cancel all map tasks. + for ref in new_metadata: + ray.cancel(ref) + # Wait until all tasks have failed or been cancelled. + for ref in new_metadata: + try: + ray.get(ref) + except (ray.exceptions.RayTaskError, + ray.exceptions.TaskCancelledError): + pass + # Reraise the original task failure exception. + raise e from None return BlockList(list(new_blocks), list(new_metadata)) diff --git a/python/ray/data/impl/progress_bar.py b/python/ray/data/impl/progress_bar.py index c9c1caa43cb5..fc28da681f3e 100644 --- a/python/ray/data/impl/progress_bar.py +++ b/python/ray/data/impl/progress_bar.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any import ray from ray.types import ObjectRef @@ -50,6 +50,16 @@ def block_until_complete(self, remaining: List[ObjectRef]) -> None: done, remaining = ray.wait(remaining, fetch_local=False) self.update(len(done)) + def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]: + ref_to_result = {} + remaining = refs + while remaining: + done, remaining = ray.wait(remaining, fetch_local=True) + for ref, result in zip(done, ray.get(done)): + ref_to_result[ref] = result + self.update(len(done)) + return [ref_to_result[ref] for ref in refs] + def set_description(self, name: str) -> None: if self._bar: self._bar.set_description(name) diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index a315f004cacb..9bb745ca0862 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -144,6 +144,19 @@ def __call__(self, x): assert len(actor_reuse) == 10, actor_reuse +def test_transform_failure(shutdown_only): + ray.init(num_cpus=2) + ds = ray.data.from_items([0, 10], parallelism=2) + + def mapper(x): + time.sleep(x) + assert False + return x + + with pytest.raises(ray.exceptions.RayTaskError): + ds.map(mapper) + + @pytest.mark.parametrize( "block_sizes,num_splits", [