Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Added benchmark experiment for SAC with MuJoCo, PPO with MuJoCo and DQN with Atari. #44262

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rllib/algorithms/dqn/dqn_rainbow_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def setup(self):
self.uses_double_q: bool = self.config.model_config_dict.get("double_q")
# If we use noisy layers.
self.uses_noisy: bool = self.config.model_config_dict.get("noisy")
# If we use a noisy encoder.
self.uses_noisy_encoder: bool = False
# The number of atoms for a distribution support.
self.num_atoms: int = self.config.model_config_dict.get("num_atoms")
# If distributional learning is requested configure the support.
Expand Down Expand Up @@ -93,7 +95,7 @@ def get_initial_state(self) -> Any:

@override(RLModule)
def input_specs_exploration(self) -> SpecType:
return [Columns.OBS, Columns.T]
return [Columns.OBS]

@override(RLModule)
def input_specs_inference(self) -> SpecType:
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _forward(self, inputs: dict, **kwargs) -> dict:

def _reset_noise(self):
# Reset the noise in the complete network.
self.net.reset_noise()
self.net._reset_noise()


class TorchNoisyMLPHead(TorchModel):
Expand Down Expand Up @@ -127,7 +127,7 @@ def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:

def _reset_noise(self) -> None:
# Reset the noise in the complete network.
self.net.reset_noise()
self.net._reset_noise()


class TorchNoisyMLP(nn.Module):
Expand Down Expand Up @@ -305,4 +305,5 @@ def forward(self, x):
def _reset_noise(self):
# Reset the noise for all modules (layers).
for module in self.modules():
module.reset_noise()
if hasattr(module, "reset_noise"):
module.reset_noise()
16 changes: 12 additions & 4 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
ATOMS,
QF_LOGITS,
QF_NEXT_PREDS,
QF_PREDS,
QF_PROBS,
QF_TARGET_NEXT_PREDS,
QF_TARGET_NEXT_PROBS,
)
from ray.rllib.algorithms.sac.sac_rl_module import QF_PREDS
from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_noisy_net import (
TorchNoisyMLPEncoder,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
Expand All @@ -31,6 +34,9 @@ class DQNRainbowTorchRLModule(TorchRLModule, DQNRainbowRLModule):
def setup(self):
super().setup()

# If we use a noisy encoder. Note, only if the observation
# space is a flat space we can use a noisy encoder.
self.uses_noisy_encoder = isinstance(self.encoder, TorchNoisyMLPEncoder)
# We do not want to train the target networks.
# AND sync all target nets with the actual (trained) ones.
self.target_encoder.requires_grad_(False)
Expand Down Expand Up @@ -76,7 +82,7 @@ def _forward_exploration(
# Resample the noise for the noisy layers, if needed.
if self.uses_noisy:
# We want to resample the noise everytime we step.
self.reset_noise(target=False)
self._reset_noise(target=False)
if not self.training:
# Set the module into training mode. This sets
# the weigths and bias to their noisy version.
Expand Down Expand Up @@ -383,15 +389,17 @@ def _reset_noise(self, target: bool = False) -> None:
target: Whether to reset the noise of the target networks.
"""
if self.uses_noisy:
self.encoder._reset_noise()
if self.uses_noisy_encoder:
self.encoder._reset_noise()
self.af._reset_noise()
# If we have a dueling architecture we need to reset the noise
# of the value stream, too.
if self.uses_dueling:
self.vf._reset_noise()
# Reset the noise of the target networks, if requested.
if target:
self.target_encoder._reset_noise()
if self.uses_noisy_encoder:
self.target_encoder._reset_noise()
self.af_target._reset_noise()
# If we have a dueling architecture we need to reset the noise
# of the value stream, too.
Expand Down
Loading
Loading