Skip to content

Commit

Permalink
[RLlib] Issue 22444: KL-coeff not stored in persistent policy state. (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored and simonsays1980 committed Feb 27, 2022
1 parent ae69100 commit 77ffa0b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
31 changes: 27 additions & 4 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import LearningRateSchedule, EntropyCoeffSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
Expand Down Expand Up @@ -237,7 +238,7 @@ def compute_and_clip_gradients(


class KLCoeffMixin:
"""Assigns the `update_kl()` method to the PPOPolicy.
"""Assigns the `update_kl()` and other KL-related methods to the PPOPolicy.
This is used in PPO's execution plan (see ppo.py) for updating the KL
coefficient after each learning step based on `config.kl_target` and
Expand Down Expand Up @@ -276,16 +277,38 @@ def update_kl(self, sampled_kl):
else:
return self.kl_coeff_val

# Update the tf Variable (via session call for tf).
# Make sure, new value is also stored in graph/tf variable.
self._set_kl_coeff(self.kl_coeff_val)

# Return the current KL value.
return self.kl_coeff_val

def _set_kl_coeff(self, new_kl_coeff):
# Set the (off graph) value.
self.kl_coeff_val = new_kl_coeff

# Update the tf/tf2 Variable (via session call for tf or `assign`).
if self.framework == "tf":
self.get_session().run(
self._kl_coeff_update,
feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val},
)
else:
self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
# Return the current KL value.
return self.kl_coeff_val

@override(Policy)
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
state = super().get_state()
# Add current kl-coeff value.
state["current_kl_coeff"] = self.kl_coeff_val
return state

@override(Policy)
def set_state(self, state: dict) -> None:
# Set current kl-coeff value first.
self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"]))
# Call super's set_state with rest of the state dict.
super().set_state(state)


class ValueNetworkMixin:
Expand Down
14 changes: 14 additions & 0 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,17 @@ def on_global_var_update(self, global_vars):
self.entropy_coeff = self._entropy_coeff_schedule.value(
global_vars["timestep"]
)

@override(TorchPolicy)
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
state = super().get_state()
# Add current kl-coeff value.
state["current_kl_coeff"] = self.kl_coeff
return state

@override(TorchPolicy)
def set_state(self, state: dict) -> None:
# Set current kl-coeff value first.
self.kl_coeff = state.pop("current_kl_coeff", self.config["kl_coeff"])
# Call super's set_state with rest of the state dict.
super().set_state(state)

0 comments on commit 77ffa0b

Please sign in to comment.