Skip to content

Commit

Permalink
[RLlib] Move DQN into the TargetNetworkAPI (and deprecate `RLModuleWi…
Browse files Browse the repository at this point in the history
…thTargetNetworksInterface`). (#46752)
  • Loading branch information
sven1977 authored Jul 23, 2024
1 parent 37ce64d commit 54e314f
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 233 deletions.
25 changes: 19 additions & 6 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_AGENT_STEPS_SAMPLED,
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self, algo_class=None):
self.num_multi_gpu_tower_stacks = 1
self.minibatch_buffer_size = 1
self.num_sgd_iter = 1
self.target_update_frequency = 1
self.target_network_update_freq = 1
self.replay_proportion = 0.0
self.replay_buffer_num_slots = 100
self.learner_queue_size = 16
Expand Down Expand Up @@ -142,6 +143,8 @@ def __init__(self, algo_class=None):
# __sphinx_doc_end__
# fmt: on

self.target_update_frequency = DEPRECATED_VALUE

@override(IMPALAConfig)
def training(
self,
Expand All @@ -155,7 +158,9 @@ def training(
kl_coeff: Optional[float] = NotProvided,
kl_target: Optional[float] = NotProvided,
tau: Optional[float] = NotProvided,
target_update_frequency: Optional[int] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
# Deprecated keys.
target_update_frequency=DEPRECATED_VALUE,
**kwargs,
) -> "APPOConfig":
"""Sets the training related configuration.
Expand All @@ -177,10 +182,10 @@ def training(
tau: The factor by which to update the target policy network towards
the current policy network. Can range between 0 and 1.
e.g. updated_param = tau * current_param + (1 - tau) * target_param
target_update_frequency: The frequency to update the target policy and
target_network_update_freq: The frequency to update the target policy and
tune the kl loss coefficients that are used during training. After
setting this parameter, the algorithm waits for at least
`target_update_frequency * minibatch_size * num_sgd_iter` number of
`target_network_update_freq * minibatch_size * num_sgd_iter` number of
samples to be trained on by the learner group before updating the target
networks and tuned the kl loss coefficients that are used during
training.
Expand All @@ -191,6 +196,14 @@ def training(
Returns:
This updated AlgorithmConfig object.
"""
if target_update_frequency != DEPRECATED_VALUE:
deprecation_warning(
old="target_update_frequency",
new="target_network_update_freq",
error=False,
)
target_network_update_freq = target_update_frequency

# Pass kwargs onto super's `training()` method.
super().training(**kwargs)

Expand All @@ -212,8 +225,8 @@ def training(
self.kl_target = kl_target
if tau is not NotProvided:
self.tau = tau
if target_update_frequency is not NotProvided:
self.target_update_frequency = target_update_frequency
if target_network_update_freq is not NotProvided:
self.target_network_update_freq = target_network_update_freq

return self

Expand Down
9 changes: 4 additions & 5 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,10 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
# of the train_batch_size * some target update frequency * num_sgd_iter.

last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
# TODO (Sven): DQN uses `config.target_network_update_freq`. Can we
# choose a standard here?
if (
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_update_frequency
if timestep - self.metrics.peek(
last_update_ts_key, default=0
) >= config.target_network_update_freq and isinstance(
module.unwrapped(), TargetNetworkAPI
):
for (
main_net,
Expand Down
61 changes: 37 additions & 24 deletions rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import abc
from typing import Any, Dict
from typing import Any, Dict, Optional

from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.learner.utils import update_target_network
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
Expand All @@ -18,6 +21,7 @@
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn


# Now, this is double defined: In `SACRLModule` and here. I would keep it here
Expand All @@ -42,12 +46,11 @@ class DQNRainbowLearner(Learner):
def build(self) -> None:
super().build()

# Initially sync target networks (w/ tau=1.0 -> full overwrite).
# TODO (sven): Use TargetNetworkAPI as soon as DQN implements it.
# Make target networks.
self.module.foreach_module(
lambda mid, module: (
module.sync_target_networks(tau=1.0)
if hasattr(module, "sync_target_networks")
lambda mid, mod: (
mod.make_target_networks()
if isinstance(mod, TargetNetworkAPI)
else None
)
)
Expand All @@ -60,6 +63,21 @@ def build(self) -> None:
AddNextObservationsFromEpisodesToTrainBatch(),
)

@override(Learner)
def add_module(
self,
*,
module_id: ModuleID,
module_spec: SingleAgentRLModuleSpec,
config_overrides: Optional[Dict] = None,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiAgentRLModuleSpec:
marl_spec = super().add_module(module_id=module_id)
# Create target networks for added Module, if applicable.
if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
self.module[module_id].unwrapped().make_target_networks()
return marl_spec

@override(Learner)
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
"""Updates the target Q Networks."""
Expand All @@ -71,26 +89,21 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
# method per module?
for module_id, module in self.module._rl_modules.items():
config = self.config.get_config_for_module(module_id)
# TODO (Sven): APPO uses `config.target_update_frequency`. Can we
# choose a standard here?
last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
if (
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_network_update_freq
if timestep - self.metrics.peek(
last_update_ts_key, default=0
) >= config.target_network_update_freq and isinstance(
module.unwrapped(), TargetNetworkAPI
):
# TODO (sven): Use TargetNetworkAPI as soon as DQN implements it.
if hasattr(module, "sync_target_networks"):
module.sync_target_networks(tau=config.tau)
else:
for (
main_net,
target_net,
) in module.unwrapped().get_target_network_pairs():
update_target_network(
main_net=main_net,
target_net=target_net,
tau=config.tau,
)
for (
main_net,
target_net,
) in module.unwrapped().get_target_network_pairs():
update_target_network(
main_net=main_net,
target_net=target_net,
tau=config.tau,
)
# Increase lifetime target network update counter by one.
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
# Update the (single-value -> window=1) last updated timestep metric.
Expand Down
90 changes: 50 additions & 40 deletions rllib/algorithms/dqn/dqn_rainbow_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from ray.rllib.algorithms.dqn.dqn_rainbow_catalog import DQNRainbowCatalog
from ray.rllib.algorithms.sac.sac_learner import QF_PREDS
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import Encoder, Model
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
)
from ray.rllib.models.distributions import Distribution
from ray.rllib.utils.annotations import (
ExperimentalAPI,
Expand All @@ -28,7 +27,7 @@


@ExperimentalAPI
class DQNRainbowRLModule(RLModule, RLModuleWithTargetNetworksInterface):
class DQNRainbowRLModule(RLModule, TargetNetworkAPI):
@override(RLModule)
def setup(self):
# Get the DQN Rainbow catalog.
Expand Down Expand Up @@ -65,8 +64,6 @@ def setup(self):
# If not an inference-only module (e.g., for evaluation), set up the
# target networks and state dict keys to be taken care of when syncing.
if not self.config.inference_only or self.framework != "torch":
# Build the same encoder for the target network(s).
self.target_encoder = catalog.build_encoder(framework=self.framework)
# Holds the parameter names to be removed or renamed when synching
# from the learner to the inference module.
self._inference_only_state_dict_keys = {}
Expand All @@ -76,30 +73,57 @@ def setup(self):
if self.uses_dueling:
# If in a dueling setting setup the value function head.
self.vf = catalog.build_vf_head(framework=self.framework)
if not self.config.inference_only or self.framework != "torch":
# Implement the same heads for the target network(s).
self.af_target = catalog.build_af_head(framework=self.framework)
if self.uses_dueling:
# If in a dueling setting setup the target value function head.
self.vf_target = catalog.build_vf_head(framework=self.framework)

# Define the action distribution for sampling the exploit action
# during exploration.
self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)

@override(RLModuleWithTargetNetworksInterface)
@override(TargetNetworkAPI)
def make_target_networks(self) -> None:
self._target_encoder = make_target_network(self.encoder)
self._target_af = make_target_network(self.af)
if self.uses_dueling:
self._target_vf = make_target_network(self.vf)

@override(TargetNetworkAPI)
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
"""Returns target Q and Q network(s) to update the target network(s)."""
return [(self.target_encoder, self.encoder), (self.af_target, self.af)] + (
return [(self.encoder, self._target_encoder), (self.af, self._target_af)] + (
# If we have a dueling architecture we need to update the value stream
# target, too.
[
(self.vf_target, self.vf),
(self.vf, self._target_vf),
]
if self.uses_dueling
else []
)

@override(TargetNetworkAPI)
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""Computes Q-values from the target network.
Note, these can be accompanied with logits and probabilities
in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
Args:
batch: The batch received in the forward pass.
Results:
A dictionary containing the target Q-value predictions ("qf_preds")
and in case of distributional Q-learning in addition to the target
Q-value predictions ("qf_preds") the support atoms ("atoms"), the target
Q-logits ("qf_logits"), and the probabilities ("qf_probs").
"""
# If we have a dueling architecture we have to add the value stream.
return self._qf_forward_helper(
batch,
self._target_encoder,
(
{"af": self._target_af, "vf": self._target_vf}
if self.uses_dueling
else self._target_af
),
)

@override(RLModule)
def get_exploration_action_dist_cls(self) -> Type[Distribution]:
"""Returns the action distribution class for exploration.
Expand Down Expand Up @@ -159,41 +183,27 @@ def output_specs_train(self) -> SpecType:
),
]

@abc.abstractmethod
@OverrideToImplementCustomLogic
def _qf(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
"""Computes Q-values.
"""Computes Q-values, given encoder, q-net and (optionally), advantage net.
Note, these can be accompanied with logits and pobabilities
Note, these can be accompanied by logits and probabilities
in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
Args:
batch: The batch recevied in the forward pass.
batch: The batch received in the forward pass.
Results:
A dictionary containing the Q-value predictions ("qf_preds")
and in case of distributional Q-learning in addition to the Q-value
predictions ("qf_preds") the support atoms ("atoms"), the Q-logits
("qf_logits"), and the probabilities ("qf_probs").
"""

@abc.abstractmethod
@OverrideToImplementCustomLogic
def _qf_target(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
"""Computes Q-values from the target network.
Note, these can be accompanied with logits and pobabilities
in case of distributional Q-learning, i.e. `self.num_atoms > 1`.
Args:
batch: The batch recevied in the forward pass.
Results:
A dictionary containing the target Q-value predictions ("qf_preds")
and in case of distributional Q-learning in addition to the target
Q-value predictions ("qf_preds") the support atoms ("atoms"), the target
Q-logits ("qf_logits"), and the probabilities ("qf_probs").
"""
# If we have a dueling architecture we have to add the value stream.
return self._qf_forward_helper(
batch,
self.encoder,
{"af": self.af, "vf": self.vf} if self.uses_dueling else self.af,
)

@abc.abstractmethod
@OverrideToImplementCustomLogic
Expand Down Expand Up @@ -229,7 +239,7 @@ def _qf_forward_helper(
Q-learning or not.
Args:
batch: The batch recevied in the forward pass.
batch: The batch received in the forward pass.
encoder: The encoder network to use. Here we have a single encoder
for all heads (Q or advantages and value in case of a dueling
architecture).
Expand Down
Loading

0 comments on commit 54e314f

Please sign in to comment.