diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 86c18b787216..8a0c830e7aec 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -22,6 +22,7 @@ from ray.util.ml_utils.filelock import TempFileLock _DICT_CHECKPOINT_FILE_NAME = "dict_checkpoint.pkl" +_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY = "_ray_additional_checkpoint_files" _METADATA_CHECKPOINT_SUFFIX = ".meta.pkl" _FS_CHECKPOINT_KEY = "fs_checkpoint" _BYTES_DATA_KEY = "bytes_data" @@ -48,9 +49,16 @@ class Checkpoint: There are no guarantees made about compatibility of intermediate representations. - New data can be added to Checkpoint during conversion. Consider the + New data can be added to a Checkpoint during conversion. Consider the following conversion: directory --> dict (adding dict["foo"] = "bar") - --> directory --> dict (expect to see dict["foo"] = "bar"). + --> directory --> dict (expect to see dict["foo"] = "bar"). Note that + the second directory will contain pickle files with the serialized additional + field data in them. + + Similarly with a dict as a source: dict --> directory (add file "foo.txt") + --> dict --> directory (will have "foo.txt" in it again). Note that the second + dict representation will contain an extra field with the serialized additional + files in it. Examples: @@ -258,6 +266,23 @@ def to_dict(self) -> dict: # from the checkpoint file. with open(checkpoint_data_path, "rb") as f: checkpoint_data = pickle.load(f) + + # If there are additional files in the directory, add them as + # _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY + additional_files = {} + for file_or_dir in os.listdir(local_path): + if file_or_dir in [".", "..", _DICT_CHECKPOINT_FILE_NAME]: + continue + + additional_files[file_or_dir] = _pack( + os.path.join(local_path, file_or_dir) + ) + + if additional_files: + checkpoint_data[ + _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY + ] = additional_files + else: files = [ f @@ -361,7 +386,15 @@ def _to_directory(self, path: str) -> None: # This used to be a true fs checkpoint, so restore _unpack(data_dict[_FS_CHECKPOINT_KEY], path) else: - # This is a dict checkpoint. Dump data into checkpoint.pkl + # This is a dict checkpoint. + # First, restore any additional files + additional_files = data_dict.pop( + _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, {} + ) + for file, content in additional_files.items(): + _unpack(stream=content, path=os.path.join(path, file)) + + # Then dump data into checkpoint.pkl checkpoint_data_path = os.path.join(path, _DICT_CHECKPOINT_FILE_NAME) with open(checkpoint_data_path, "wb") as f: pickle.dump(data_dict, f) @@ -630,5 +663,3 @@ def _make_dir(path: str, acquire_del_lock: bool = True) -> None: open(del_lock_path, "a").close() os.makedirs(path, exist_ok=True) - # Drop marker - open(os.path.join(path, ".is_checkpoint"), "a").close() diff --git a/python/ray/air/tests/test_checkpoints.py b/python/ray/air/tests/test_checkpoints.py index b2657bdad209..06f835280fec 100644 --- a/python/ray/air/tests/test_checkpoints.py +++ b/python/ray/air/tests/test_checkpoints.py @@ -6,7 +6,7 @@ from typing import Any import ray -from ray.air.checkpoint import Checkpoint +from ray.air.checkpoint import Checkpoint, _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY from ray.air._internal.remote_storage import delete_at_uri, _ensure_directory @@ -328,6 +328,81 @@ def test_obj_store_cp_as_directory(self): assert not os.path.exists(checkpoint_dir) + def test_dict_checkpoint_additional_files(self): + checkpoint = self._prepare_dict_checkpoint() + + # Convert to directory + checkpoint_dir = checkpoint.to_directory() + + # Add file into checkpoint directory + with open(os.path.join(checkpoint_dir, "additional_file.txt"), "w") as f: + f.write("Additional data\n") + os.mkdir(os.path.join(checkpoint_dir, "subdir")) + with open(os.path.join(checkpoint_dir, "subdir", "another.txt"), "w") as f: + f.write("Another additional file\n") + + # Create new checkpoint object + checkpoint = Checkpoint.from_directory(checkpoint_dir) + + new_dir = checkpoint.to_directory() + + assert os.path.exists(os.path.join(new_dir, "additional_file.txt")) + with open(os.path.join(new_dir, "additional_file.txt"), "r") as f: + assert f.read() == "Additional data\n" + + assert os.path.exists(os.path.join(new_dir, "subdir", "another.txt")) + with open(os.path.join(new_dir, "subdir", "another.txt"), "r") as f: + assert f.read() == "Another additional file\n" + + checkpoint_dict = checkpoint.to_dict() + for k, v in self.checkpoint_dict_data.items(): + assert checkpoint_dict[k] == v + + assert _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY in checkpoint_dict + + # Add another field + checkpoint_dict["new_field"] = "Data" + + another_dict = Checkpoint.from_directory( + Checkpoint.from_dict(checkpoint_dict).to_directory() + ).to_dict() + assert _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY in another_dict + assert another_dict["new_field"] == "Data" + + def test_fs_checkpoint_additional_fields(self): + checkpoint = self._prepare_fs_checkpoint() + + # Convert to dict + checkpoint_dict = checkpoint.to_dict() + + # Add field to dict + checkpoint_dict["additional_field"] = "data" + + # Create new checkpoint object + checkpoint = Checkpoint.from_dict(checkpoint_dict) + + # Turn into FS + checkpoint_dir = checkpoint.to_directory() + + assert os.path.exists(os.path.join(checkpoint_dir, "test_data.pkl")) + assert os.path.exists(os.path.join(checkpoint_dir, "additional_field.meta.pkl")) + + # Add new file + with open(os.path.join(checkpoint_dir, "even_more.txt"), "w") as f: + f.write("More\n") + + # Turn into dict + new_dict = Checkpoint.from_directory(checkpoint_dir).to_dict() + + assert new_dict["additional_field"] == "data" + + # Turn into fs + new_dir = Checkpoint.from_dict(new_dict).to_directory() + + assert os.path.exists(os.path.join(new_dir, "test_data.pkl")) + assert os.path.exists(os.path.join(new_dir, "additional_field.meta.pkl")) + assert os.path.exists(os.path.join(new_dir, "even_more.txt")) + class CheckpointsSerdeTest(unittest.TestCase): def setUp(self) -> None: