diff --git a/rllib/BUILD b/rllib/BUILD index f2d4537eedda..5bdc5b33fb19 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1982,6 +1982,13 @@ py_test( srcs = ["core/learner/torch/tests/test_torch_learner.py"] ) +py_test( + name ="tests/test_algorithm_save_load_checkpoint_learner", + tags = ["team:rllib", "core"], + size = "medium", + srcs = ["tests/test_algorithm_save_load_checkpoint_learner.py"] +) + py_test( name = "test_bc_algorithm", tags = ["team:rllib", "core"], diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 8385b3638607..90595bb79c9c 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -76,6 +76,7 @@ ) from ray.rllib.utils.checkpoints import ( CHECKPOINT_VERSION, + CHECKPOINT_VERSION_LEARNER, get_checkpoint_info, try_import_msgpack, ) @@ -2077,6 +2078,14 @@ def save_checkpoint(self, checkpoint_dir: str) -> str: policy_state.pkl pol_2/ policy_state.pkl + learner/ + learner_state.json + module_state/ + module_1/ + ... + optimizer_state/ + optimizers_module_1/ + ... rllib_checkpoint.json algorithm_state.pkl @@ -2099,7 +2108,10 @@ def save_checkpoint(self, checkpoint_dir: str) -> str: policy_states = state["worker"].pop("policy_states", {}) # Add RLlib checkpoint version. - state["checkpoint_version"] = CHECKPOINT_VERSION + if self.config._enable_learner_api: + state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER + else: + state["checkpoint_version"] = CHECKPOINT_VERSION # Write state (w/o policies) to disk. state_file = os.path.join(checkpoint_dir, "algorithm_state.pkl") @@ -2130,21 +2142,24 @@ def save_checkpoint(self, checkpoint_dir: str) -> str: policy = self.get_policy(pid) policy.export_checkpoint(policy_dir, policy_state=policy_state) + # if we are using the learner API, save the learner group state + if self.config._enable_learner_api: + learner_state_dir = os.path.join(checkpoint_dir, "learner") + self.learner_group.save_state(learner_state_dir) + return checkpoint_dir @override(Trainable) - def load_checkpoint(self, checkpoint: Union[Dict, str]) -> None: + def load_checkpoint(self, checkpoint: str) -> None: # Checkpoint is provided as a directory name. # Restore from the checkpoint file or dir. - if isinstance(checkpoint, str): - checkpoint_info = get_checkpoint_info(checkpoint) - checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state( - checkpoint_info - ) - # Checkpoint is a checkpoint-as-dict -> Restore state from it as-is. - else: - checkpoint_data = checkpoint + + checkpoint_info = get_checkpoint_info(checkpoint) + checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(checkpoint_info) self.__setstate__(checkpoint_data) + if self.config._enable_learner_api: + learner_state_dir = os.path.join(checkpoint, "learner") + self.learner_group.load_state(learner_state_dir) @override(Trainable) def log_result(self, result: ResultDict) -> None: diff --git a/rllib/algorithms/ppo/tests/test_ppo_learner.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py index 12e910ed8599..40aa98389539 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_learner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -2,6 +2,7 @@ import unittest import numpy as np import torch +import tempfile import tensorflow as tf import tree # pip install dm-tree @@ -74,8 +75,8 @@ def test_loss(self): ) for fw in framework_iterator(config, ("tf2", "torch"), with_eager_tracing=True): - trainer = config.build() - policy = trainer.get_policy() + algo = config.build() + policy = algo.get_policy() train_batch = SampleBatch(FAKE_BATCH) train_batch = compute_gae_for_sample_batch(policy, train_batch) @@ -109,14 +110,58 @@ def test_loss(self): ) learner_group = learner_group_config.build() - # load the trainer weights onto the learner_group - learner_group.set_weights(trainer.get_weights()) + # load the algo weights onto the learner_group + learner_group.set_weights(algo.get_weights()) results = learner_group.update(train_batch.as_multi_agent()) learner_group_loss = results[ALL_MODULES]["total_loss"] check(learner_group_loss, policy_loss) + def test_save_load_state(self): + """Tests saving and loading the state of the PPO Learner Group.""" + config = ( + ppo.PPOConfig() + .environment("CartPole-v1") + .rollouts( + num_rollout_workers=0, + ) + .training( + gamma=0.99, + model=dict( + fcnet_hiddens=[10, 10], + fcnet_activation="linear", + vf_share_layers=False, + ), + _enable_learner_api=True, + ) + .rl_module( + _enable_rl_module_api=True, + ) + ) + algo = config.build() + policy = algo.get_policy() + + for fw in framework_iterator(config, ("tf2", "torch"), with_eager_tracing=True): + algo_config = config.copy(copy_frozen=False) + algo_config.validate() + algo_config.freeze() + learner_group_config = algo_config.get_learner_group_config( + SingleAgentRLModuleSpec( + module_class=algo_config.rl_module_spec.module_class, + observation_space=policy.observation_space, + action_space=policy.action_space, + model_config_dict=policy.config["model"], + catalog_class=PPOCatalog, + ) + ) + learner_group1 = learner_group_config.build() + learner_group2 = learner_group_config.build() + with tempfile.TemporaryDirectory() as tmpdir: + learner_group1.save_state(tmpdir) + learner_group2.load_state(tmpdir) + check(learner_group1.get_state(), learner_group2.get_state()) + if __name__ == "__main__": import pytest diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index e4ef00dd6ad5..4847567dac23 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -906,6 +906,16 @@ def save_state(self, path: Union[str, pathlib.Path]) -> None: NOTE: if path doesn't exist, then a new directory will be created. otherwise, it will be appended to. + the state of the learner is saved in the following format: + + checkpoint_dir/ + learner_state.json + module_state/ + module_1/ + ... + optimizer_state/ + optimizers_module_1/ + ... Args: path: The path to the directory to save the state to. diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index c53ee9b78dd7..9b2774438b69 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -475,10 +475,10 @@ def load_state(self, path: str) -> None: if not path.exists(): raise ValueError(f"Path {path} does not exist.") path = str(path.absolute()) - assert len(self._workers) == self._worker_manager.num_healthy_actors() if self.is_local: self._learner.load_state(path) else: + assert len(self._workers) == self._worker_manager.num_healthy_actors() head_node_ip = socket.gethostbyname(socket.gethostname()) workers = self._worker_manager.healthy_actor_ids() diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 905865f52d57..c403f54a7435 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -186,7 +186,7 @@ def to_dict(self): """ catalog_class_path = ( - serialize_type(type(self.catalog_class)) if self.catalog_class else "" + serialize_type(self.catalog_class) if self.catalog_class else "" ) return { "observation_space": gym_space_to_dict(self.observation_space), diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 003083066c17..a30c07fa33fc 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -134,7 +134,7 @@ def _update_env_seed_if_necessary( NOTE: this may not work with remote environments (issue #18154). """ - if not seed: + if seed is None: return # A single RL job is unlikely to have more than 10K diff --git a/rllib/tests/test_algorithm_save_load_checkpoint_learner.py b/rllib/tests/test_algorithm_save_load_checkpoint_learner.py new file mode 100644 index 000000000000..211d0dac10c7 --- /dev/null +++ b/rllib/tests/test_algorithm_save_load_checkpoint_learner.py @@ -0,0 +1,128 @@ +import tempfile +import unittest + +import ray +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.utils.test_utils import check, framework_iterator +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY + + +algorithms_and_configs = { + "PPO": (PPOConfig().training(train_batch_size=2, sgd_minibatch_size=2)) +} + + +@ray.remote +def save_and_train(algo_cfg: AlgorithmConfig, env: str, tmpdir): + """Create an algo, checkpoint it, then train for 2 iterations. + + Note: This function uses a seeded algorithm that can modify the global random state. + Running it multiple times in the same process can affect other algorithms. + Making it a Ray task runs it in a separate process and prevents it from + affecting other algorithms' random state. + + Args: + algo_cfg: The algorithm config to build the algo from. + env: The gym genvironment to train on. + tmpdir: The temporary directory to save the checkpoint to. + + Returns: + The learner stats after 2 iterations of training. + """ + algo_cfg = ( + algo_cfg.training(_enable_learner_api=True) + .rl_module(_enable_rl_module_api=True) + .rollouts(num_rollout_workers=0) + # setting min_time_s_per_iteration=0 and min_sample_timesteps_per_iteration=1 + # to make sure that we get results as soon as sampling/training is done at + # least once + .reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=1) + .debugging(seed=10) + ) + algo = algo_cfg.environment(env).build() + + tmpdir = str(tmpdir) + algo.save_checkpoint(tmpdir) + for _ in range(2): + results = algo.train() + return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] + + +@ray.remote +def load_and_train(algo_cfg: AlgorithmConfig, env: str, tmpdir): + """Loads the checkpoint saved by save_and_train and trains for 2 iterations. + + Note: This function uses a seeded algorithm that can modify the global random state. + Running it multiple times in the same process can affect other algorithms. + Making it a Ray task runs it in a separate process and prevents it from + affecting other algorithms' random state. + + Args: + algo_cfg: The algorithm config to build the algo from. + env: The gym genvironment to train on. + tmpdir: The temporary directory to save the checkpoint to. + + Returns: + The learner stats after 2 iterations of training. + + """ + algo_cfg = ( + algo_cfg.training(_enable_learner_api=True) + .rl_module(_enable_rl_module_api=True) + .rollouts(num_rollout_workers=0) + # setting min_time_s_per_iteration=0 and min_sample_timesteps_per_iteration=1 + # to make sure that we get results as soon as sampling/training is done at + # least once + .reporting(min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=1) + .debugging(seed=10) + ) + algo = algo_cfg.environment(env).build() + tmpdir = str(tmpdir) + algo.load_checkpoint(tmpdir) + for _ in range(2): + results = algo.train() + return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY] + + +class TestAlgorithmWithLearnerSaveAndRestore(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + ray.init() + + @classmethod + def tearDowClass(cls) -> None: + ray.shutdown() + + def test_save_and_restore(self): + for algo_name in algorithms_and_configs: + config = algorithms_and_configs[algo_name] + for _ in framework_iterator(config, frameworks=["torch", "tf2"]): + with tempfile.TemporaryDirectory() as tmpdir: + # create an algorithm, checkpoint it, then train for 2 iterations + ray.get(save_and_train.remote(config, "CartPole-v1", tmpdir)) + # load that checkpoint into a new algorithm and train for 2 + # iterations + results_algo_2 = ray.get( + load_and_train.remote(config, "CartPole-v1", tmpdir) + ) + + # load that checkpoint into another new algorithm and train for 2 + # iterations + results_algo_3 = ray.get( + load_and_train.remote(config, "CartPole-v1", tmpdir) + ) + + # check that the results are the same across loaded algorithms + # they won't be the same as the first algorithm since the random + # state that is used for each algorithm is not preserved across + # checkpoints. + check(results_algo_3, results_algo_2) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/checkpoints.py b/rllib/utils/checkpoints.py index 7bbc456ca148..19e5cc145b31 100644 --- a/rllib/utils/checkpoints.py +++ b/rllib/utils/checkpoints.py @@ -29,7 +29,11 @@ # 1.1: Same as 1.0, but has a new "format" field in the rllib_checkpoint.json file # indicating, whether the checkpoint is `cloudpickle` (default) or `msgpack`. + +# 1.2: Introduces the checkpoint for the new Learner API if the Learner api is enabled. + CHECKPOINT_VERSION = version.Version("1.1") +CHECKPOINT_VERSION_LEARNER = version.Version("1.2") @PublicAPI(stability="alpha") @@ -102,15 +106,15 @@ def get_checkpoint_info(checkpoint: Union[str, Checkpoint]) -> Dict[str, Any]: rllib_checkpoint_info["checkpoint_version"] ) info.update(rllib_checkpoint_info) - - # No rllib_checkpoint.json file present: Warn and continue trying to figure out - # checkpoint info ourselves. - if log_once("no_rllib_checkpoint_json_file"): - logger.warning( - "No `rllib_checkpoint.json` file found in checkpoint directory " - f"{checkpoint}! Trying to extract checkpoint info from other files " - f"found in that dir." - ) + else: + # No rllib_checkpoint.json file present: Warn and continue trying to figure + # out checkpoint info ourselves. + if log_once("no_rllib_checkpoint_json_file"): + logger.warning( + "No `rllib_checkpoint.json` file found in checkpoint directory " + f"{checkpoint}! Trying to extract checkpoint info from other files " + f"found in that dir." + ) # Policy checkpoint file found. for extension in ["pkl", "msgpck"]: @@ -222,7 +226,10 @@ def convert_to_msgpack_checkpoint( state["worker"]["is_policy_to_train"] = NOT_SERIALIZABLE # Add RLlib checkpoint version (as string). - state["checkpoint_version"] = str(CHECKPOINT_VERSION) + if state["config"]["_enable_learner_api"]: + state["checkpoint_version"] = str(CHECKPOINT_VERSION_LEARNER) + else: + state["checkpoint_version"] = str(CHECKPOINT_VERSION) # Write state (w/o policies) to disk. state_file = os.path.join(msgpack_checkpoint_dir, "algorithm_state.msgpck")