Skip to content

Commit

Permalink
[RLlib] Enable cloud checkpointing. (#47682)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Sep 25, 2024
1 parent b1624c9 commit a6cf9d7
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 41 deletions.
8 changes: 4 additions & 4 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,12 @@ Getting and setting state


.. testcode::
:hide:
:hide:

import tempfile
import tempfile

LEARNER_CKPT_DIR = str(tempfile.TemporaryDirectory())
LEARNER_GROUP_CKPT_DIR = str(tempfile.TemporaryDirectory())
LEARNER_CKPT_DIR = tempfile.mkdtemp()
LEARNER_GROUP_CKPT_DIR = tempfile.mkdtemp()


Checkpointing
Expand Down
6 changes: 5 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
from packaging import version
import pathlib
import pyarrow.fs
import re
import tempfile
import time
Expand Down Expand Up @@ -305,6 +306,7 @@ class Algorithm(Checkpointable, Trainable, AlgorithmBase):
def from_checkpoint(
cls,
path: Optional[Union[str, Checkpoint]] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
*,
# @OldAPIStack
policy_ids: Optional[Collection[PolicyID]] = None,
Expand All @@ -324,6 +326,8 @@ def from_checkpoint(
Args:
path: The path (str) to the checkpoint directory to use
or an AIR Checkpoint instance to restore from.
filesystem: PyArrow FileSystem to use to access data at the `path`. If not
specified, this is inferred from the URI scheme of `path`.
policy_ids: Optional list of PolicyIDs to recover. This allows users to
restore an Algorithm with only a subset of the originally present
Policies.
Expand Down Expand Up @@ -371,7 +375,7 @@ def from_checkpoint(
)
# New API stack -> Use Checkpointable's default implementation.
elif checkpoint_info["checkpoint_version"] >= version.Version("2.0"):
return super().from_checkpoint(path, **kwargs)
return super().from_checkpoint(path, filesystem=filesystem, **kwargs)

# This is a msgpack checkpoint.
if checkpoint_info["format"] == "msgpack":
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class TestAlgorithm(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
ray.init(local_mode=True)
register_env("multi_cart", lambda cfg: MultiAgentCartPole(cfg))

@classmethod
Expand Down
Loading

0 comments on commit a6cf9d7

Please sign in to comment.