Skip to content

Commit

Permalink
[air] pyarrow.fs persistence: Prep removal of air.Checkpoint depe…
Browse files Browse the repository at this point in the history
…ndence in rllib (ray-project#38590)

This PR prepares us to be able to delete the dependency on `air.Checkpoint`, by accepting the new `train.Checkpoint` and keeping feature parity for `Algorithm.from_checkpoint`.

Signed-off-by: Victor <[email protected]>
  • Loading branch information
justinvyu authored and Victor committed Oct 11, 2023
1 parent 95ff104 commit f44e241
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
20 changes: 8 additions & 12 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray.actor import ActorHandle
from ray.air.checkpoint import Checkpoint
from ray.train._checkpoint import Checkpoint as NewCheckpoint
import ray.cloudpickle as pickle

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
Expand Down Expand Up @@ -261,7 +262,7 @@ class Algorithm(Trainable, AlgorithmBase):

@staticmethod
def from_checkpoint(
checkpoint: Union[str, Checkpoint],
checkpoint: Union[str, Checkpoint, NewCheckpoint],
policy_ids: Optional[Container[PolicyID]] = None,
policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
policies_to_train: Optional[
Expand Down Expand Up @@ -2066,8 +2067,8 @@ def import_policy_model_from_h5(
self._sync_weights_to_workers(worker_set=self.workers)

@override(Trainable)
def save_checkpoint(self, checkpoint_dir: str) -> str:
"""Exports AIR Checkpoint to a local directory and returns its directory path.
def save_checkpoint(self, checkpoint_dir: str) -> None:
"""Exports checkpoint to a local directory.
The structure of an Algorithm checkpoint dir will be as follows::
Expand All @@ -2093,9 +2094,6 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
Args:
checkpoint_dir: The directory where the checkpoint files will be stored.
Returns:
The path to the created AIR Checkpoint directory.
"""
state = self.__getstate__()

Expand Down Expand Up @@ -2145,18 +2143,16 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
learner_state_dir = os.path.join(checkpoint_dir, "learner")
self.learner_group.save_state(learner_state_dir)

return checkpoint_dir

@override(Trainable)
def load_checkpoint(self, checkpoint: str) -> None:
# Checkpoint is provided as a directory name.
def load_checkpoint(self, checkpoint_dir: str) -> None:
# Checkpoint is provided as a local directory.
# Restore from the checkpoint file or dir.

checkpoint_info = get_checkpoint_info(checkpoint)
checkpoint_info = get_checkpoint_info(checkpoint_dir)
checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info)
self.__setstate__(checkpoint_data)
if self.config._enable_learner_api:
learner_state_dir = os.path.join(checkpoint, "learner")
learner_state_dir = os.path.join(checkpoint_dir, "learner")
self.learner_group.load_state(learner_state_dir)

@override(Trainable)
Expand Down
4 changes: 2 additions & 2 deletions rllib/offline/estimators/tests/test_ope.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def test_dr_on_estimate_on_dataset(self):
def test_algo_with_ope_from_checkpoint(self):
algo = self.config_dqn_on_cartpole.build()
tmpdir = tempfile.mkdtemp()
checkpoint = algo.save_checkpoint(tmpdir)
algo = Algorithm.from_checkpoint(checkpoint)
algo.save_checkpoint(tmpdir)
algo = Algorithm.from_checkpoint(tmpdir)
shutil.rmtree(tmpdir)


Expand Down
9 changes: 7 additions & 2 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ray
from ray.air.checkpoint import Checkpoint
from ray.train._checkpoint import Checkpoint as NewCheckpoint
from ray.rllib.utils.serialization import NOT_SERIALIZABLE, serialize_type
from ray.util import log_once
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -37,7 +38,9 @@


@PublicAPI(stability="alpha")
def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]:
def get_checkpoint_info(
checkpoint: Union[str, Checkpoint, NewCheckpoint]
) -> Dict[str, Any]:
"""Returns a dict with information about a Algorithm/Policy checkpoint.
If the given checkpoint is a >=v1.0 checkpoint directory, try reading all
Expand Down Expand Up @@ -74,6 +77,8 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]:
tmp_dir = tempfile.mkdtemp()
checkpoint.to_directory(tmp_dir)
checkpoint = tmp_dir
elif isinstance(checkpoint, NewCheckpoint):
checkpoint: str = checkpoint.to_directory()

# Checkpoint is dir.
if os.path.isdir(checkpoint):
Expand Down Expand Up @@ -181,7 +186,7 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]:

@PublicAPI(stability="beta")
def convert_to_msgpack_checkpoint(
checkpoint: Union[str, Checkpoint],
checkpoint: Union[str, Checkpoint, NewCheckpoint],
msgpack_checkpoint_dir: str,
) -> str:
"""Converts an Algorithm checkpoint (pickle based) to a msgpack based one.
Expand Down

0 comments on commit f44e241

Please sign in to comment.