Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] APPO on new API stack (w/ EnvRunners). #46216

Merged
merged 15 commits into from
Jun 26, 2024
5 changes: 2 additions & 3 deletions doc/source/rllib/package_ref/learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ Performing Updates

Learner.update_from_batch
Learner.update_from_episodes
Learner.before_gradient_based_update
Learner._update
Learner.additional_update
Learner.additional_update_for_module
Learner._convert_batch_type
Learner.after_gradient_based_update


Computing Losses
Expand Down
33 changes: 13 additions & 20 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,7 @@ Updates
}
default_batch = SampleBatch(DUMMY_BATCH)
DUMMY_BATCH = default_batch.as_multi_agent()
ADDITIONAL_UPDATE_KWARGS = {
"timestep": 0,
"sampled_kl_values": {DEFAULT_MODULE_ID: 1e-4},
}
TIMESTEPS = {"num_env_steps_sampled_lifetime": 0}

learner.build() # needs to be called on the learner before calling any functions

Expand All @@ -230,43 +227,39 @@ Updates
.. testcode::

# This is a blocking update.
results = learner_group.update_from_batch(batch=DUMMY_BATCH)
results = learner_group.update_from_batch(batch=DUMMY_BATCH, timesteps=TIMESTEPS)

# This is a non-blocking update. The results are returned in a future
# call to `update_from_batch(..., async_update=True)`
_ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True)
_ = learner_group.update_from_batch(batch=DUMMY_BATCH, async_update=True, timesteps=TIMESTEPS)

# Artificially wait for async request to be done to get the results
# in the next call to
# `LearnerGroup.update_from_batch(..., async_update=True)`.
time.sleep(5)
results = learner_group.update_from_batch(
batch=DUMMY_BATCH, async_update=True
batch=DUMMY_BATCH, async_update=True, timesteps=TIMESTEPS
)
# `results` is an already reduced dict, which is the result of
# reducing over the individual async `update_from_batch(..., async_update=True)`
# calls.
assert isinstance(results, dict), results

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)

When updating a :py:class:`~ray.rllib.core.learner.learner_group.LearnerGroup` you can perform blocking or async updates on batches of data. Async updates are necessary for implementing async algorithms such as APPO/IMPALA.
You can perform non-gradient based updates using :py:meth:`~ray.rllib.core.learner.learner_group.LearnerGroup.additional_update`.
When updating a :py:class:`~ray.rllib.core.learner.learner_group.LearnerGroup` you can perform blocking or async updates on batches of data.
Async updates are necessary for implementing async algorithms such as APPO/IMPALA.

.. tab-item:: Updating a Learner

.. testcode::

# This is a blocking update (given a training batch).
result = learner.update_from_batch(batch=DUMMY_BATCH)

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)
result = learner.update_from_batch(batch=DUMMY_BATCH, timesteps=timesteps)

When updating a :py:class:`~ray.rllib.core.learner.learner.Learner` you can only perform blocking updates on batches of data.
You can perform non-gradient based updates using :py:meth:`~ray.rllib.core.learner.learner.Learner.additional_update`.

You can perform non-gradient based updates before or after the gradient-based ones by overriding
:py:meth:`~ray.rllib.core.learner.learner.Learner.before_gradient_based_update` and
:py:meth:`~ray.rllib.core.learner.learner.Learner.after_gradient_based_update`.


Getting and setting state
-------------------------
Expand Down Expand Up @@ -368,8 +361,8 @@ Implementation
- set up any optimizers for a RLModule.
* - :py:meth:`~ray.rllib.core.learner.learner.Learner.compute_loss_for_module()`
- calculate the loss for gradient based update to a module.
* - :py:meth:`~ray.rllib.core.learner.learner.Learner.additional_update_for_module()`
- do any non gradient based updates to a RLModule, e.g. target network updates.
* - :py:meth:`~ray.rllib.core.learner.learner.Learner.TODO examplebefore_gradient_based_update()`
- do any non-gradient based updates to a RLModule before(!) the gradient based ones, e.g. add noise to your network.

