From d0015a783f7e688597c649161c657076dc9ad97e Mon Sep 17 00:00:00 2001 From: Rohit Annigeri Date: Tue, 19 Jul 2022 15:58:43 +0000 Subject: [PATCH 1/5] [Tune] Fix storage client creation when sync function tpl is not provided (#26714) Signed-off-by: Rohit Annigeri --- python/ray/tune/trainable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 662580d40616..3bd36e74305f 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -171,7 +171,7 @@ def __init__( self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl self.storage_client = None - if self.uses_cloud_checkpointing and self.sync_function_tpl: + if self.uses_cloud_checkpointing or self.sync_function_tpl: # Keep this only for custom sync functions and # backwards compatibility. # Todo (krfricke): We should find a way to register custom From 483090a73cd6a89736d8fc903caa439727f40b4e Mon Sep 17 00:00:00 2001 From: Rohit Annigeri Date: Wed, 20 Jul 2022 03:07:31 +0000 Subject: [PATCH 2/5] [Tune] Disable sync down if local path is present (#26714) Signed-off-by: Rohit Annigeri --- python/ray/tune/trainable.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 3bd36e74305f..18b7b7bce713 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -171,7 +171,7 @@ def __init__( self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl self.storage_client = None - if self.uses_cloud_checkpointing or self.sync_function_tpl: + if self.uses_cloud_checkpointing and self.sync_function_tpl: # Keep this only for custom sync functions and # backwards compatibility. # Todo (krfricke): We should find a way to register custom @@ -544,6 +544,11 @@ def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None # Only keep for backwards compatibility self.storage_client.sync_down(external_uri, local_dir) self.storage_client.wait_or_retry() + elif os.path.exists(checkpoint_path): + try: + TrainableUtil.find_checkpoint_dir(checkpoint_path) + except Exception: + pass else: checkpoint = Checkpoint.from_uri(external_uri) retry_fn( From 748f98c9bc6524b79d603f734bea3d5509e2773a Mon Sep 17 00:00:00 2001 From: Rohit Annigeri Date: Wed, 20 Jul 2022 18:45:18 +0000 Subject: [PATCH 3/5] [Tune] Disable sync up for tmp objects(#26714) Signed-off-by: Rohit Annigeri --- python/ray/tune/trainable.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 18b7b7bce713..eaa2e4f8fc04 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -429,7 +429,7 @@ def get_state(self): "ray_version": ray.__version__, } - def save(self, checkpoint_dir: Optional[str] = None) -> str: + def save(self, checkpoint_dir: Optional[str] = None, prevent_upload: bool = False) -> str: """Saves the current model state to a checkpoint. Subclasses should override ``save_checkpoint()`` instead to save state. @@ -440,6 +440,7 @@ def save(self, checkpoint_dir: Optional[str] = None) -> str: Args: checkpoint_dir: Optional dir to place the checkpoint. + prevent_upload: bool flag to stop tmp folders from uploading Returns: str: path that points to xxx.pkl file. @@ -459,7 +460,8 @@ def save(self, checkpoint_dir: Optional[str] = None) -> str: ) # Maybe sync to cloud - self._maybe_save_to_cloud(checkpoint_dir) + if not prevent_upload: + self._maybe_save_to_cloud(checkpoint_dir) return checkpoint_path @@ -486,12 +488,12 @@ def save_to_object(self): """Saves the current model state to a Python object. It also saves to disk but does not return the checkpoint path. - + It doesn't save to cloud. Returns: Object holding checkpoint data. """ tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) - checkpoint_path = self.save(tmpdir) + checkpoint_path = self.save(tmpdir, prevent_upload=True) # Save all files in subtree and delete the tmpdir. obj = TrainableUtil.checkpoint_to_object(checkpoint_path) shutil.rmtree(tmpdir) From 47816aae65d3de6b6b9dee28ed764074727120f1 Mon Sep 17 00:00:00 2001 From: Rohit Annigeri Date: Thu, 21 Jul 2022 18:24:37 +0000 Subject: [PATCH 4/5] [Tune] Move checkpoint exists check outside(#26714) Signed-off-by: Rohit Annigeri --- python/ray/tune/trainable.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index eaa2e4f8fc04..0152c9e38f73 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -535,7 +535,18 @@ def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path - if self.uses_cloud_checkpointing: + checkpoint_exists = False + if os.path.exists(checkpoint_path): + try: + TrainableUtil.find_checkpoint_dir(checkpoint_path) + except Exception: + pass + else: + checkpoint_exists = True + + if checkpoint_exists: + pass + elif self.uses_cloud_checkpointing: rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path ) @@ -546,11 +557,6 @@ def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None # Only keep for backwards compatibility self.storage_client.sync_down(external_uri, local_dir) self.storage_client.wait_or_retry() - elif os.path.exists(checkpoint_path): - try: - TrainableUtil.find_checkpoint_dir(checkpoint_path) - except Exception: - pass else: checkpoint = Checkpoint.from_uri(external_uri) retry_fn( From 0621d72f61455981e5f93af49af98e83e2a392c0 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 25 Jul 2022 14:42:24 +0100 Subject: [PATCH 5/5] Add test Signed-off-by: Kai Fricke --- .../ray/tune/tests/test_tune_save_restore.py | 78 ++++++++++++++++++- python/ray/tune/trainable.py | 11 ++- 2 files changed, 84 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/tests/test_tune_save_restore.py b/python/ray/tune/tests/test_tune_save_restore.py index 76f7444b8365..3396d8da3a9d 100644 --- a/python/ray/tune/tests/test_tune_save_restore.py +++ b/python/ray/tune/tests/test_tune_save_restore.py @@ -1,12 +1,17 @@ # coding: utf-8 +import json import os import pickle +from typing import Union, Dict + +import pytest import shutil import tempfile import unittest import ray from ray import tune +from ray.ml.utils.remote_storage import download_from_uri from ray.rllib import _register_all from ray.tune import Trainable from ray.tune.utils import validate_save_restore @@ -178,8 +183,79 @@ def load_checkpoint(self, checkpoint_dir): validate_save_restore(MockTrainable, use_object_store=True) +class SavingTrainable(tune.Trainable): + def __init__(self, return_type: str, *args, **kwargs): + self.return_type = return_type + super(SavingTrainable, self).__init__(*args, **kwargs) + + def save_checkpoint(self, tmp_checkpoint_dir: str): + checkpoint_data = {"data": 1} + + if self.return_type == "object": + return checkpoint_data + + subdir = os.path.join(tmp_checkpoint_dir, "subdir") + os.makedirs(subdir, exist_ok=True) + checkpoint_file = os.path.join(subdir, "checkpoint.pkl") + with open(checkpoint_file, "w") as f: + f.write(json.dumps(checkpoint_data)) + + if self.return_type == "root": + return tmp_checkpoint_dir + elif self.return_type == "subdir": + return subdir + elif self.return_type == "checkpoint": + return checkpoint_file + + def load_checkpoint(self, checkpoint: Union[Dict, str]): + if self.return_type == "object": + assert isinstance(checkpoint, dict) + checkpoint_data = checkpoint + checkpoint_file = None + elif self.return_type == "root": + assert "subdir" not in checkpoint + checkpoint_file = os.path.join(checkpoint, "subdir", "checkpoint.pkl") + elif self.return_type == "subdir": + assert "subdir" in checkpoint + assert "checkpoint.pkl" not in checkpoint + checkpoint_file = os.path.join(checkpoint, "checkpoint.pkl") + else: # self.return_type == "checkpoint" + assert checkpoint.endswith("subdir/checkpoint.pkl") + checkpoint_file = checkpoint + + if checkpoint_file: + with open(checkpoint_file, "rb") as f: + checkpoint_data = json.load(f) + + assert checkpoint_data == {"data": 1}, checkpoint_data + + +# Note: In Ray 2.0, this test lives in test_trainable.py +def test_checkpoint_object_no_sync(tmpdir): + """Asserts that save_to_object() and restore_from_object() do not sync up/down""" + trainable = SavingTrainable( + "object", remote_checkpoint_dir="memory:///test/location" + ) + + # Save checkpoint + trainable.save() + + check_dir = tmpdir / "check_save" + download_from_uri(uri="memory:///test/location", local_path=str(check_dir)) + assert os.listdir(str(check_dir)) == ["checkpoint_000000"] + + # Save to object + obj = trainable.save_to_object() + + check_dir = tmpdir / "check_save_obj" + download_from_uri(uri="memory:///test/location", local_path=str(check_dir)) + assert os.listdir(str(check_dir)) == ["checkpoint_000000"] + + # Restore from object + trainable.restore_from_object(obj) + + if __name__ == "__main__": - import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 0152c9e38f73..507966a30ebc 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -429,7 +429,9 @@ def get_state(self): "ray_version": ray.__version__, } - def save(self, checkpoint_dir: Optional[str] = None, prevent_upload: bool = False) -> str: + def save( + self, checkpoint_dir: Optional[str] = None, prevent_upload: bool = False + ) -> str: """Saves the current model state to a checkpoint. Subclasses should override ``save_checkpoint()`` instead to save state. @@ -440,7 +442,7 @@ def save(self, checkpoint_dir: Optional[str] = None, prevent_upload: bool = Fals Args: checkpoint_dir: Optional dir to place the checkpoint. - prevent_upload: bool flag to stop tmp folders from uploading + prevent_upload: If True, will not upload the saved checkpoint to cloud. Returns: str: path that points to xxx.pkl file. @@ -488,7 +490,8 @@ def save_to_object(self): """Saves the current model state to a Python object. It also saves to disk but does not return the checkpoint path. - It doesn't save to cloud. + It does not save the checkpoint to cloud storage. + Returns: Object holding checkpoint data. """ @@ -543,7 +546,7 @@ def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None pass else: checkpoint_exists = True - + if checkpoint_exists: pass elif self.uses_cloud_checkpointing: