Skip to content

Commit

Permalink
[AIR] Checkpoint improvements (ray-project#30948)
Browse files Browse the repository at this point in the history
Boston dataset (used in tests) is/will be removed from sklearn.

Signed-off-by: Antoni Baum <[email protected]>
Co-authored-by: Balaji Veeramani <[email protected]>
Signed-off-by: tmynn <[email protected]>
  • Loading branch information
2 people authored and tamohannes committed Jan 25, 2023
1 parent da54a10 commit 0a11946
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
34 changes: 27 additions & 7 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class Checkpoint:
@DeveloperAPI
def __init__(
self,
local_path: Optional[str] = None,
local_path: Optional[Union[str, os.PathLike]] = None,
data_dict: Optional[dict] = None,
uri: Optional[str] = None,
obj_ref: Optional[ray.ObjectRef] = None,
Expand Down Expand Up @@ -210,7 +210,9 @@ def __init__(
else:
raise ValueError("Cannot create checkpoint without data.")

self._local_path: Optional[str] = local_path
self._local_path: Optional[str] = (
str(Path(local_path).resolve()) if local_path else local_path
)
self._data_dict: Optional[Dict[str, Any]] = data_dict
self._uri: Optional[str] = uri
self._obj_ref: Optional[ray.ObjectRef] = obj_ref
Expand All @@ -230,6 +232,22 @@ def _metadata(self) -> _CheckpointMetadata:
},
)

def _copy_metadata_attrs_from(self, source: "Checkpoint") -> None:
"""Copy in-place metadata attributes from ``source`` to self."""
for attr, value in source._metadata.checkpoint_state.items():
if attr in self._SERIALIZED_ATTRS:
setattr(self, attr, value)

@_metadata.setter
def _metadata(self, metadata: _CheckpointMetadata):
if metadata.checkpoint_type is not self.__class__:
raise ValueError(
f"Checkpoint type in metadata must match {self.__class__}, "
f"got {metadata.checkpoint_type}"
)
for attr, value in metadata.checkpoint_state.items():
setattr(self, attr, value)

@property
def uri(self) -> Optional[str]:
"""Return checkpoint URI, if available.
Expand Down Expand Up @@ -259,7 +277,7 @@ def uri(self) -> Optional[str]:
return self._uri

if self._local_path and Path(self._local_path).exists():
return "file://" + self._local_path
return f"file://{self._local_path}"

return None

Expand Down Expand Up @@ -290,7 +308,7 @@ def to_bytes(self) -> bytes:
data_dict = self.to_dict()
if "bytes_data" in data_dict:
return data_dict["bytes_data"]
return pickle.dumps(self.to_dict())
return pickle.dumps(data_dict)

@classmethod
def from_dict(cls, data: dict) -> "Checkpoint":
Expand Down Expand Up @@ -421,7 +439,7 @@ def to_object_ref(self) -> ray.ObjectRef:
)

@classmethod
def from_directory(cls, path: str) -> "Checkpoint":
def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
"""Create checkpoint object from directory.
Args:
Expand Down Expand Up @@ -463,12 +481,14 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
if type(other) is cls:
return other

return cls(
new_checkpoint = cls(
local_path=other._local_path,
data_dict=other._data_dict,
uri=other._uri,
obj_ref=other._obj_ref,
)
new_checkpoint._copy_metadata_attrs_from(other)
return new_checkpoint

def _get_temporary_checkpoint_dir(self) -> str:
"""Return the name for the temporary checkpoint dir."""
Expand Down Expand Up @@ -532,7 +552,7 @@ def _to_directory(self, path: str) -> None:
local_path = self._local_path
external_path = _get_external_path(self._uri)
if local_path:
if local_path != path:
if Path(local_path).resolve() != Path(path).resolve():
# If this exists on the local path, just copy over
if path and os.path.exists(path):
shutil.rmtree(path)
Expand Down
14 changes: 13 additions & 1 deletion python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import tempfile
import unittest
from pathlib import Path
from typing import Any

import pytest
Expand Down Expand Up @@ -47,6 +48,10 @@ class OtherStubCheckpoint(Checkpoint):
pass


class OtherStubCheckpointWithAttrs(Checkpoint):
_SERIALIZED_ATTRS = StubCheckpoint._SERIALIZED_ATTRS


def test_from_checkpoint():
checkpoint = Checkpoint.from_dict({"spam": "ham"})
assert type(StubCheckpoint.from_checkpoint(checkpoint)) is StubCheckpoint
Expand All @@ -56,6 +61,13 @@ def test_from_checkpoint():
checkpoint.foo = "bar"
assert StubCheckpoint.from_checkpoint(checkpoint).foo == "bar"

# Check that attributes persist if the new checkpoint
# has them as well.
# Check that attributes persist if same checkpoint type.
checkpoint = StubCheckpoint.from_dict({"spam": "ham"})
checkpoint.foo = "bar"
assert OtherStubCheckpointWithAttrs.from_checkpoint(checkpoint).foo == "bar"


class TestCheckpointTypeCasting:
def test_dict(self):
Expand Down Expand Up @@ -491,7 +503,7 @@ def test_obj_store_cp_as_directory(self):

with checkpoint.as_directory() as checkpoint_dir:
assert os.path.exists(checkpoint_dir)
assert checkpoint_dir.endswith(checkpoint._uuid.hex)
assert Path(checkpoint_dir).stem.endswith(checkpoint._uuid.hex)

assert not os.path.exists(checkpoint_dir)

Expand Down
3 changes: 2 additions & 1 deletion python/ray/tune/tests/test_trainable_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import shutil
import unittest
from pathlib import Path
from unittest.mock import patch

import ray
Expand Down Expand Up @@ -61,7 +62,7 @@ def tune_one(config=None, checkpoint_dir=None):
)
df = a.dataframe()
checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0])._local_path
assert checkpoint_dir.endswith("/checkpoint_000001/")
assert Path(checkpoint_dir).stem == "checkpoint_000001"

def testFindCheckpointDir(self):
checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt")
Expand Down

0 comments on commit 0a11946

Please sign in to comment.