Skip to content

Commit

Permalink
[RLlib] Algorithm Level Checkpointing with Learner and RL Modules (#3…
Browse files Browse the repository at this point in the history
…4717)

Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn authored Apr 26, 2023
1 parent d99ae15 commit 6b59692
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 27 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,13 @@ py_test(
srcs = ["core/learner/torch/tests/test_torch_learner.py"]
)

py_test(
name ="tests/test_algorithm_save_load_checkpoint_learner",
tags = ["team:rllib", "core"],
size = "medium",
srcs = ["tests/test_algorithm_save_load_checkpoint_learner.py"]
)

py_test(
name = "test_bc_algorithm",
tags = ["team:rllib", "core"],
Expand Down
35 changes: 25 additions & 10 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)
from ray.rllib.utils.checkpoints import (
CHECKPOINT_VERSION,
CHECKPOINT_VERSION_LEARNER,
get_checkpoint_info,
try_import_msgpack,
)
Expand Down Expand Up @@ -2077,6 +2078,14 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
policy_state.pkl
pol_2/
policy_state.pkl
learner/
learner_state.json
module_state/
module_1/
...
optimizer_state/
optimizers_module_1/
...
rllib_checkpoint.json
algorithm_state.pkl
Expand All @@ -2099,7 +2108,10 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
policy_states = state["worker"].pop("policy_states", {})

# Add RLlib checkpoint version.
state["checkpoint_version"] = CHECKPOINT_VERSION
if self.config._enable_learner_api:
state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER
else:
state["checkpoint_version"] = CHECKPOINT_VERSION

# Write state (w/o policies) to disk.
state_file = os.path.join(checkpoint_dir, "algorithm_state.pkl")
Expand Down Expand Up @@ -2130,21 +2142,24 @@ def save_checkpoint(self, checkpoint_dir: str) -> str:
policy = self.get_policy(pid)
policy.export_checkpoint(policy_dir, policy_state=policy_state)

# if we are using the learner API, save the learner group state
if self.config._enable_learner_api:
learner_state_dir = os.path.join(checkpoint_dir, "learner")
self.learner_group.save_state(learner_state_dir)

return checkpoint_dir

@override(Trainable)
def load_checkpoint(self, checkpoint: Union[Dict, str]) -> None:
def load_checkpoint(self, checkpoint: str) -> None:
# Checkpoint is provided as a directory name.
# Restore from the checkpoint file or dir.
if isinstance(checkpoint, str):
checkpoint_info = get_checkpoint_info(checkpoint)
checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(
checkpoint_info
)
# Checkpoint is a checkpoint-as-dict -> Restore state from it as-is.
else:
checkpoint_data = checkpoint

checkpoint_info = get_checkpoint_info(checkpoint)
checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info)
self.__setstate__(checkpoint_data)
if self.config._enable_learner_api:
learner_state_dir = os.path.join(checkpoint, "learner")
self.learner_group.load_state(learner_state_dir)

@override(Trainable)
def log_result(self, result: ResultDict) -> None:
Expand Down
53 changes: 49 additions & 4 deletions rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import numpy as np
import torch
import tempfile
import tensorflow as tf
import tree # pip install dm-tree

Expand Down Expand Up @@ -74,8 +75,8 @@ def test_loss(self):
)

for fw in framework_iterator(config, ("tf2", "torch"), with_eager_tracing=True):
trainer = config.build()
policy = trainer.get_policy()
algo = config.build()
policy = algo.get_policy()

train_batch = SampleBatch(FAKE_BATCH)
train_batch = compute_gae_for_sample_batch(policy, train_batch)
Expand Down Expand Up @@ -109,14 +110,58 @@ def test_loss(self):
)
learner_group = learner_group_config.build()

# load the trainer weights onto the learner_group
learner_group.set_weights(trainer.get_weights())
# load the algo weights onto the learner_group
learner_group.set_weights(algo.get_weights())
results = learner_group.update(train_batch.as_multi_agent())

learner_group_loss = results[ALL_MODULES]["total_loss"]

check(learner_group_loss, policy_loss)

def test_save_load_state(self):
"""Tests saving and loading the state of the PPO Learner Group."""
config = (
ppo.PPOConfig()
.environment("CartPole-v1")
.rollouts(
num_rollout_workers=0,
)
.training(
gamma=0.99,
model=dict(
fcnet_hiddens=[10, 10],
fcnet_activation="linear",
vf_share_layers=False,
),
_enable_learner_api=True,
)
.rl_module(
_enable_rl_module_api=True,
)
)
algo = config.build()
policy = algo.get_policy()

for fw in framework_iterator(config, ("tf2", "torch"), with_eager_tracing=True):
algo_config = config.copy(copy_frozen=False)
algo_config.validate()
algo_config.freeze()
learner_group_config = algo_config.get_learner_group_config(
SingleAgentRLModuleSpec(
module_class=algo_config.rl_module_spec.module_class,
observation_space=policy.observation_space,
action_space=policy.action_space,
model_config_dict=policy.config["model"],
catalog_class=PPOCatalog,
)
)
learner_group1 = learner_group_config.build()
learner_group2 = learner_group_config.build()
with tempfile.TemporaryDirectory() as tmpdir:
learner_group1.save_state(tmpdir)
learner_group2.load_state(tmpdir)
check(learner_group1.get_state(), learner_group2.get_state())


if __name__ == "__main__":
import pytest
Expand Down
10 changes: 10 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,16 @@ def save_state(self, path: Union[str, pathlib.Path]) -> None:
NOTE: if path doesn't exist, then a new directory will be created. otherwise, it
will be appended to.
the state of the learner is saved in the following format:
checkpoint_dir/
learner_state.json
module_state/
module_1/
...
optimizer_state/
optimizers_module_1/
...
Args:
path: The path to the directory to save the state to.
Expand Down
2 changes: 1 addition & 1 deletion rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ def load_state(self, path: str) -> None:
if not path.exists():
raise ValueError(f"Path {path} does not exist.")
path = str(path.absolute())
assert len(self._workers) == self._worker_manager.num_healthy_actors()
if self.is_local:
self._learner.load_state(path)
else:
assert len(self._workers) == self._worker_manager.num_healthy_actors()
head_node_ip = socket.gethostbyname(socket.gethostname())
workers = self._worker_manager.healthy_actor_ids()

Expand Down
2 changes: 1 addition & 1 deletion rllib/core/rl_module/rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def to_dict(self):
"""
catalog_class_path = (
serialize_type(type(self.catalog_class)) if self.catalog_class else ""
serialize_type(self.catalog_class) if self.catalog_class else ""
)
return {
"observation_space": gym_space_to_dict(self.observation_space),
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _update_env_seed_if_necessary(
NOTE: this may not work with remote environments (issue #18154).
"""
if not seed:
if seed is None:
return

# A single RL job is unlikely to have more than 10K
Expand Down
128 changes: 128 additions & 0 deletions rllib/tests/test_algorithm_save_load_checkpoint_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import tempfile
import unittest

import ray
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY


algorithms_and_configs = {
"PPO": (PPOConfig().training(train_batch_size=2, sgd_minibatch_size=2))
}


@ray.remote
def save_and_train(algo_cfg: AlgorithmConfig, env: str, tmpdir):
"""Create an algo, checkpoint it, then train for 2 iterations.
Note: This function uses a seeded algorithm that can modify the global random state.
Running it multiple times in the same process can affect other algorithms.
Making it a Ray task runs it in a separate process and prevents it from
affecting other algorithms' random state.
Args:
algo_cfg: The algorithm config to build the algo from.
env: The gym genvironment to train on.
tmpdir: The temporary directory to save the checkpoint to.
Returns:
The learner stats after 2 iterations of training.
"""
algo_cfg = (
algo_cfg.training(_enable_learner_api=True)
.rl_module(_enable_rl_module_api=True)
.rollouts(num_rollout_workers=0)
# setting min_time_s_per_iteration=0 and min_sample_timesteps_per_iteration=1
# to make sure that we get results as soon as sampling/training is done at
# least once
.reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=1)
.debugging(seed=10)
)
algo = algo_cfg.environment(env).build()

tmpdir = str(tmpdir)
algo.save_checkpoint(tmpdir)
for _ in range(2):
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY]


@ray.remote
def load_and_train(algo_cfg: AlgorithmConfig, env: str, tmpdir):
"""Loads the checkpoint saved by save_and_train and trains for 2 iterations.
Note: This function uses a seeded algorithm that can modify the global random state.
Running it multiple times in the same process can affect other algorithms.
Making it a Ray task runs it in a separate process and prevents it from
affecting other algorithms' random state.
Args:
algo_cfg: The algorithm config to build the algo from.
env: The gym genvironment to train on.
tmpdir: The temporary directory to save the checkpoint to.
Returns:
The learner stats after 2 iterations of training.
"""
algo_cfg = (
algo_cfg.training(_enable_learner_api=True)
.rl_module(_enable_rl_module_api=True)
.rollouts(num_rollout_workers=0)
# setting min_time_s_per_iteration=0 and min_sample_timesteps_per_iteration=1
# to make sure that we get results as soon as sampling/training is done at
# least once
.reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=1)
.debugging(seed=10)
)
algo = algo_cfg.environment(env).build()
tmpdir = str(tmpdir)
algo.load_checkpoint(tmpdir)
for _ in range(2):
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY]


class TestAlgorithmWithLearnerSaveAndRestore(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDowClass(cls) -> None:
ray.shutdown()

def test_save_and_restore(self):
for algo_name in algorithms_and_configs:
config = algorithms_and_configs[algo_name]
for _ in framework_iterator(config, frameworks=["torch", "tf2"]):
with tempfile.TemporaryDirectory() as tmpdir:
# create an algorithm, checkpoint it, then train for 2 iterations
ray.get(save_and_train.remote(config, "CartPole-v1", tmpdir))
# load that checkpoint into a new algorithm and train for 2
# iterations
results_algo_2 = ray.get(
load_and_train.remote(config, "CartPole-v1", tmpdir)
)

# load that checkpoint into another new algorithm and train for 2
# iterations
results_algo_3 = ray.get(
load_and_train.remote(config, "CartPole-v1", tmpdir)
)

# check that the results are the same across loaded algorithms
# they won't be the same as the first algorithm since the random
# state that is used for each algorithm is not preserved across
# checkpoints.
check(results_algo_3, results_algo_2)


if __name__ == "__main__":
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
27 changes: 17 additions & 10 deletions rllib/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@

# 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file
# indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`.

# 1.2: Introduces the checkpoint for the new Learner API if the Learner api is enabled.

CHECKPOINT_VERSION = version.Version("1.1")
CHECKPOINT_VERSION_LEARNER = version.Version("1.2")


@PublicAPI(stability="alpha")
Expand Down Expand Up @@ -102,15 +106,15 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]:
rllib_checkpoint_info["checkpoint_version"]
)
info.update(rllib_checkpoint_info)

# No rllib_checkpoint.json file present: Warn and continue trying to figure out
# checkpoint info ourselves.
if log_once("no_rllib_checkpoint_json_file"):
logger.warning(
"No `rllib_checkpoint.json` file found in checkpoint directory "
f"{checkpoint}! Trying to extract checkpoint info from other files "
f"found in that dir."
)
else:
# No rllib_checkpoint.json file present: Warn and continue trying to figure
# out checkpoint info ourselves.
if log_once("no_rllib_checkpoint_json_file"):
logger.warning(
"No `rllib_checkpoint.json` file found in checkpoint directory "
f"{checkpoint}! Trying to extract checkpoint info from other files "
f"found in that dir."
)

# Policy checkpoint file found.
for extension in ["pkl", "msgpck"]:
Expand Down Expand Up @@ -222,7 +226,10 @@ def convert_to_msgpack_checkpoint(
state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE

# Add RLlib checkpoint version (as string).
state["checkpoint_version"] = str(CHECKPOINT_VERSION)
if state["config"]["_enable_learner_api"]:
state["checkpoint_version"] = str(CHECKPOINT_VERSION_LEARNER)
else:
state["checkpoint_version"] = str(CHECKPOINT_VERSION)

# Write state (w/o policies) to disk.
state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")
Expand Down

0 comments on commit 6b59692

Please sign in to comment.