Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Algorithm Level Checkpointing with Learner and RL Modules #34717

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,13 @@ py_test(
srcs = ["core/learner/torch/tests/test_torch_learner.py"]
)

py_test(
name ="tests/test_algorithm_save_load_checkpoint_learner",
tags = ["team:rllib", "core"],
size = "medium",
srcs = ["tests/test_algorithm_save_load_checkpoint_learner.py"]
)

py_test(
name = "test_bc_algorithm",
tags = ["team:rllib", "core"],
Expand Down
35 changes: 25 additions & 10 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from ray.rllib.utils.checkpoints import (
CHECKPOINT_VERSION,
CHECKPOINT_VERSION_LEARNER,
get_checkpoint_info,
try_import_msgpack,
)
Expand Down Expand Up @@ -2077,6 +2078,14 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
policy_state.pkl
pol_2/
policy_state.pkl
learner/
learner_state.json
module_state/
module_1/
...
optimizer_state/
optimizers_module_1/
...
rllib_checkpoint.json
algorithm_state.pkl
Expand All @@ -2099,7 +2108,10 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
policy_states = state["worker"].pop("policy_states", {})

# Add RLlib checkpoint version.
state["checkpoint_version"] = CHECKPOINT_VERSION
if self.config._enable_learner_api:
state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER
else:
state["checkpoint_version"] = CHECKPOINT_VERSION

# Write state (w/o policies) to disk.
state_file = os.path.join(checkpoint_dir, "algorithm_state.pkl")
Expand Down Expand Up @@ -2130,21 +2142,24 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
policy = self.get_policy(pid)
policy.export_checkpoint(policy_dir, policy_state=policy_state)

# if we are using the learner API, save the learner group state
if self.config._enable_learner_api:
learner_state_dir = os.path.join(checkpoint_dir, "learner")
self.learner_group.save_state(learner_state_dir)

return checkpoint_dir

@override(Trainable)
def load_checkpoint(self, checkpoint: Union[Dict, str]) -> None:
def load_checkpoint(self, checkpoint: str) -> None:
# Checkpoint is provided as a directory name.
# Restore from the checkpoint file or dir.
if isinstance(checkpoint, str):
checkpoint_info = get_checkpoint_info(checkpoint)
checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(
checkpoint_info
)
# Checkpoint is a checkpoint-as-dict -> Restore state from it as-is.
else:
checkpoint_data = checkpoint

checkpoint_info = get_checkpoint_info(checkpoint)
checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info)
self.__setstate__(checkpoint_data)
if self.config._enable_learner_api:
learner_state_dir = os.path.join(checkpoint, "learner")
self.learner_group.load_state(learner_state_dir)

@override(Trainable)
def log_result(self, result: ResultDict) -> None:
Expand Down
53 changes: 49 additions & 4 deletions rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import numpy as np
import torch
import tempfile
import tensorflow as tf
import tree # pip install dm-tree

Expand Down Expand Up @@ -74,8 +75,8 @@ def test_loss(self):
)

for fw in framework_iterator(config, ("tf2", "torch"), with_eager_tracing=True):
trainer = config.build()
policy = trainer.get_policy()
algo = config.build()
policy = algo.get_policy()

train_batch = SampleBatch(FAKE_BATCH)
train_batch = compute_gae_for_sample_batch(policy, train_batch)
Expand Down Expand Up @@ -109,14 +110,58 @@ def test_loss(self):
)
learner_group = learner_group_config.build()

# load the trainer weights onto the learner_group
learner_group.set_weights(trainer.get_weights())
# load the algo weights onto the learner_group
learner_group.set_weights(algo.get_weights())
results = learner_group.update(train_batch.as_multi_agent())

learner_group_loss = results[ALL_MODULES]["total_loss"]

check(learner_group_loss, policy_loss)

def test_save_load_state(self):
avnishn marked this conversation as resolved.
Show resolved Hide resolved
"""Tests saving and loading the state of the PPO Learner Group."""
config = (
ppo.PPOConfig()
.environment("CartPole-v1")
.rollouts(
num_rollout_workers=0,
)
.training(
gamma=0.99,
model=dict(
fcnet_hiddens=[10, 10],
fcnet_activation="linear",
vf_share_layers=False,
),
_enable_learner_api=True,
)
.rl_module(
_enable_rl_module_api=True,
)
)
algo = config.build()
policy = algo.get_policy()

for fw in framework_iterator(config, ("tf2", "torch"), with_eager_tracing=True):
algo_config = config.copy(copy_frozen=False)
algo_config.validate()
algo_config.freeze()
learner_group_config = algo_config.get_learner_group_config(
SingleAgentRLModuleSpec(
module_class=algo_config.rl_module_spec.module_class,
observation_space=policy.observation_space,
action_space=policy.action_space,
model_config_dict=policy.config["model"],
catalog_class=PPOCatalog,
)
)
learner_group1 = learner_group_config.build()
learner_group2 = learner_group_config.build()
with tempfile.TemporaryDirectory() as tmpdir:
learner_group1.save_state(tmpdir)
learner_group2.load_state(tmpdir)
check(learner_group1.get_state(), learner_group2.get_state())


if __name__ == "__main__":
import pytest
Expand Down
10 changes: 10 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,16 @@ def save_state(self, path: Union[str, pathlib.Path]) -> None:
NOTE: if path doesn't exist, then a new directory will be created. otherwise, it
will be appended to.
the state of the learner is saved in the following format:
checkpoint_dir/
learner_state.json
module_state/
module_1/
...
optimizer_state/
optimizers_module_1/
...
Args:
path: The path to the directory to save the state to.
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ def load_state(self, path: str) -> None:
if not path.exists():
raise ValueError(f"Path {path} does not exist.")
path = str(path.absolute())
assert len(self._workers) == self._worker_manager.num_healthy_actors()
if self.is_local:
self._learner.load_state(path)
else:
assert len(self._workers) == self._worker_manager.num_healthy_actors()
head_node_ip = socket.gethostbyname(socket.gethostname())
workers = self._worker_manager.healthy_actor_ids()

Expand Down
2 changes: 1 addition & 1 deletion rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def to_dict(self):
"""
catalog_class_path = (
serialize_type(type(self.catalog_class)) if self.catalog_class else ""
serialize_type(self.catalog_class) if self.catalog_class else ""
)
return {
"observation_space": gym_space_to_dict(self.observation_space),
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _update_env_seed_if_necessary(
NOTE: this may not work with remote environments (issue #18154).
"""
if not seed:
if seed is None:
avnishn marked this conversation as resolved.
Show resolved Hide resolved
return

# A single RL job is unlikely to have more than 10K
Expand Down
128 changes: 128 additions & 0 deletions rllib/tests/test_algorithm_save_load_checkpoint_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import tempfile
import unittest

import ray
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY


algorithms_and_configs = {
"PPO": (PPOConfig().training(train_batch_size=2, sgd_minibatch_size=2))
}


@ray.remote
avnishn marked this conversation as resolved.
Show resolved Hide resolved
def save_and_train(algo_cfg: AlgorithmConfig, env: str, tmpdir):
"""Create an algo, checkpoint it, then train for 2 iterations.
Note: This function uses a seeded algorithm that can modify the global random state.
Running it multiple times in the same process can affect other algorithms.
Making it a Ray task runs it in a separate process and prevents it from
affecting other algorithms' random state.
Args:
algo_cfg: The algorithm config to build the algo from.
env: The gym genvironment to train on.
tmpdir: The temporary directory to save the checkpoint to.
Returns:
The learner stats after 2 iterations of training.
"""
algo_cfg = (
algo_cfg.training(_enable_learner_api=True)
.rl_module(_enable_rl_module_api=True)
.rollouts(num_rollout_workers=0)
# setting min_time_s_per_iteration=0 and min_sample_timesteps_per_iteration=1
# to make sure that we get results as soon as sampling/training is done at
# least once
.reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=1)
avnishn marked this conversation as resolved.
Show resolved Hide resolved
.debugging(seed=10)
)
algo = algo_cfg.environment(env).build()

tmpdir = str(tmpdir)
algo.save_checkpoint(tmpdir)
for _ in range(2):
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY]


