Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Chunk file transfers in cross-node checkpoint syncing #23804

Merged
merged 14 commits into from
Apr 12, 2022
147 changes: 41 additions & 106 deletions python/ray/tune/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@
import distutils
import distutils.spawn
import inspect
import io
import logging
import os
import pathlib
import shutil
import subprocess
import tarfile
import tempfile
import time
import types
import warnings

from typing import Optional, List, Callable, Union, Tuple, Dict
from typing import Optional, List, Callable, Union, Tuple

from shlex import quote

import ray
from ray.tune.error import TuneError
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node
from ray.util.annotations import PublicAPI
from ray.util.debug import log_once
from ray.util.ml_utils.cloud import (
Expand Down Expand Up @@ -435,73 +432,6 @@ def _validate_exclude_template(exclude_template):
)


def _get_recursive_files_and_stats(path: str) -> Dict[str, Tuple[float, int]]:
"""Return dict of files mapping to stats in ``path``.

This function scans a directory ``path`` recursively and returns a dict
mapping each contained file to a tuple of (mtime, filesize).

mtime and filesize are returned from ``os.lstat`` and are usually a
floating point number (timestamp) and an int (filesize in bytes).
"""
files_stats = {}
for root, dirs, files in os.walk(path, topdown=False):
rel_root = os.path.relpath(root, path)
for file in files:
key = os.path.join(rel_root, file)
stat = os.lstat(os.path.join(path, key))
files_stats[key] = stat.st_mtime, stat.st_size

return files_stats


# Only export once
_remote_get_recursive_files_and_stats = ray.remote(_get_recursive_files_and_stats)


@ray.remote
def _pack_dir(
source_dir: str, files_stats: Optional[Dict[str, Tuple[float, int]]]
) -> bytes:
stream = io.BytesIO()
with tarfile.open(fileobj=stream, mode="w:gz", format=tarfile.PAX_FORMAT) as tar:
if not files_stats:
# If no `files_stats` is passed, pack whole directory
tar.add(source_dir, arcname="", recursive=True)
else:
# Otherwise, only pack differing files
tar.add(source_dir, arcname="", recursive=False)
for root, dirs, files in os.walk(source_dir, topdown=False):
rel_root = os.path.relpath(root, source_dir)
# Always add all directories
for dir in dirs:
key = os.path.join(rel_root, dir)
tar.add(os.path.join(source_dir, key), arcname=key, recursive=False)
# Add files where our information differs
for file in files:
key = os.path.join(rel_root, file)
stat = os.lstat(os.path.join(source_dir, key))
file_stat = stat.st_mtime, stat.st_size
if key not in files_stats or file_stat != files_stats[key]:
tar.add(os.path.join(source_dir, key), arcname=key)

return stream.getvalue()


@ray.remote
def _unpack_dir(stream: bytes, target_dir: str):
with tarfile.open(fileobj=io.BytesIO(stream)) as tar:
tar.extractall(target_dir)


@ray.remote
def _delete_dir(target_dir: str) -> bool:
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
return True
return False


class RemoteTaskClient(SyncClient):
"""Sync client that uses remote tasks to synchronize two directories.

Expand All @@ -523,16 +453,18 @@ class RemoteTaskClient(SyncClient):
will not kill the previous sync command, so it may still be executed.
"""

def __init__(self, store_pack_future: bool = False):
def __init__(self, _store_remotes: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to make this a private argument? store_remotes should be good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is only needed for testing (so we can access and inspect the futures), so it's nothing a user would usually do or use

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. can you please these it super clear that nobody should flip this parameter unless for testing.
like a doc string please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this in the next update (coming soon :-) ) - I think it's not urgent as users never instantiate SyncClients themselves. It's an internal concept and it's instantiated by Ray Tune automatically. So nobody ever calls SomeSyncClient(..).

# Used for testing
self._store_pack_future = store_pack_future
self._store_remotes = _store_remotes
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None

self._pack_future = None
self._sync_future = None

self._last_source_tuple = None
self._last_target_tuple = None
self._last_files_stats = None

self._max_size_bytes = None # No file size limit

