From 0a119462800789fbf5dee6f6abebfccf4a34c347 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 9 Dec 2022 15:01:04 -0800 Subject: [PATCH] [AIR] `Checkpoint` improvements (#30948) Boston dataset (used in tests) is/will be removed from sklearn. Signed-off-by: Antoni Baum Co-authored-by: Balaji Veeramani Signed-off-by: tmynn --- python/ray/air/checkpoint.py | 34 ++++++++++++++++---- python/ray/air/tests/test_checkpoints.py | 14 +++++++- python/ray/tune/tests/test_trainable_util.py | 3 +- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/python/ray/air/checkpoint.py b/python/ray/air/checkpoint.py index 337c4859f6af..afd269e851ce 100644 --- a/python/ray/air/checkpoint.py +++ b/python/ray/air/checkpoint.py @@ -156,7 +156,7 @@ class Checkpoint: @DeveloperAPI def __init__( self, - local_path: Optional[str] = None, + local_path: Optional[Union[str, os.PathLike]] = None, data_dict: Optional[dict] = None, uri: Optional[str] = None, obj_ref: Optional[ray.ObjectRef] = None, @@ -210,7 +210,9 @@ def __init__( else: raise ValueError("Cannot create checkpoint without data.") - self._local_path: Optional[str] = local_path + self._local_path: Optional[str] = ( + str(Path(local_path).resolve()) if local_path else local_path + ) self._data_dict: Optional[Dict[str, Any]] = data_dict self._uri: Optional[str] = uri self._obj_ref: Optional[ray.ObjectRef] = obj_ref @@ -230,6 +232,22 @@ def _metadata(self) -> _CheckpointMetadata: }, ) + def _copy_metadata_attrs_from(self, source: "Checkpoint") -> None: + """Copy in-place metadata attributes from ``source`` to self.""" + for attr, value in source._metadata.checkpoint_state.items(): + if attr in self._SERIALIZED_ATTRS: + setattr(self, attr, value) + + @_metadata.setter + def _metadata(self, metadata: _CheckpointMetadata): + if metadata.checkpoint_type is not self.__class__: + raise ValueError( + f"Checkpoint type in metadata must match {self.__class__}, " + f"got {metadata.checkpoint_type}" + ) + for attr, value in metadata.checkpoint_state.items(): + setattr(self, attr, value) + @property def uri(self) -> Optional[str]: """Return checkpoint URI, if available. @@ -259,7 +277,7 @@ def uri(self) -> Optional[str]: return self._uri if self._local_path and Path(self._local_path).exists(): - return "file://" + self._local_path + return f"file://{self._local_path}" return None @@ -290,7 +308,7 @@ def to_bytes(self) -> bytes: data_dict = self.to_dict() if "bytes_data" in data_dict: return data_dict["bytes_data"] - return pickle.dumps(self.to_dict()) + return pickle.dumps(data_dict) @classmethod def from_dict(cls, data: dict) -> "Checkpoint": @@ -421,7 +439,7 @@ def to_object_ref(self) -> ray.ObjectRef: ) @classmethod - def from_directory(cls, path: str) -> "Checkpoint": + def from_directory(cls, path: Union[str, os.PathLike]) -> "Checkpoint": """Create checkpoint object from directory. Args: @@ -463,12 +481,14 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint": if type(other) is cls: return other - return cls( + new_checkpoint = cls( local_path=other._local_path, data_dict=other._data_dict, uri=other._uri, obj_ref=other._obj_ref, ) + new_checkpoint._copy_metadata_attrs_from(other) + return new_checkpoint def _get_temporary_checkpoint_dir(self) -> str: """Return the name for the temporary checkpoint dir.""" @@ -532,7 +552,7 @@ def _to_directory(self, path: str) -> None: local_path = self._local_path external_path = _get_external_path(self._uri) if local_path: - if local_path != path: + if Path(local_path).resolve() != Path(path).resolve(): # If this exists on the local path, just copy over if path and os.path.exists(path): shutil.rmtree(path) diff --git a/python/ray/air/tests/test_checkpoints.py b/python/ray/air/tests/test_checkpoints.py index 3323e68a22eb..3c60926126eb 100644 --- a/python/ray/air/tests/test_checkpoints.py +++ b/python/ray/air/tests/test_checkpoints.py @@ -4,6 +4,7 @@ import shutil import tempfile import unittest +from pathlib import Path from typing import Any import pytest @@ -47,6 +48,10 @@ class OtherStubCheckpoint(Checkpoint): pass +class OtherStubCheckpointWithAttrs(Checkpoint): + _SERIALIZED_ATTRS = StubCheckpoint._SERIALIZED_ATTRS + + def test_from_checkpoint(): checkpoint = Checkpoint.from_dict({"spam": "ham"}) assert type(StubCheckpoint.from_checkpoint(checkpoint)) is StubCheckpoint @@ -56,6 +61,13 @@ def test_from_checkpoint(): checkpoint.foo = "bar" assert StubCheckpoint.from_checkpoint(checkpoint).foo == "bar" + # Check that attributes persist if the new checkpoint + # has them as well. + # Check that attributes persist if same checkpoint type. + checkpoint = StubCheckpoint.from_dict({"spam": "ham"}) + checkpoint.foo = "bar" + assert OtherStubCheckpointWithAttrs.from_checkpoint(checkpoint).foo == "bar" + class TestCheckpointTypeCasting: def test_dict(self): @@ -491,7 +503,7 @@ def test_obj_store_cp_as_directory(self): with checkpoint.as_directory() as checkpoint_dir: assert os.path.exists(checkpoint_dir) - assert checkpoint_dir.endswith(checkpoint._uuid.hex) + assert Path(checkpoint_dir).stem.endswith(checkpoint._uuid.hex) assert not os.path.exists(checkpoint_dir) diff --git a/python/ray/tune/tests/test_trainable_util.py b/python/ray/tune/tests/test_trainable_util.py index dc1657a79c49..51b32894d086 100644 --- a/python/ray/tune/tests/test_trainable_util.py +++ b/python/ray/tune/tests/test_trainable_util.py @@ -5,6 +5,7 @@ import sys import shutil import unittest +from pathlib import Path from unittest.mock import patch import ray @@ -61,7 +62,7 @@ def tune_one(config=None, checkpoint_dir=None): ) df = a.dataframe() checkpoint_dir = a.get_best_checkpoint(df["logdir"].iloc[0])._local_path - assert checkpoint_dir.endswith("/checkpoint_000001/") + assert Path(checkpoint_dir).stem == "checkpoint_000001" def testFindCheckpointDir(self): checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt")