Skip to content

Commit

Permalink
[RLlib] Added benchmark experiment for SAC with MuJoCo, PPO with MuJo…
Browse files Browse the repository at this point in the history
…Co and DQN with Atari. (#44262)
  • Loading branch information
simonsays1980 authored Apr 11, 2024
1 parent 6c68acf commit 61ef56f
Show file tree
Hide file tree
Showing 11 changed files with 1,358 additions and 12 deletions.
2 changes: 2 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ doctest(
"**/examples/**",
"**/tests/**",
"**/test_*.py",
# Exclude benchmark runs.
"tuned_examples/**/benchmark_*.py",
# Deprecated modules
"utils/window_stat.py",
"utils/timer.py",
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def from_dict(cls, config_dict: dict) -> "AlgorithmConfig":
config_dict: The legacy formatted python config dict for some algorithm.
Returns:
A new AlgorithmConfig object that matches the given python config dict.
A new AlgorithmConfig object that matches the given python config dict.
"""
# Create a default config object of this class.
config_obj = cls()
Expand Down
4 changes: 3 additions & 1 deletion rllib/algorithms/dqn/dqn_rainbow_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def setup(self):
self.uses_double_q: bool = self.config.model_config_dict.get("double_q")
# If we use noisy layers.
self.uses_noisy: bool = self.config.model_config_dict.get("noisy")
# If we use a noisy encoder.
self.uses_noisy_encoder: bool = False
# The number of atoms for a distribution support.
self.num_atoms: int = self.config.model_config_dict.get("num_atoms")
# If distributional learning is requested configure the support.
Expand Down Expand Up @@ -93,7 +95,7 @@ def get_initial_state(self) -> Any:

@override(RLModule)
def input_specs_exploration(self) -> SpecType:
return [Columns.OBS, Columns.T]
return [Columns.OBS]

@override(RLModule)
def input_specs_inference(self) -> SpecType:
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _forward(self, inputs: dict, **kwargs) -> dict:

def _reset_noise(self):
# Reset the noise in the complete network.
self.net.reset_noise()
self.net._reset_noise()


class TorchNoisyMLPHead(TorchModel):
Expand Down Expand Up @@ -127,7 +127,7 @@ def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:

def _reset_noise(self) -> None:
# Reset the noise in the complete network.
self.net.reset_noise()
self.net._reset_noise()


class TorchNoisyMLP(nn.Module):
Expand Down Expand Up @@ -305,4 +305,5 @@ def forward(self, x):
def _reset_noise(self):
# Reset the noise for all modules (layers).
for module in self.modules():
module.reset_noise()
if hasattr(module, "reset_noise"):
module.reset_noise()
24 changes: 17 additions & 7 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
ATOMS,
QF_LOGITS,
QF_NEXT_PREDS,
QF_PREDS,
QF_PROBS,
QF_TARGET_NEXT_PREDS,
QF_TARGET_NEXT_PROBS,
)
from ray.rllib.algorithms.sac.sac_rl_module import QF_PREDS
from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_noisy_net import (
TorchNoisyMLPEncoder,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
Expand All @@ -31,6 +34,9 @@ class DQNRainbowTorchRLModule(TorchRLModule, DQNRainbowRLModule):
def setup(self):
super().setup()

# If we use a noisy encoder. Note, only if the observation
# space is a flat space we can use a noisy encoder.
self.uses_noisy_encoder = isinstance(self.encoder, TorchNoisyMLPEncoder)
# We do not want to train the target networks.
# AND sync all target nets with the actual (trained) ones.
self.target_encoder.requires_grad_(False)
Expand Down Expand Up @@ -76,7 +82,7 @@ def _forward_exploration(
# Resample the noise for the noisy layers, if needed.
if self.uses_noisy:
# We want to resample the noise everytime we step.
self.reset_noise(target=False)
self._reset_noise(target=False)
if not self.training:
# Set the module into training mode. This sets
# the weigths and bias to their noisy version.
Expand Down Expand Up @@ -214,9 +220,11 @@ def _qf_target(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]:
return self._qf_forward_helper(
batch,
self.target_encoder,
{"af": self.af_target, "vf": self.vf_target}
if self.uses_dueling
else self.af_target,
(
{"af": self.af_target, "vf": self.vf_target}
if self.uses_dueling
else self.af_target
),
)

@override(DQNRainbowRLModule)
Expand Down Expand Up @@ -383,15 +391,17 @@ def _reset_noise(self, target: bool = False) -> None:
target: Whether to reset the noise of the target networks.
"""
if self.uses_noisy:
self.encoder._reset_noise()
if self.uses_noisy_encoder:
self.encoder._reset_noise()
self.af._reset_noise()
# If we have a dueling architecture we need to reset the noise
# of the value stream, too.
if self.uses_dueling:
self.vf._reset_noise()
# Reset the noise of the target networks, if requested.
if target:
self.target_encoder._reset_noise()
if self.uses_noisy_encoder:
self.target_encoder._reset_noise()
self.af_target._reset_noise()
# If we have a dueling architecture we need to reset the noise
# of the value stream, too.
Expand Down
Loading

0 comments on commit 61ef56f

Please sign in to comment.