Skip to content

Commit

Permalink
[Train] Improvements to fault tolerance (#22511)
Browse files Browse the repository at this point in the history
Various improvements to Ray Train fault tolerance.

Add more log statements for better debugging of Ray Train failure handling.
Fixes [Bug] [Train] Cannot reproduce fault-tolerance, script hangs upon any node shutdown #22349.
Simplifies fault tolerance by removing backend specific handle_failure. If any workers have failed, all workers will be restarted and training will continue from the last checkpoint.
Also adds a test for fault tolerance with an actual torch example. When testing locally, the test hangs before the fix, but passes after.
  • Loading branch information
amogkam authored Mar 29, 2022
1 parent da7901f commit 0b8c219
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 30 deletions.
32 changes: 10 additions & 22 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,6 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass

def handle_failure(
self,
worker_group: WorkerGroup,
failed_worker_indexes: List[int],
backend_config: BackendConfig,
):
"""Logic for handling failures.
By default, restart all workers.
"""
worker_group.shutdown()
worker_group.start()
self.on_start(worker_group, backend_config)

@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
Expand Down Expand Up @@ -184,6 +170,10 @@ def start(
self._backend.on_start(self.worker_group, self._backend_config)
except RayActorError as exc:
logger.exception(str(exc))
logger.warning(
"Failure occurred during startup. Restarting all workers and "
"attempting to startup again."
)
self._increment_failures()
self._restart()

Expand Down Expand Up @@ -560,18 +550,16 @@ def get_with_failure_handling(self, remote_values):
Returns:
The resolved objects represented by the passed in ObjectRefs.
"""
success, failed_worker_indexes = check_for_failure(remote_values)
success = check_for_failure(remote_values)
if success:
return ray.get(remote_values)
else:
self._increment_failures()
try:
self._backend.handle_failure(
self.worker_group, failed_worker_indexes, self._backend_config
)
except RayActorError as exc:
logger.exception(str(exc))
self._restart()
logger.warning(
"Failure identified during training. Restarting all workers and "
"continuing training from latest checkpoint."
)
self._restart()
raise TrainingWorkerError

def shutdown(self):
Expand Down
18 changes: 18 additions & 0 deletions python/ray/train/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ def test_torch_linear(ray_start_2_cpus, num_workers):
assert result[-1]["loss"] < result[0]["loss"]


def test_torch_linear_failure(ray_start_2_cpus):
num_workers = 2
epochs = 3

trainer = Trainer("torch", num_workers=num_workers)
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": epochs}
trainer.start()
kill_callback = KillCallback(fail_on=1, trainer=trainer)
results = trainer.run(linear_train_func, config, callbacks=[kill_callback])
trainer.shutdown()

assert len(results) == num_workers

for result in results:
assert len(result) == epochs
assert result[-1]["loss"] < result[0]["loss"]


def test_torch_fashion_mnist(ray_start_2_cpus):
num_workers = 2
epochs = 3
Expand Down
7 changes: 7 additions & 0 deletions python/ray/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,13 @@ def _run_with_error_handling(self, func: Callable):
return func()
except TrainingWorkerError:
# Workers have already been restarted.
logger.info(
"Workers have been successfully restarted. Resuming "
"training from latest checkpoint."
)
logger.debug(
f"Latest checkpoint: {self._checkpoint_manager.latest_checkpoint}"
)
self._start_training(
self._train_func,
self._run_dir,
Expand Down
17 changes: 9 additions & 8 deletions python/ray/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,34 @@
logger = logging.getLogger(__name__)


def check_for_failure(remote_values: List[ObjectRef]) -> Tuple[bool, List[int]]:
def check_for_failure(remote_values: List[ObjectRef]) -> bool:
"""Check for actor failure when retrieving the remote values.
Args:
remote_values (list): List of object references from Ray actor methods.
Returns:
Returns Tuple of success boolean and list of workers indexes that fail.
True if evaluating all object references is successful, False otherwise.
"""
unfinished = remote_values.copy()
dead_worker_indexes = [] # Store the indexes of the failed workers.

while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)

# If a failure occurs the ObjectRef will be marked as finished.
# Calling ray.get will expose the failure as a RayActorError.
for object_ref in finished:
# Everything in finished has either failed or completed
# successfully.
try:
ray.get(object_ref)
except RayActorError as exc:
logger.exception(str(exc))
failed_actor_rank = remote_values.index(object_ref)
logger.info(f"Worker {failed_actor_rank} has failed.")
dead_worker_indexes.append(failed_actor_rank)
if len(dead_worker_indexes) > 0:
return False, dead_worker_indexes
else:
return True, []
return False

return True


def get_address_and_port() -> Tuple[str, int]:
Expand Down

0 comments on commit 0b8c219

Please sign in to comment.