diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 366460197116..e70cce4059c8 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -32,6 +32,7 @@ from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag from ray.actor import ActorHandle from ray.air.checkpoint import Checkpoint +from ray.train._checkpoint import Checkpoint as NewCheckpoint import ray.cloudpickle as pickle from ray.rllib.algorithms.algorithm_config import AlgorithmConfig @@ -261,7 +262,7 @@ class Algorithm(Trainable, AlgorithmBase): @staticmethod def from_checkpoint( - checkpoint: Union[str, Checkpoint], + checkpoint: Union[str, Checkpoint, NewCheckpoint], policy_ids: Optional[Container[PolicyID]] = None, policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, policies_to_train: Optional[ @@ -2066,8 +2067,8 @@ def import_policy_model_from_h5( self._sync_weights_to_workers(worker_set=self.workers) @override(Trainable) - def save_checkpoint(self, checkpoint_dir: str) -> str: - """Exports AIR Checkpoint to a local directory and returns its directory path. + def save_checkpoint(self, checkpoint_dir: str) -> None: + """Exports checkpoint to a local directory. The structure of an Algorithm checkpoint dir will be as follows:: @@ -2093,9 +2094,6 @@ def save_checkpoint(self, checkpoint_dir: str) -> str: Args: checkpoint_dir: The directory where the checkpoint files will be stored. - - Returns: - The path to the created AIR Checkpoint directory. """ state = self.__getstate__() @@ -2145,18 +2143,16 @@ def save_checkpoint(self, checkpoint_dir: str) -> str: learner_state_dir = os.path.join(checkpoint_dir, "learner") self.learner_group.save_state(learner_state_dir) - return checkpoint_dir - @override(Trainable) - def load_checkpoint(self, checkpoint: str) -> None: - # Checkpoint is provided as a directory name. + def load_checkpoint(self, checkpoint_dir: str) -> None: + # Checkpoint is provided as a local directory. # Restore from the checkpoint file or dir. - checkpoint_info = get_checkpoint_info(checkpoint) + checkpoint_info = get_checkpoint_info(checkpoint_dir) checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info) self.__setstate__(checkpoint_data) if self.config._enable_learner_api: - learner_state_dir = os.path.join(checkpoint, "learner") + learner_state_dir = os.path.join(checkpoint_dir, "learner") self.learner_group.load_state(learner_state_dir) @override(Trainable) diff --git a/rllib/offline/estimators/tests/test_ope.py b/rllib/offline/estimators/tests/test_ope.py index fa15f780cb3d..dd876a84c6aa 100644 --- a/rllib/offline/estimators/tests/test_ope.py +++ b/rllib/offline/estimators/tests/test_ope.py @@ -280,8 +280,8 @@ def test_dr_on_estimate_on_dataset(self): def test_algo_with_ope_from_checkpoint(self): algo = self.config_dqn_on_cartpole.build() tmpdir = tempfile.mkdtemp() - checkpoint = algo.save_checkpoint(tmpdir) - algo = Algorithm.from_checkpoint(checkpoint) + algo.save_checkpoint(tmpdir) + algo = Algorithm.from_checkpoint(tmpdir) shutil.rmtree(tmpdir) diff --git a/rllib/utils/checkpoints.py b/rllib/utils/checkpoints.py index 19e5cc145b31..428ef3d64e8d 100644 --- a/rllib/utils/checkpoints.py +++ b/rllib/utils/checkpoints.py @@ -8,6 +8,7 @@ import ray from ray.air.checkpoint import Checkpoint +from ray.train._checkpoint import Checkpoint as NewCheckpoint from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type from ray.util import log_once from ray.util.annotations import PublicAPI @@ -37,7 +38,9 @@ @PublicAPI(stability="alpha") -def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: +def get_checkpoint_info( + checkpoint: Union[str, Checkpoint, NewCheckpoint] +) -> Dict[str, Any]: """Returns a dict with information about a Algorithm/Policy checkpoint. If the given checkpoint is a >=v1.0 checkpoint directory, try reading all @@ -74,6 +77,8 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: tmp_dir = tempfile.mkdtemp() checkpoint.to_directory(tmp_dir) checkpoint = tmp_dir + elif isinstance(checkpoint, NewCheckpoint): + checkpoint: str = checkpoint.to_directory() # Checkpoint is dir. if os.path.isdir(checkpoint): @@ -181,7 +186,7 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: @PublicAPI(stability="beta") def convert_to_msgpack_checkpoint( - checkpoint: Union[str, Checkpoint], + checkpoint: Union[str, Checkpoint, NewCheckpoint], msgpack_checkpoint_dir: str, ) -> str: """Converts an Algorithm checkpoint (pickle based) to a msgpack based one.