Skip to content

Commit

Permalink
[RLlib] Make the KL coefficient traced in appo tf (#34293)
Browse files Browse the repository at this point in the history
Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn authored Apr 13, 2023
1 parent 10d0b8a commit f1b14d2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
47 changes: 45 additions & 2 deletions rllib/algorithms/appo/tests/tf/test_appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

import ray
import ray.rllib.algorithms.appo as appo
from ray.rllib.algorithms.appo.tf.appo_tf_learner import (
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
)
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check, framework_iterator

Expand Down Expand Up @@ -36,7 +40,7 @@
}


class TestImpalaTfLearner(unittest.TestCase):
class TestAPPOTfLearner(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()
Expand Down Expand Up @@ -105,6 +109,45 @@ def test_appo_loss(self):

check(learner_group_loss, policy_loss)

def test_kl_coeff_changes(self):
initial_kl_coeff = 0.01
config = (
appo.APPOConfig()
.environment("CartPole-v1")
.rollouts(
num_rollout_workers=0,
rollout_fragment_length=frag_length,
)
.resources(num_gpus=0)
.framework(eager_tracing=True)
.training(
gamma=0.99,
model=dict(
fcnet_hiddens=[10, 10],
fcnet_activation="linear",
vf_share_layers=False,
),
_enable_learner_api=True,
kl_coeff=initial_kl_coeff,
)
.rl_module(
_enable_rl_module_api=True,
)
.exploration(exploration_config={})
)
for _ in framework_iterator(config, ("tf2")):
algo = config.build()
# Call train while results aren't returned because this is
# a asynchronous trainer and results are returned asynchronously.
while 1:
results = algo.train()
if results:
break
curr_kl_coeff = results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
LEARNER_STATS_KEY
][LEARNER_RESULTS_CURR_KL_COEFF_KEY]
self.assertNotEqual(curr_kl_coeff, initial_kl_coeff)


if __name__ == "__main__":
import pytest
Expand Down
13 changes: 9 additions & 4 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"


@dataclass
Expand Down Expand Up @@ -72,8 +73,11 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kl_target = self._hps.kl_target
self.clip_param = self._hps.clip_param
self.kl_coeffs = defaultdict(lambda: self._hps.kl_coeff)
self.kl_coeff = self._hps.kl_coeff
# TODO: (avnishn) Make creating the kl coeff a utility function when we add
# torch APPO as well.
self.kl_coeffs = defaultdict(
lambda: tf.Variable(self._hps.kl_coeff, trainable=False, dtype=tf.float32)
)
self.tau = self._hps.tau

@override(TfLearner)
Expand Down Expand Up @@ -192,6 +196,7 @@ def compute_loss_per_module(
VF_LOSS_KEY: mean_vf_loss,
ENTROPY_KEY: mean_entropy_loss,
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
LEARNER_RESULTS_CURR_KL_COEFF_KEY: self.kl_coeffs[module_id],
}

@override(ImpalaTfLearner)
Expand Down Expand Up @@ -238,10 +243,10 @@ def _update_module_kl_coeff(
# Update the current KL value based on the recently measured value.
# Increase.
if sampled_kl > 2.0 * self.kl_target:
self.kl_coeffs[module_id] *= 1.5
self.kl_coeffs[module_id].assign(self.kl_coeffs[module_id] * 1.5)
# Decrease.
elif sampled_kl < 0.5 * self.kl_target:
self.kl_coeffs[module_id] *= 0.5
self.kl_coeffs[module_id].assign(self.kl_coeffs[module_id] * 0.5)

@override(ImpalaTfLearner)
def additional_update_per_module(
Expand Down
2 changes: 2 additions & 0 deletions rllib/tuned_examples/appo/cartpole-appo-learner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ cartpole-appo-learner:
eager_tracing: True
lr: 0.001
entropy_coeff: 0.1
kl_coeff: 0.01
exploration_config: null

0 comments on commit f1b14d2

Please sign in to comment.