diff --git a/rllib/algorithms/ppo/tests/test_ppo_learner.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py index e16ea35641e6..1eb8e082c102 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_learner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -9,8 +9,12 @@ import ray.rllib.algorithms.ppo as ppo from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.algorithms.ppo.ppo_learner import LEARNER_RESULTS_CURR_KL_COEFF_KEY from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.policy.sample_batch import SampleBatch +from ray.tune.registry import register_env +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import check, framework_iterator from ray.rllib.evaluation.postprocessing import ( @@ -156,6 +160,67 @@ def test_save_load_state(self): learner_group2.load_state(tmpdir) check(learner_group1.get_state(), learner_group2.get_state()) + def test_kl_coeff_changes(self): + # Simple environment with 4 independent cartpole entities + register_env( + "multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2}) + ) + + initial_kl_coeff = 0.01 + config = ( + ppo.PPOConfig() + .environment("CartPole-v1") + .rollouts( + num_rollout_workers=0, + rollout_fragment_length=50, + ) + .training( + gamma=0.99, + model=dict( + fcnet_hiddens=[10, 10], + fcnet_activation="linear", + vf_share_layers=False, + ), + _enable_learner_api=True, + kl_coeff=initial_kl_coeff, + ) + .rl_module( + _enable_rl_module_api=True, + ) + .exploration(exploration_config={}) + .environment("multi_agent_cartpole") + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + "p{}".format(agent_id % 2) + ), + ) + ) + + for _ in framework_iterator(config, ("torch", "tf2"), with_eager_tracing=True): + algo = config.build() + # Call train while results aren't returned because this is + # a asynchronous trainer and results are returned asynchronously. + curr_kl_coeff_1 = None + curr_kl_coeff_2 = None + while not curr_kl_coeff_1 or not curr_kl_coeff_2: + results = algo.train() + + # Attempt to get the current KL coefficient from the learner. + # Iterate until we have found both coefficients at least once. + if results and "info" in results and LEARNER_INFO in results["info"]: + if "p0" in results["info"][LEARNER_INFO]: + curr_kl_coeff_1 = results["info"][LEARNER_INFO]["p0"][ + LEARNER_RESULTS_CURR_KL_COEFF_KEY + ] + if "p1" in results["info"][LEARNER_INFO]: + curr_kl_coeff_2 = results["info"][LEARNER_INFO]["p1"][ + LEARNER_RESULTS_CURR_KL_COEFF_KEY + ] + + self.assertNotEqual(curr_kl_coeff_1, initial_kl_coeff) + self.assertNotEqual(curr_kl_coeff_2, initial_kl_coeff) + if __name__ == "__main__": import pytest