diff --git a/python/ray/tune/sync_client.py b/python/ray/tune/sync_client.py index c5f89eedee3b..b85dec32c1fb 100644 --- a/python/ray/tune/sync_client.py +++ b/python/ray/tune/sync_client.py @@ -5,6 +5,7 @@ import pathlib import subprocess import tempfile +import time import types import warnings @@ -30,7 +31,7 @@ def noop(*args): return -def get_sync_client(sync_function, delete_function=None): +def get_sync_client(sync_function, delete_function=None) -> Optional["SyncClient"]: """Returns a sync client. Args: @@ -58,7 +59,7 @@ def get_sync_client(sync_function, delete_function=None): return client_cls(sync_function, sync_function, delete_function) -def get_cloud_sync_client(remote_path): +def get_cloud_sync_client(remote_path) -> "CommandBasedClient": """Returns a CommandBasedClient that can sync to/from remote storage. Args: @@ -158,6 +159,10 @@ def wait(self): """Waits for current sync to complete, if asynchronously started.""" pass + def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5): + """Wait for current sync to complete or retries on error.""" + pass + def reset(self): """Resets state.""" pass @@ -251,6 +256,8 @@ def __init__( self.logfile = None self._closed = False self.cmd_process = None + # Keep track of last command for retry + self._last_cmd = None def set_logdir(self, logdir): """Sets the directory to log sync execution output in. @@ -273,6 +280,11 @@ def _get_logfile(self): else: return self.logfile + def _start_process(self, cmd: str) -> subprocess.Popen: + return subprocess.Popen( + cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile() + ) + def sync_up(self, source, target, exclude: Optional[List] = None): return self._execute(self.sync_up_template, source, target, exclude) @@ -284,13 +296,15 @@ def sync_down(self, source, target, exclude: Optional[List] = None): def delete(self, target): if self.is_running: - logger.warning("Last sync client cmd still in progress, skipping.") + logger.warning( + f"Last sync client cmd still in progress, " + f"skipping deletion of {target}" + ) return False final_cmd = self.delete_template.format(target=quote(target), options="") logger.debug("Running delete: {}".format(final_cmd)) - self.cmd_process = subprocess.Popen( - final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile() - ) + self._last_cmd = final_cmd + self.cmd_process = self._start_process(final_cmd) return True def wait(self): @@ -306,10 +320,28 @@ def wait(self): "Error message ({}): {}".format(args, code, error_msg) ) + def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5): + assert max_retries > 0 + for i in range(max_retries - 1): + try: + self.wait() + except TuneError as e: + logger.error( + f"Caught sync error: {e}. " + f"Retrying after sleeping for {backoff_s} seconds..." + ) + time.sleep(backoff_s) + self.cmd_process = self._start_process(self._last_cmd) + continue + return + self.cmd_process = None + raise TuneError(f"Failed sync even after {max_retries} retries.") + def reset(self): if self.is_running: logger.warning("Sync process still running but resetting anyways.") self.cmd_process = None + self._last_cmd = None def close(self): if self.logfile: @@ -329,7 +361,10 @@ def is_running(self): def _execute(self, sync_template, source, target, exclude: Optional[List] = None): """Executes sync_template on source and target.""" if self.is_running: - logger.warning("Last sync client cmd still in progress, skipping.") + logger.warning( + f"Last sync client cmd still in progress, " + f"skipping sync from {source} to {target}." + ) return False if exclude and self.exclude_template: @@ -355,9 +390,8 @@ def _to_regex(pattern: str) -> str: source=quote(source), target=quote(target), options=option_str ) logger.debug("Running sync: {}".format(final_cmd)) - self.cmd_process = subprocess.Popen( - final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile() - ) + self._last_cmd = final_cmd + self.cmd_process = self._start_process(final_cmd) return True @staticmethod diff --git a/python/ray/tune/tests/test_sync.py b/python/ray/tune/tests/test_sync.py index 99151da2546b..2c4435963d5e 100644 --- a/python/ray/tune/tests/test_sync.py +++ b/python/ray/tune/tests/test_sync.py @@ -415,6 +415,41 @@ def __init__(self, id, logdir): trial_syncer = syncer_callback._get_trial_syncer(trial) self.assertEqual(trial_syncer.sync_client, NOOP) + def testSyncWaitRetry(self): + class CountingClient(CommandBasedClient): + def __init__(self, *args, **kwargs): + self._sync_ups = 0 + self._sync_downs = 0 + super(CountingClient, self).__init__(*args, **kwargs) + + def _start_process(self, cmd): + if "UPLOAD" in cmd: + self._sync_ups += 1 + elif "DOWNLOAD" in cmd: + self._sync_downs += 1 + if self._sync_downs == 1: + self._last_cmd = "echo DOWNLOAD && true" + return super(CountingClient, self)._start_process(cmd) + + client = CountingClient( + "echo UPLOAD {source} {target} && false", + "echo DOWNLOAD {source} {target} && false", + "echo DELETE {target}", + ) + + # Fail always + with self.assertRaisesRegex(TuneError, "Failed sync even after"): + client.sync_up("test_source", "test_target") + client.wait_or_retry(max_retries=3, backoff_s=0) + + self.assertEquals(client._sync_ups, 3) + + # Succeed after second try + client.sync_down("test_source", "test_target") + client.wait_or_retry(max_retries=3, backoff_s=0) + + self.assertEquals(client._sync_downs, 2) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index fddb442ca268..88e143957e84 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -441,7 +441,7 @@ def _maybe_save_to_cloud(self, checkpoint_dir): self.storage_client.sync_up( checkpoint_dir, self._storage_path(checkpoint_dir) ) - self.storage_client.wait() + self.storage_client.wait_or_retry() def save_to_object(self): """Saves the current model state to a Python object. @@ -488,7 +488,7 @@ def restore(self, checkpoint_path): os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir), os.path.join(self.logdir, rel_checkpoint_dir), ) - self.storage_client.wait() + self.storage_client.wait_or_retry() # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): @@ -557,6 +557,7 @@ def delete_checkpoint(self, checkpoint_path): else: if self.uses_cloud_checkpointing: self.storage_client.delete(self._storage_path(checkpoint_dir)) + self.storage_client.wait_or_retry() if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir) diff --git a/release/tune_tests/cloud_tests/workloads/run_cloud_test.py b/release/tune_tests/cloud_tests/workloads/run_cloud_test.py index 73755ba766d1..716c06fe0b9f 100644 --- a/release/tune_tests/cloud_tests/workloads/run_cloud_test.py +++ b/release/tune_tests/cloud_tests/workloads/run_cloud_test.py @@ -266,7 +266,12 @@ def send_signal_after_wait(process: subprocess.Popen, signal: int, wait: int = 3 time.sleep(wait) if process.poll() is not None: - raise RuntimeError(f"Process {process.pid} already terminated.") + raise RuntimeError( + f"Process {process.pid} already terminated. This usually means " + f"that some of the trials ERRORed (e.g. because they couldn't be " + f"restored. Try re-running this test to see if this fixes the " + f"issue." + ) print(f"Sending signal {signal} to process {process.pid}") process.send_signal(signal)