def _sync_still_running(self) -> bool:
if not self._sync_future:
Expand Down Expand Up @@ -560,12 +492,7 @@ def sync_down(
self._last_source_tuple = source_ip, source_path
self._last_target_tuple = target_ip, target

# Get existing files on local node before packing on remote node
self._last_files_stats = _get_recursive_files_and_stats(target)

return self._execute_sync(
self._last_source_tuple, self._last_target_tuple, self._last_files_stats
)
return self._execute_sync(self._last_source_tuple, self._last_target_tuple)

def sync_up(
self, source: str, target: Tuple[str, str], exclude: Optional[List] = None
Expand All @@ -583,39 +510,46 @@ def sync_up(
self._last_source_tuple = source_ip, source
self._last_target_tuple = target_ip, target_path

# Get existing files on remote node before packing on local node
self._last_files_stats = _remote_get_recursive_files_and_stats.options(
num_cpus=0, resources={f"node:{target_ip}": 0.01}
).remote(target_path)

return self._execute_sync(
self._last_source_tuple, self._last_target_tuple, self._last_files_stats
)
return self._execute_sync(self._last_source_tuple, self._last_target_tuple)

def _execute_sync(
self,
source_tuple: Tuple[str, str],
target_tuple: Tuple[str, str],
files_stats: Optional[Dict[str, Tuple[float, int]]] = None,
) -> bool:
source_ip, source_path = source_tuple
target_ip, target_path = target_tuple

pack_on_source_node = _pack_dir.options(
num_cpus=0, resources={f"node:{source_ip}": 0.01}
)
unpack_on_target_node = _unpack_dir.options(
num_cpus=0, resources={f"node:{target_ip}": 0.01}
self._sync_future, pack_actor, files_stats = sync_dir_between_nodes(
source_ip=source_ip,
source_path=source_path,
target_ip=target_ip,
target_path=target_path,
return_futures=True,
max_size_bytes=self._max_size_bytes,
)

pack_future = pack_on_source_node.remote(source_path, files_stats)
if self._store_pack_future:
self._pack_future = pack_future
self._sync_future = unpack_on_target_node.remote(pack_future, target_path)
if self._store_remotes:
self._stored_pack_actor_ref = pack_actor
self._stored_files_stats = files_stats

return True

def delete(self, target: str):
pass
if not self._last_target_tuple:
logger.error(
f"Could not delete path {target} as the target node is not known."
)
return

node_ip = self._last_target_tuple[0]

try:
delete_on_node(node_ip=node_ip, path=target)
except Exception as e:
logger.error(
f"Could not delete path {target} on remote node {node_ip}: {e}"
)
krfricke marked this conversation as resolved.
Show resolved Hide resolved

def wait(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we can sync arbitrarily large file size now, we may run into this method more often than before and end up in a blocking situation. We need to think of how to supply visibility if/when this happens.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, we synced arbitrarily large files before as well, just with rsync and not with remote tasks. And we do warn already: with warn_if_slow("callbacks.on_trial_save"), though we may want to think about making this message a bit more insightful

if self._sync_future:
Expand All @@ -628,7 +562,8 @@ def wait(self):
f"{self._last_target_tuple}: {e}"
) from e
self._sync_future = None
self._pack_future = None
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None

def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
assert max_retries > 0
Expand All @@ -646,22 +581,22 @@ def wait_or_retry(self, max_retries: int = 3, backoff_s: int = 5):
self._execute_sync(
self._last_source_tuple,
self._last_target_tuple,
self._last_files_stats,
)
continue
return
self._sync_future = None
self._pack_future = None
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None
raise TuneError(f"Failed sync even after {max_retries} retries.")

def reset(self):
if self._sync_future:
logger.warning("Sync process still running but resetting anyways.")
self._sync_future = None
self._pack_future = None
self._last_source_tuple = None
self._last_target_tuple = None
self._last_files_stats = None
self._stored_pack_actor_ref = None
self._stored_files_stats_future = None

def close(self):
self._sync_future = None # Avoid warning
Expand Down
70 changes: 65 additions & 5 deletions python/ray/tune/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections import deque

import ray
from ray.exceptions import RayTaskError
from ray.rllib import _register_all

from ray import tune
Expand All @@ -28,6 +29,7 @@
SyncerCallback,
)
from ray.tune.utils.callback import create_default_callbacks
from ray.tune.utils.file_transfer import sync_dir_between_nodes, delete_on_node


class TestSyncFunctionality(unittest.TestCase):
Expand Down Expand Up @@ -455,6 +457,64 @@ def _start_process(self, cmd):

self.assertEquals(client._sync_downs, 2)

def testSyncBetweenNodesAndDelete(self):
temp_source = tempfile.mkdtemp()
temp_up_target = tempfile.mkdtemp()
temp_down_target = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, temp_source)
self.addCleanup(shutil.rmtree, temp_up_target, ignore_errors=True)
self.addCleanup(shutil.rmtree, temp_down_target)

os.makedirs(os.path.join(temp_source, "dir_level0", "dir_level1"))
with open(os.path.join(temp_source, "dir_level0", "file_level1.txt"), "w") as f:
f.write("Data\n")

def check_dir_contents(path: str):
assert os.path.exists(os.path.join(path, "dir_level0"))
assert os.path.exists(os.path.join(path, "dir_level0", "dir_level1"))
assert os.path.exists(os.path.join(path, "dir_level0", "file_level1.txt"))
with open(os.path.join(path, "dir_level0", "file_level1.txt"), "r") as f:
assert f.read() == "Data\n"

# Sanity check
check_dir_contents(temp_source)

sync_dir_between_nodes(
source_ip=ray.util.get_node_ip_address(),
source_path=temp_source,
target_ip=ray.util.get_node_ip_address(),
target_path=temp_up_target,
)

# Check sync up
check_dir_contents(temp_up_target)

# Max size exceeded
with self.assertRaises(RayTaskError):
sync_dir_between_nodes(
source_ip=ray.util.get_node_ip_address(),
source_path=temp_up_target,
target_ip=ray.util.get_node_ip_address(),
target_path=temp_down_target,
max_size_bytes=2,
)
assert not os.listdir(temp_down_target)
krfricke marked this conversation as resolved.
Show resolved Hide resolved

sync_dir_between_nodes(
source_ip=ray.util.get_node_ip_address(),
source_path=temp_up_target,
target_ip=ray.util.get_node_ip_address(),
target_path=temp_down_target,
)

# Check sync down
check_dir_contents(temp_down_target)

# Delete in some dir
delete_on_node(node_ip=ray.util.get_node_ip_address(), path=temp_up_target)

assert not os.path.exists(temp_up_target)

def testSyncRemoteTaskOnlyDifferences(self):
"""Tests the RemoteTaskClient sync client.

Expand Down Expand Up @@ -489,7 +549,7 @@ def testSyncRemoteTaskOnlyDifferences(self):
this_node_ip = ray.util.get_node_ip_address()

# Sync everything up
client = RemoteTaskClient(store_pack_future=True)
client = RemoteTaskClient(_store_remotes=True)
client.sync_up(source=temp_source, target=(this_node_ip, temp_up_target))
client.wait()

Expand Down Expand Up @@ -526,8 +586,8 @@ def testSyncRemoteTaskOnlyDifferences(self):
client.sync_up(source=temp_source, target=(this_node_ip, temp_up_target))

# Hi-jack futures
files_stats = ray.get(client._last_files_stats)
tarball = ray.get(client._pack_future)
files_stats = ray.get(client._stored_files_stats)
tarball = ray.get(client._stored_pack_actor_ref.get_full_data.remote())
client.wait()

# Existing file should have new content
Expand Down Expand Up @@ -559,8 +619,8 @@ def testSyncRemoteTaskOnlyDifferences(self):
client.sync_down(source=(this_node_ip, temp_source), target=temp_down_target)

# Hi-jack futures
files_stats = client._last_files_stats
tarball = ray.get(client._pack_future)
files_stats = ray.get(client._stored_files_stats)
tarball = ray.get(client._stored_pack_actor_ref.get_full_data.remote())
client.wait()

# Existing file should have new content
Expand Down
Loading