Skip to content

Commit

Permalink
[air] Serialize additional files in dict checkpoints turned dir check…
Browse files Browse the repository at this point in the history
…points (ray-project#26351)

With this PR, files put into directory checkpoints that were dict checkpoints will be serialized and retained when a subsequent to_dict() is called. This is to enable storing additional files, as e.g. needed by Ray Tune.

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
krfricke authored and Stefan van der Kleij committed Aug 18, 2022
1 parent 1da306c commit 93ba4a6
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 6 deletions.
41 changes: 36 additions & 5 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.util.ml_utils.filelock import TempFileLock

_DICT_CHECKPOINT_FILE_NAME = "dict_checkpoint.pkl"
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY = "_ray_additional_checkpoint_files"
_METADATA_CHECKPOINT_SUFFIX = ".meta.pkl"
_FS_CHECKPOINT_KEY = "fs_checkpoint"
_BYTES_DATA_KEY = "bytes_data"
Expand All @@ -48,9 +49,16 @@ class Checkpoint:
There are no guarantees made about compatibility of intermediate
representations.
New data can be added to Checkpoint during conversion. Consider the
New data can be added to a Checkpoint during conversion. Consider the
following conversion: directory --> dict (adding dict["foo"] = "bar")
--> directory --> dict (expect to see dict["foo"] = "bar").
--> directory --> dict (expect to see dict["foo"] = "bar"). Note that
the second directory will contain pickle files with the serialized additional
field data in them.
Similarly with a dict as a source: dict --> directory (add file "foo.txt")
--> dict --> directory (will have "foo.txt" in it again). Note that the second
dict representation will contain an extra field with the serialized additional
files in it.
Examples:
Expand Down Expand Up @@ -258,6 +266,23 @@ def to_dict(self) -> dict:
# from the checkpoint file.
with open(checkpoint_data_path, "rb") as f:
checkpoint_data = pickle.load(f)

# If there are additional files in the directory, add them as
# _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
additional_files = {}
for file_or_dir in os.listdir(local_path):
if file_or_dir in [".", "..", _DICT_CHECKPOINT_FILE_NAME]:
continue

additional_files[file_or_dir] = _pack(
os.path.join(local_path, file_or_dir)
)

if additional_files:
checkpoint_data[
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
] = additional_files

else:
files = [
f
Expand Down Expand Up @@ -361,7 +386,15 @@ def _to_directory(self, path: str) -> None:
# This used to be a true fs checkpoint, so restore
_unpack(data_dict[_FS_CHECKPOINT_KEY], path)
else:
# This is a dict checkpoint. Dump data into checkpoint.pkl
# This is a dict checkpoint.
# First, restore any additional files
additional_files = data_dict.pop(
_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, {}
)
for file, content in additional_files.items():
_unpack(stream=content, path=os.path.join(path, file))

# Then dump data into checkpoint.pkl
checkpoint_data_path = os.path.join(path, _DICT_CHECKPOINT_FILE_NAME)
with open(checkpoint_data_path, "wb") as f:
pickle.dump(data_dict, f)
Expand Down Expand Up @@ -630,5 +663,3 @@ def _make_dir(path: str, acquire_del_lock: bool = True) -> None:
open(del_lock_path, "a").close()

os.makedirs(path, exist_ok=True)
# Drop marker
open(os.path.join(path, ".is_checkpoint"), "a").close()
77 changes: 76 additions & 1 deletion python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any

import ray
from ray.air.checkpoint import Checkpoint
from ray.air.checkpoint import Checkpoint, _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY
from ray.air._internal.remote_storage import delete_at_uri, _ensure_directory


Expand Down Expand Up @@ -328,6 +328,81 @@ def test_obj_store_cp_as_directory(self):

assert not os.path.exists(checkpoint_dir)

def test_dict_checkpoint_additional_files(self):
checkpoint = self._prepare_dict_checkpoint()

# Convert to directory
checkpoint_dir = checkpoint.to_directory()

# Add file into checkpoint directory
with open(os.path.join(checkpoint_dir, "additional_file.txt"), "w") as f:
f.write("Additional data\n")
os.mkdir(os.path.join(checkpoint_dir, "subdir"))
with open(os.path.join(checkpoint_dir, "subdir", "another.txt"), "w") as f:
f.write("Another additional file\n")

# Create new checkpoint object
checkpoint = Checkpoint.from_directory(checkpoint_dir)

new_dir = checkpoint.to_directory()

assert os.path.exists(os.path.join(new_dir, "additional_file.txt"))
with open(os.path.join(new_dir, "additional_file.txt"), "r") as f:
assert f.read() == "Additional data\n"

assert os.path.exists(os.path.join(new_dir, "subdir", "another.txt"))
with open(os.path.join(new_dir, "subdir", "another.txt"), "r") as f:
assert f.read() == "Another additional file\n"

checkpoint_dict = checkpoint.to_dict()
for k, v in self.checkpoint_dict_data.items():
assert checkpoint_dict[k] == v

assert _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY in checkpoint_dict

# Add another field
checkpoint_dict["new_field"] = "Data"

another_dict = Checkpoint.from_directory(
Checkpoint.from_dict(checkpoint_dict).to_directory()
).to_dict()
assert _DICT_CHECKPOINT_ADDITIONAL_FILE_KEY in another_dict
assert another_dict["new_field"] == "Data"

def test_fs_checkpoint_additional_fields(self):
checkpoint = self._prepare_fs_checkpoint()

# Convert to dict
checkpoint_dict = checkpoint.to_dict()

# Add field to dict
checkpoint_dict["additional_field"] = "data"

# Create new checkpoint object
checkpoint = Checkpoint.from_dict(checkpoint_dict)

# Turn into FS
checkpoint_dir = checkpoint.to_directory()

assert os.path.exists(os.path.join(checkpoint_dir, "test_data.pkl"))
assert os.path.exists(os.path.join(checkpoint_dir, "additional_field.meta.pkl"))

# Add new file
with open(os.path.join(checkpoint_dir, "even_more.txt"), "w") as f:
f.write("More\n")

# Turn into dict
new_dict = Checkpoint.from_directory(checkpoint_dir).to_dict()

assert new_dict["additional_field"] == "data"

# Turn into fs
new_dir = Checkpoint.from_dict(new_dict).to_directory()

assert os.path.exists(os.path.join(new_dir, "test_data.pkl"))
assert os.path.exists(os.path.join(new_dir, "additional_field.meta.pkl"))
assert os.path.exists(os.path.join(new_dir, "even_more.txt"))


class CheckpointsSerdeTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 93ba4a6

Please sign in to comment.