Skip to content

Commit

Permalink
[tune] Retry cloud sync up/down/delete on fail (ray-project#22029)
Browse files Browse the repository at this point in the history
  • Loading branch information
krfricke authored and simonsays1980 committed Feb 27, 2022
1 parent 4353b2e commit 7d7d119
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 13 deletions.
54 changes: 44 additions & 10 deletions python/ray/tune/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib
import subprocess
import tempfile
import time
import types
import warnings

Expand All @@ -30,7 +31,7 @@ def noop(*args):
return


def get_sync_client(sync_function, delete_function=None):
def get_sync_client(sync_function, delete_function=None) -> Optional["SyncClient"]:
"""Returns a sync client.
Args:
Expand Down Expand Up @@ -58,7 +59,7 @@ def get_sync_client(sync_function, delete_function=None):
return client_cls(sync_function, sync_function, delete_function)


def get_cloud_sync_client(remote_path):
def get_cloud_sync_client(remote_path) -> "CommandBasedClient":
"""Returns a CommandBasedClient that can sync to/from remote storage.
Args:
Expand Down Expand Up @@ -158,6 +159,10 @@ def wait(self):
"""Waits for current sync to complete, if asynchronously started."""
pass

def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
"""Wait for current sync to complete or retries on error."""
pass

def reset(self):
"""Resets state."""
pass
Expand Down Expand Up @@ -251,6 +256,8 @@ def __init__(
self.logfile = None
self._closed = False
self.cmd_process = None
# Keep track of last command for retry
self._last_cmd = None

def set_logdir(self, logdir):
"""Sets the directory to log sync execution output in.
Expand All @@ -273,6 +280,11 @@ def _get_logfile(self):
else:
return self.logfile

def _start_process(self, cmd: str) -> subprocess.Popen:
return subprocess.Popen(
cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)

def sync_up(self, source, target, exclude: Optional[List] = None):
return self._execute(self.sync_up_template, source, target, exclude)

Expand All @@ -284,13 +296,15 @@ def sync_down(self, source, target, exclude: Optional[List] = None):

def delete(self, target):
if self.is_running:
logger.warning("Last sync client cmd still in progress, skipping.")
logger.warning(
f"Last sync client cmd still in progress, "
f"skipping deletion of {target}"
)
return False
final_cmd = self.delete_template.format(target=quote(target), options="")
logger.debug("Running delete: {}".format(final_cmd))
self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)
self._last_cmd = final_cmd
self.cmd_process = self._start_process(final_cmd)
return True

def wait(self):
Expand All @@ -306,10 +320,28 @@ def wait(self):
"Error message ({}): {}".format(args, code, error_msg)
)

def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
assert max_retries > 0
for i in range(max_retries - 1):
try:
self.wait()
except TuneError as e:
logger.error(
f"Caught sync error: {e}. "
f"Retrying after sleeping for {backoff_s} seconds..."
)
time.sleep(backoff_s)
self.cmd_process = self._start_process(self._last_cmd)
continue
return
self.cmd_process = None
raise TuneError(f"Failed sync even after {max_retries} retries.")

def reset(self):
if self.is_running:
logger.warning("Sync process still running but resetting anyways.")
self.cmd_process = None
self._last_cmd = None

def close(self):
if self.logfile:
Expand All @@ -329,7 +361,10 @@ def is_running(self):
def _execute(self, sync_template, source, target, exclude: Optional[List] = None):
"""Executes sync_template on source and target."""
if self.is_running:
logger.warning("Last sync client cmd still in progress, skipping.")
logger.warning(
f"Last sync client cmd still in progress, "
f"skipping sync from {source} to {target}."
)
return False

if exclude and self.exclude_template:
Expand All @@ -355,9 +390,8 @@ def _to_regex(pattern: str) -> str:
source=quote(source), target=quote(target), options=option_str
)
logger.debug("Running sync: {}".format(final_cmd))
self.cmd_process = subprocess.Popen(
final_cmd, shell=True, stderr=subprocess.PIPE, stdout=self._get_logfile()
)
self._last_cmd = final_cmd
self.cmd_process = self._start_process(final_cmd)
return True

@staticmethod
Expand Down
35 changes: 35 additions & 0 deletions python/ray/tune/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,41 @@ def __init__(self, id, logdir):
trial_syncer = syncer_callback._get_trial_syncer(trial)
self.assertEqual(trial_syncer.sync_client, NOOP)

def testSyncWaitRetry(self):
class CountingClient(CommandBasedClient):
def __init__(self, *args, **kwargs):
self._sync_ups = 0
self._sync_downs = 0
super(CountingClient, self).__init__(*args, **kwargs)

def _start_process(self, cmd):
if "UPLOAD" in cmd:
self._sync_ups += 1
elif "DOWNLOAD" in cmd:
self._sync_downs += 1
if self._sync_downs == 1:
self._last_cmd = "echo DOWNLOAD && true"
return super(CountingClient, self)._start_process(cmd)

client = CountingClient(
"echo UPLOAD {source} {target} && false",
"echo DOWNLOAD {source} {target} && false",
"echo DELETE {target}",
)

# Fail always
with self.assertRaisesRegex(TuneError, "Failed sync even after"):
client.sync_up("test_source", "test_target")
client.wait_or_retry(max_retries=3, backoff_s=0)

self.assertEquals(client._sync_ups, 3)

# Succeed after second try
client.sync_down("test_source", "test_target")
client.wait_or_retry(max_retries=3, backoff_s=0)

self.assertEquals(client._sync_downs, 2)


if __name__ == "__main__":
import pytest
Expand Down
5 changes: 3 additions & 2 deletions python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def _maybe_save_to_cloud(self, checkpoint_dir):
self.storage_client.sync_up(
checkpoint_dir, self._storage_path(checkpoint_dir)
)
self.storage_client.wait()
self.storage_client.wait_or_retry()

def save_to_object(self):
"""Saves the current model state to a Python object.
Expand Down Expand Up @@ -488,7 +488,7 @@ def restore(self, checkpoint_path):
os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
os.path.join(self.logdir, rel_checkpoint_dir),
)
self.storage_client.wait()
self.storage_client.wait_or_retry()

# Ensure TrialCheckpoints are converted
if isinstance(checkpoint_path, TrialCheckpoint):
Expand Down Expand Up @@ -557,6 +557,7 @@ def delete_checkpoint(self, checkpoint_path):
else:
if self.uses_cloud_checkpointing:
self.storage_client.delete(self._storage_path(checkpoint_dir))
self.storage_client.wait_or_retry()

if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
Expand Down
7 changes: 6 additions & 1 deletion release/tune_tests/cloud_tests/workloads/run_cloud_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,12 @@ def send_signal_after_wait(process: subprocess.Popen, signal: int, wait: int = 3
time.sleep(wait)

if process.poll() is not None:
raise RuntimeError(f"Process {process.pid} already terminated.")
raise RuntimeError(
f"Process {process.pid} already terminated. This usually means "
f"that some of the trials ERRORed (e.g. because they couldn't be "
f"restored. Try re-running this test to see if this fixes the "
f"issue."
)

print(f"Sending signal {signal} to process {process.pid}")
process.send_signal(signal)
Expand Down

0 comments on commit 7d7d119

Please sign in to comment.