@ray.remote
def load_and_train(algo_cfg: AlgorithmConfig, env: str, tmpdir):
"""Loads the checkpoint saved by save_and_train and trains for 2 iterations.
Note: This function uses a seeded algorithm that can modify the global random state.
Running it multiple times in the same process can affect other algorithms.
Making it a Ray task runs it in a separate process and prevents it from
affecting other algorithms' random state.
Args:
algo_cfg: The algorithm config to build the algo from.
env: The gym genvironment to train on.
tmpdir: The temporary directory to save the checkpoint to.
Returns:
The learner stats after 2 iterations of training.
"""
algo_cfg = (
algo_cfg.training(_enable_learner_api=True)
.rl_module(_enable_rl_module_api=True)
.rollouts(num_rollout_workers=0)
# setting min_time_s_per_iteration=0 and min_sample_timesteps_per_iteration=1
# to make sure that we get results as soon as sampling/training is done at
# least once
.reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=1)
avnishn marked this conversation as resolved.
Show resolved Hide resolved
.debugging(seed=10)
)
algo = algo_cfg.environment(env).build()
tmpdir = str(tmpdir)
algo.load_checkpoint(tmpdir)
for _ in range(2):
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY]


class TestAlgorithmWithLearnerSaveAndRestore(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDowClass(cls) -> None:
ray.shutdown()

def test_save_and_restore(self):
for algo_name in algorithms_and_configs:
config = algorithms_and_configs[algo_name]
for _ in framework_iterator(config, frameworks=["torch", "tf2"]):
with tempfile.TemporaryDirectory() as tmpdir:
# create an algorithm, checkpoint it, then train for 2 iterations
ray.get(save_and_train.remote(config, "CartPole-v1", tmpdir))
# load that checkpoint into a new algorithm and train for 2
# iterations
results_algo_2 = ray.get(
load_and_train.remote(config, "CartPole-v1", tmpdir)
)

# load that checkpoint into another new algorithm and train for 2
# iterations
results_algo_3 = ray.get(
load_and_train.remote(config, "CartPole-v1", tmpdir)
)

# check that the results are the same across loaded algorithms
# they won't be the same as the first algorithm since the random
# state that is used for each algorithm is not preserved across
# checkpoints.
check(results_algo_3, results_algo_2)


if __name__ == "__main__":
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
27 changes: 17 additions & 10 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

# 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file
# indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`.

# 1.2: Introduces the checkpoint for the new Learner API if the Learner api is enabled.

CHECKPOINT_VERSION = version.Version("1.1")
CHECKPOINT_VERSION_LEARNER = version.Version("1.2")


@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -102,15 +106,15 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]:
rllib_checkpoint_info["checkpoint_version"]
)
info.update(rllib_checkpoint_info)

# No rllib_checkpoint.json file present: Warn and continue trying to figure out
# checkpoint info ourselves.
if log_once("no_rllib_checkpoint_json_file"):
logger.warning(
"No `rllib_checkpoint.json` file found in checkpoint directory "
f"{checkpoint}! Trying to extract checkpoint info from other files "
f"found in that dir."
)
else:
# No rllib_checkpoint.json file present: Warn and continue trying to figure
# out checkpoint info ourselves.
if log_once("no_rllib_checkpoint_json_file"):
logger.warning(
"No `rllib_checkpoint.json` file found in checkpoint directory "
f"{checkpoint}! Trying to extract checkpoint info from other files "
f"found in that dir."
)

# Policy checkpoint file found.
for extension in ["pkl", "msgpck"]:
Expand Down Expand Up @@ -222,7 +226,10 @@ def convert_to_msgpack_checkpoint(
state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE

# Add RLlib checkpoint version (as string).
state["checkpoint_version"] = str(CHECKPOINT_VERSION)
if state["config"]["_enable_learner_api"]:
state["checkpoint_version"] = str(CHECKPOINT_VERSION_LEARNER)
else:
state["checkpoint_version"] = str(CHECKPOINT_VERSION)

# Write state (w/o policies) to disk.
state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")
Expand Down