Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Checkpoint improvements #30948

Merged
merged 14 commits into from
Dec 9, 2022
34 changes: 27 additions & 7 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class Checkpoint:
@DeveloperAPI
def __init__(
self,
local_path: Optional[str] = None,
local_path: Optional[Union[str, os.PathLike]] = None,
data_dict: Optional[dict] = None,
uri: Optional[str] = None,
obj_ref: Optional[ray.ObjectRef] = None,
Expand Down Expand Up @@ -210,7 +210,9 @@ def __init__(
else:
raise ValueError("Cannot create checkpoint without data.")

self._local_path: Optional[str] = local_path
self._local_path: Optional[str] = (
str(Path(local_path).resolve()) if local_path else local_path
)
self._data_dict: Optional[Dict[str, Any]] = data_dict
self._uri: Optional[str] = uri
self._obj_ref: Optional[ray.ObjectRef] = obj_ref
Expand All @@ -230,6 +232,22 @@ def _metadata(self) -> _CheckpointMetadata:
},
)

def _copy_metadata_attrs_from(self, source: "Checkpoint") -> None:
"""Copy in-place metadata attributes from ``source`` to self."""
for attr, value in source._metadata.checkpoint_state.items():
if attr in self._SERIALIZED_ATTRS:
setattr(self, attr, value)

@_metadata.setter
def _metadata(self, metadata: _CheckpointMetadata):
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
if metadata.checkpoint_type is not self.__class__:
raise ValueError(
f"Checkpoint type in metadata must match {self.__class__}, "
f"got {metadata.checkpoint_type}"
)
for attr, value in metadata.checkpoint_state.items():
setattr(self, attr, value)
Yard1 marked this conversation as resolved.
Show resolved Hide resolved

@property
def uri(self) -> Optional[str]:
"""Return checkpoint URI, if available.
Expand Down Expand Up @@ -259,7 +277,7 @@ def uri(self) -> Optional[str]:
return self._uri

if self._local_path and Path(self._local_path).exists():
return "file://" + self._local_path
return f"file://{self._local_path}"

return None

Expand Down Expand Up @@ -290,7 +308,7 @@ def to_bytes(self) -> bytes:
data_dict = self.to_dict()
if "bytes_data" in data_dict:
return data_dict["bytes_data"]
return pickle.dumps(self.to_dict())
return pickle.dumps(data_dict)

@classmethod
def from_dict(cls, data: dict) -> "Checkpoint":
Expand Down Expand Up @@ -421,7 +439,7 @@ def to_object_ref(self) -> ray.ObjectRef:
)

@classmethod
def from_directory(cls, path: str) -> "Checkpoint":
def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint":
"""Create checkpoint object from directory.

Args:
Expand Down Expand Up @@ -463,12 +481,14 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
if type(other) is cls:
return other

return cls(
new_checkpoint = cls(
local_path=other._local_path,
data_dict=other._data_dict,
uri=other._uri,
obj_ref=other._obj_ref,
)
new_checkpoint._copy_metadata_attrs_from(other)
return new_checkpoint

def _get_temporary_checkpoint_dir(self) -> str:
"""Return the name for the temporary checkpoint dir."""
Expand Down Expand Up @@ -532,7 +552,7 @@ def _to_directory(self, path: str) -> None:
local_path = self._local_path
external_path = _get_external_path(self._uri)
if local_path:
if local_path != path:
if Path(local_path).resolve() != Path(path).resolve():
# If this exists on the local path, just copy over
if path and os.path.exists(path):
shutil.rmtree(path)
Expand Down
14 changes: 13 additions & 1 deletion python/ray/air/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import tempfile
import unittest
from pathlib import Path
from typing import Any

import pytest
Expand Down Expand Up @@ -47,6 +48,10 @@ class OtherStubCheckpoint(Checkpoint):
pass


class OtherStubCheckpointWithAttrs(Checkpoint):
_SERIALIZED_ATTRS = StubCheckpoint._SERIALIZED_ATTRS


def test_from_checkpoint():
checkpoint = Checkpoint.from_dict({"spam": "ham"})
assert type(StubCheckpoint.from_checkpoint(checkpoint)) is StubCheckpoint
Expand All @@ -56,6 +61,13 @@ def test_from_checkpoint():
checkpoint.foo = "bar"
assert StubCheckpoint.from_checkpoint(checkpoint).foo == "bar"

# Check that attributes persist if the new checkpoint
# has them as well.
# Check that attributes persist if same checkpoint type.
checkpoint = StubCheckpoint.from_dict({"spam": "ham"})
checkpoint.foo = "bar"
assert OtherStubCheckpointWithAttrs.from_checkpoint(checkpoint).foo == "bar"


class TestCheckpointTypeCasting:
def test_dict(self):
Expand Down Expand Up @@ -491,7 +503,7 @@ def test_obj_store_cp_as_directory(self):

with checkpoint.as_directory() as checkpoint_dir:
assert os.path.exists(checkpoint_dir)
assert checkpoint_dir.endswith(checkpoint._uuid.hex)
assert Path(checkpoint_dir).stem.endswith(checkpoint._uuid.hex)

assert not os.path.exists(checkpoint_dir)

Expand Down
3 changes: 2 additions & 1 deletion python/ray/tune/tests/test_trainable_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import shutil
import unittest
from pathlib import Path
from unittest.mock import patch

import ray
Expand Down Expand Up @@ -61,7 +62,7 @@ def tune_one(config=None, checkpoint_dir=None):
)
df = a.dataframe()
checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0])._local_path
assert checkpoint_dir.endswith("/checkpoint_000001/")
assert Path(checkpoint_dir).stem == "checkpoint_000001"

def testFindCheckpointDir(self):
checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt")
Expand Down