Skip to content

Commit

Permalink
[RLlib] Add a KL ratio test to PPO and remove APPO's own learner keys (
Browse files Browse the repository at this point in the history
…ray-project#35476)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored and scv119 committed Jun 11, 2023
1 parent 291cb73 commit 9dab6c3
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions rllib/algorithms/ppo/tests/test_ppo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import ray.rllib.algorithms.ppo as ppo
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

from ray.rllib.algorithms.ppo.ppo_learner import LEARNER_RESULTS_CURR_KL_COEFF_KEY
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.sample_batch import SampleBatch
from ray.tune.registry import register_env
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.test_utils import check, framework_iterator

from ray.rllib.evaluation.postprocessing import (
Expand Down Expand Up @@ -156,6 +160,67 @@ def test_save_load_state(self):
learner_group2.load_state(tmpdir)
check(learner_group1.get_state(), learner_group2.get_state())

def test_kl_coeff_changes(self):
# Simple environment with 4 independent cartpole entities
register_env(
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 2})
)

initial_kl_coeff = 0.01
config = (
ppo.PPOConfig()
.environment("CartPole-v1")
.rollouts(
num_rollout_workers=0,
rollout_fragment_length=50,
)
.training(
gamma=0.99,
model=dict(
fcnet_hiddens=[10, 10],
fcnet_activation="linear",
vf_share_layers=False,
),
_enable_learner_api=True,
kl_coeff=initial_kl_coeff,
)
.rl_module(
_enable_rl_module_api=True,
)
.exploration(exploration_config={})
.environment("multi_agent_cartpole")
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
)
)

for _ in framework_iterator(config, ("torch", "tf2"), with_eager_tracing=True):
algo = config.build()
# Call train while results aren't returned because this is
# a asynchronous trainer and results are returned asynchronously.
curr_kl_coeff_1 = None
curr_kl_coeff_2 = None
while not curr_kl_coeff_1 or not curr_kl_coeff_2:
results = algo.train()

# Attempt to get the current KL coefficient from the learner.
# Iterate until we have found both coefficients at least once.
if results and "info" in results and LEARNER_INFO in results["info"]:
if "p0" in results["info"][LEARNER_INFO]:
curr_kl_coeff_1 = results["info"][LEARNER_INFO]["p0"][
LEARNER_RESULTS_CURR_KL_COEFF_KEY
]
if "p1" in results["info"][LEARNER_INFO]:
curr_kl_coeff_2 = results["info"][LEARNER_INFO]["p1"][
LEARNER_RESULTS_CURR_KL_COEFF_KEY
]

self.assertNotEqual(curr_kl_coeff_1, initial_kl_coeff)
self.assertNotEqual(curr_kl_coeff_2, initial_kl_coeff)


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit 9dab6c3

Please sign in to comment.