Skip to content

Commit

Permalink
[tune] Only sync down from cloud if needed (ray-project#26725)
Browse files Browse the repository at this point in the history
Currently, trainables will try to sync up/down temporary checkpoints from cloud storage, leading to errors. These erros come up e.g. with PBT, which heavily uses saving/restoring from objects.

Instead, we should not sync these temporary checkpoints up at all, and we should generally not sync down if a local checkpoint directory exists, which will prevent us also from trying to sync down non-existent temporary checkpoint directories.

See ray-project#26714

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Rohan138 <[email protected]>
  • Loading branch information
krfricke authored and Rohan138 committed Jul 28, 2022
1 parent 7f7f4a7 commit 3111df8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
45 changes: 45 additions & 0 deletions python/ray/tune/tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray
from ray import tune
from ray.air import session, Checkpoint
from ray.air._internal.remote_storage import download_from_uri
from ray.tune.trainable import wrap_function


Expand Down Expand Up @@ -81,6 +82,11 @@ def function_trainable_directory(config):

@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_path_class(ray_start_2_cpus, return_type):
"""Assert that restoring from a Trainable.save() future works with
class trainables.
Needs Ray cluster so we get actual futures.
"""
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save.remote()
Expand All @@ -95,6 +101,11 @@ def test_save_load_checkpoint_path_class(ray_start_2_cpus, return_type):

@pytest.mark.parametrize("return_type", ["object", "root", "subdir", "checkpoint"])
def test_save_load_checkpoint_object_class(ray_start_2_cpus, return_type):
"""Assert that restoring from a Trainable.save_to_object() future works with
class trainables.
Needs Ray cluster so we get actual futures.
"""
trainable = ray.remote(SavingTrainable).remote(return_type=return_type)

saving_future = trainable.save_to_object.remote()
Expand All @@ -111,6 +122,11 @@ def test_save_load_checkpoint_object_class(ray_start_2_cpus, return_type):
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_path_fn(ray_start_2_cpus, fn_trainable):
"""Assert that restoring from a Trainable.save() future works with
function trainables.
Needs Ray cluster so we get actual futures.
"""
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())
Expand All @@ -129,6 +145,11 @@ def test_save_load_checkpoint_path_fn(ray_start_2_cpus, fn_trainable):
"fn_trainable", [function_trainable_dict, function_trainable_directory]
)
def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
"""Assert that restoring from a Trainable.save_to_object() future works with
function trainables.
Needs Ray cluster so we get actual futures.
"""
trainable_cls = wrap_function(fn_trainable)
trainable = ray.remote(trainable_cls).remote()
ray.get(trainable.train.remote())
Expand All @@ -143,6 +164,30 @@ def test_save_load_checkpoint_object_fn(ray_start_2_cpus, fn_trainable):
ray.get(restoring_future)


def test_checkpoint_object_no_sync(tmpdir):
"""Asserts that save_to_object() and restore_from_object() do not sync up/down"""
trainable = SavingTrainable(
"object", remote_checkpoint_dir="memory:///test/location"
)

# Save checkpoint
trainable.save()

check_dir = tmpdir / "check_save"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Save to object
obj = trainable.save_to_object()

check_dir = tmpdir / "check_save_obj"
download_from_uri(uri="memory:///test/location", local_path=str(check_dir))
assert os.listdir(str(check_dir)) == ["checkpoint_000000"]

# Restore from object
trainable.restore_from_object(obj)


if __name__ == "__main__":
import sys

Expand Down
20 changes: 17 additions & 3 deletions python/ray/tune/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ def _create_checkpoint_dir(
)
return checkpoint_dir

def save(self, checkpoint_dir: Optional[str] = None) -> str:
def save(
self, checkpoint_dir: Optional[str] = None, prevent_upload: bool = False
) -> str:
"""Saves the current model state to a checkpoint.
Subclasses should override ``save_checkpoint()`` instead to save state.
Expand All @@ -436,6 +438,7 @@ def save(self, checkpoint_dir: Optional[str] = None) -> str:
Args:
checkpoint_dir: Optional dir to place the checkpoint.
prevent_upload: If True, will not upload the saved checkpoint to cloud.
Returns:
The given or created checkpoint directory.
Expand Down Expand Up @@ -487,7 +490,8 @@ def save(self, checkpoint_dir: Optional[str] = None) -> str:
TrainableUtil.write_metadata(checkpoint_dir, metadata)

# Maybe sync to cloud
self._maybe_save_to_cloud(checkpoint_dir)
if not prevent_upload:
self._maybe_save_to_cloud(checkpoint_dir)

return checkpoint_dir

Expand All @@ -512,6 +516,15 @@ def _maybe_save_to_cloud(self, checkpoint_dir: str) -> bool:
return True

def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
if os.path.exists(checkpoint_path):
try:
TrainableUtil.find_checkpoint_dir(checkpoint_path)
except Exception:
pass
else:
# If the path exists locally, we don't have to download
return True

if not self.uses_cloud_checkpointing:
return False

Expand Down Expand Up @@ -541,12 +554,13 @@ def save_to_object(self):
"""Saves the current model state to a Python object.
It also saves to disk but does not return the checkpoint path.
It does not save the checkpoint to cloud storage.
Returns:
Object holding checkpoint data.
"""
temp_container_dir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
checkpoint_dir = self.save(temp_container_dir)
checkpoint_dir = self.save(temp_container_dir, prevent_upload=True)

obj_ref = Checkpoint.from_directory(checkpoint_dir).to_bytes()
shutil.rmtree(temp_container_dir)
Expand Down

0 comments on commit 3111df8

Please sign in to comment.