Skip to content

Commit

Permalink
(TaskPool) Cancel all transformation tasks when one task fails or whe…
Browse files Browse the repository at this point in the history
…n SIGINT is received.
  • Loading branch information
clarkzinzow committed Sep 30, 2021
1 parent 828f5d2 commit 2b6c765
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
19 changes: 17 additions & 2 deletions python/ray/data/impl/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,23 @@ def apply(self, fn: Any, remote_args: dict,
]
new_blocks, new_metadata = zip(*refs)

map_bar.block_until_complete(list(new_blocks))
new_metadata = ray.get(list(new_metadata))
new_metadata = list(new_metadata)
try:
new_metadata = map_bar.fetch_until_complete(new_metadata)
except (ray.exceptions.RayTaskError, KeyboardInterrupt) as e:
# One or more mapper tasks failed, or we received a SIGINT signal
# while waiting; either way, we cancel all map tasks.
for ref in new_metadata:
ray.cancel(ref)
# Wait until all tasks have failed or been cancelled.
for ref in new_metadata:
try:
ray.get(ref)
except (ray.exceptions.RayTaskError,
ray.exceptions.TaskCancelledError):
pass
# Reraise the original task failure exception.
raise e from None
return BlockList(list(new_blocks), list(new_metadata))


Expand Down
12 changes: 11 additions & 1 deletion python/ray/data/impl/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Any

import ray
from ray.types import ObjectRef
Expand Down Expand Up @@ -50,6 +50,16 @@ def block_until_complete(self, remaining: List[ObjectRef]) -> None:
done, remaining = ray.wait(remaining, fetch_local=False)
self.update(len(done))

def fetch_until_complete(self, refs: List[ObjectRef]) -> List[Any]:
ref_to_result = {}
remaining = refs
while remaining:
done, remaining = ray.wait(remaining, fetch_local=True)
for ref, result in zip(done, ray.get(done)):
ref_to_result[ref] = result
self.update(len(done))
return [ref_to_result[ref] for ref in refs]

def set_description(self, name: str) -> None:
if self._bar:
self._bar.set_description(name)
Expand Down
13 changes: 13 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ def __call__(self, x):
assert len(actor_reuse) == 10, actor_reuse


def test_transform_failure(shutdown_only):
ray.init(num_cpus=2)
ds = ray.data.from_items([0, 10], parallelism=2)

def mapper(x):
time.sleep(x)
assert False
return x

with pytest.raises(ray.exceptions.RayTaskError):
ds.map(mapper)


@pytest.mark.parametrize(
"block_sizes,num_splits",
[
Expand Down

0 comments on commit 2b6c765

Please sign in to comment.