Starter Example
---------------
Expand Down
18 changes: 17 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,22 @@ py_test(
# --------------------------------------------------------------------

# APPO
py_test(
name = "learning_tests_cartpole_appo",
main = "tuned_examples/appo/cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "torch_only"],
size = "large",
srcs = ["tuned_examples/appo/cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack"]
)
py_test(
name = "learning_tests_multi_agent_cartpole_appo",
main = "tuned_examples/appo/multi_agent_cartpole_appo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "torch_only"],
size = "large",
srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"],
args = ["--as-test", "--enable-new-api-stack"]
)

#@OldAPIStack
py_test(
Expand All @@ -173,7 +189,7 @@ py_test(
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/appo/multi_agent_cartpole_appo.py"],
data = ["tuned_examples/appo/multi_agent_cartpole_appo_old_api_stack.py"],
args = ["--dir=tuned_examples/appo"]
)

Expand Down
126 changes: 59 additions & 67 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
import abc
from typing import Any, Dict

from ray.rllib.algorithms.appo.appo import APPOConfig
from ray.rllib.algorithms.impala.impala_learner import ImpalaLearner
from ray.rllib.core.learner.learner import Learner
from ray.rllib.utils.annotations import override
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
from ray.rllib.utils.metrics import LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import ModuleID


class AppoLearner(ImpalaLearner):
"""Adds KL coeff updates via `additional_update_for_module()` to Impala logic.
"""Adds KL coeff updates via `after_gradient_based_update()` to Impala logic.

Framework-specific sub-classes must override `_update_module_kl_coeff()`
Framework-specific sub-classes must override `_update_module_kl_coeff()`.
"""

@override(ImpalaLearner)
def build(self):
super().build()

# Initially sync target networks (w/ tau=1.0 -> full overwrite).
self.module.foreach_module(
lambda mid, module: module.sync_target_networks(tau=1.0)
)

# The current kl coefficients per module as (framework specific) tensor
# variables.
self.curr_kl_coeffs_per_module: LambdaDefaultDict[
Expand All @@ -34,69 +46,52 @@ def remove_module(self, module_id: str):
super().remove_module(module_id)
self.curr_kl_coeffs_per_module.pop(module_id)

@override(ImpalaLearner)
def additional_update_for_module(
self,
*,
module_id: ModuleID,
config: APPOConfig,
timestep: int,
last_update: int,
mean_kl_loss_per_module: dict,
**kwargs,
) -> None:
"""Updates the target networks and KL loss coefficients (per module).

Args:
module_id:
"""

# return dict(
# last_update=self._counters[LAST_TARGET_UPDATE_TS],
# mean_kl_loss_per_module={
# module_id: r[LEARNER_RESULTS_KL_KEY]
# for module_id, r in train_results.items()
# if module_id != ALL_MODULES
# },
# )

# TODO (avnish) Using steps trained here instead of sampled ... I'm not sure
# why the other implementation uses sampled.
# The difference in steps sampled/trained is pretty
# much always going to be larger than self.config.num_sgd_iter *
# self.config.minibatch_buffer_size unless the number of steps collected
# is really small. The thing is that the default rollout fragment length
# is 50, so the minibatch buffer size * num_sgd_iter is going to be
# have to be 50 to even meet the threshold of having delayed target
# updates.
# We should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.
super().additional_update_for_module(
module_id=module_id, config=config, timestep=timestep
)

# TODO (Sven): DQN uses `config.target_network_update_freq`. Can we
# choose a standard here?
last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
if (
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_update_frequency
):
self.module._synch_target_network(module_id=module_id, config=config)
# Increase lifetime target network update counter by one.
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
# Update the (single-value -> window=1) last updated timestep metric.
self.metrics.log_value(last_update_ts_key, timestep, window=1)

if config.use_kl_loss and module_id in mean_kl_loss_per_module:
self._update_module_kl_coeff(
module_id, config, mean_kl_loss_per_module[module_id]
)
@override(Learner)
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
"""Updates the target Q Networks."""
super().after_gradient_based_update(timesteps=timesteps)

timestep = timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)

# TODO (sven): Maybe we should have a `after_gradient_based_update`
# method per module?
for module_id, module in self.module._rl_modules.items():
config = self.config.get_config_for_module(module_id)

# TODO (avnish) Using steps trained here instead of sampled ... I'm not sure
# why the other implementation uses sampled.
# The difference in steps sampled/trained is pretty
# much always going to be larger than self.config.num_sgd_iter *
# self.config.minibatch_buffer_size unless the number of steps collected
# is really small. The thing is that the default rollout fragment length
# is 50, so the minibatch buffer size * num_sgd_iter is going to be
# have to be 50 to even meet the threshold of having delayed target
# updates.
# We should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.

last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
# TODO (Sven): DQN uses `config.target_network_update_freq`. Can we
# choose a standard here?
if (
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_update_frequency
):
module.sync_target_networks(tau=config.tau)
# Increase lifetime target network update counter by one.
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
# Update the (single-value -> window=1) last updated timestep metric.
self.metrics.log_value(last_update_ts_key, timestep, window=1)

if (
config.use_kl_loss
and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0)
> 0
):
self._update_module_kl_coeff(module_id=module_id, config=config)

@abc.abstractmethod
def _update_module_kl_coeff(
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float
) -> None:
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll take the KL directly from the metrics now.

"""Dynamically update the KL loss coefficients of each module with.

The update is completed using the mean KL divergence between the action
Expand All @@ -106,7 +101,4 @@ def _update_module_kl_coeff(
Args:
module_id: The module whose KL loss coefficient to update.
config: The AlgorithmConfig specific to the given `module_id`.
sampled_kl: The computed KL loss for the given Module
(KL divergence between the action distributions of the current
(most recently updated) module and the old module version).
"""
10 changes: 5 additions & 5 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import ModuleID, TensorType

_, tf, _ = try_import_tf()
Expand Down Expand Up @@ -180,18 +181,17 @@ def compute_loss_for_module(
return total_loss

@override(AppoLearner)
def _update_module_kl_coeff(
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float
) -> None:
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
# Update the current KL value based on the recently measured value.
# Increase.
kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)))
kl_coeff_var = self.curr_kl_coeffs_per_module[module_id]

if sampled_kl > 2.0 * config.kl_target:
if kl > 2.0 * config.kl_target:
# TODO (Kourosh) why not *2.0?
kl_coeff_var.assign(kl_coeff_var * 1.5)
# Decrease.
elif sampled_kl < 0.5 * config.kl_target:
elif kl < 0.5 * config.kl_target:
kl_coeff_var.assign(kl_coeff_var * 0.5)

self.metrics.log_value(
Expand Down
10 changes: 5 additions & 5 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import ModuleID, TensorType

torch, nn = try_import_torch()
Expand Down Expand Up @@ -219,18 +220,17 @@ def _make_modules_ddp_if_necessary(self) -> None:
)

@override(AppoLearner)
def _update_module_kl_coeff(
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float
) -> None:
def _update_module_kl_coeff(self, module_id: ModuleID, config: APPOConfig) -> None:
# Update the current KL value based on the recently measured value.
# Increase.
kl = convert_to_numpy(self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)))
kl_coeff_var = self.curr_kl_coeffs_per_module[module_id]

if sampled_kl > 2.0 * config.kl_target:
if kl > 2.0 * config.kl_target:
# TODO (Kourosh) why not *2.0?
kl_coeff_var.data *= 1.5
# Decrease.
elif sampled_kl < 0.5 * config.kl_target:
elif kl < 0.5 * config.kl_target:
kl_coeff_var.data *= 0.5

self.metrics.log_value(
Expand Down
16 changes: 4 additions & 12 deletions rllib/algorithms/appo/torch/appo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

from ray.rllib.algorithms.appo.appo import (
OLD_ACTION_DIST_LOGITS_KEY,
Expand All @@ -13,6 +13,7 @@
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import NetworkType


class APPOTorchRLModule(PPOTorchRLModule, APPORLModule):
Expand All @@ -29,17 +30,8 @@ def setup(self):
self.old_encoder.requires_grad_(False)

@override(RLModuleWithTargetNetworksInterface)
def sync_target_networks(self, tau: float) -> None:
for target_network, current_network in [
(self.old_pi, self.pi),
(self.old_encoder, self.encoder),
]:
current_state_dict = current_network.state_dict()
new_state_dict = {
k: tau * current_state_dict[k] + (1 - tau) * v
for k, v in target_network.state_dict().items()
}
target_network.load_state_dict(new_state_dict)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
return [(self.old_pi, self.pi), (self.old_encoder, self.encoder)]

@override(PPOTorchRLModule)
def output_specs_train(self) -> List[str]:
Expand Down
Loading
Loading