Skip to content

Commit

Permalink
[RLlib] DreamerV3: Catalog enhancements 05 - GRU default model suppor…
Browse files Browse the repository at this point in the history
…t. (ray-project#34284)

Signed-off-by: Jack He <[email protected]>
  • Loading branch information
sven1977 authored and ProjectsByJackHe committed May 4, 2023
1 parent f8ea15c commit f1307cf
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 64 deletions.
11 changes: 5 additions & 6 deletions rllib/core/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from ray.rllib.core.models.base import ModelConfig
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.core.models.configs import (
MLPEncoderConfig,
LSTMEncoderConfig,
CNNEncoderConfig,
MLPEncoderConfig,
RecurrentEncoderConfig,
)
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
from ray.rllib.models import MODEL_DEFAULTS
Expand Down Expand Up @@ -321,14 +321,13 @@ def get_encoder_config(
use_attention = model_config_dict["use_attention"]

if use_lstm:
encoder_config = LSTMEncoderConfig(
encoder_config = RecurrentEncoderConfig(
recurrent_layer_type="lstm",
hidden_dim=model_config_dict["lstm_cell_size"],
batch_first=not model_config_dict["_time_major"],
batch_major=not model_config_dict["_time_major"],
num_layers=1,
output_dims=[model_config_dict["lstm_cell_size"]],
output_activation=output_activation,
observation_space=observation_space,
action_space=action_space,
view_requirements_dict=view_requirements,
get_tokenizer_config=cls.get_tokenizer_config,
)
Expand Down
75 changes: 48 additions & 27 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,20 +609,24 @@ def build(self, framework: str = "torch") -> Encoder:

@ExperimentalAPI
@dataclass
class LSTMEncoderConfig(ModelConfig):
"""Configuration for an LSTM encoder.
class RecurrentEncoderConfig(ModelConfig):
"""Configuration for an LSTM-based or a GRU-based encoder.
The encoder consists of N LSTM layers stacked on top of each other and feeding their
outputs as inputs to the respective next layer. The internal state is structued
as (num_layers, B, hidden-size) for both h- and c-states of the LSTM layer(s).
The encoder consists of N LSTM/GRU layers stacked on top of each other and feeding
their outputs as inputs to the respective next layer. The internal state is
structued as (num_layers, B, hidden-size) for all hidden state components, e.g.
h- and c-states of the LSTM layer(s) or h-state of the GRU layer(s).
For example, the hidden states of an LSTMEncoder with num_layers=2 and hidden_dim=8
would be: {"h": (2, B, 8), "c": (2, B, 8)}.
Example:
.. code-block:: python
# Configuration:
config = LSTMEncoderConfig(
config = RecurrentEncoderConfig(
recurrent_layer_type="lstm",
input_dims=[16], # must be 1D tensor
hidden_dim=128,
num_lstm_layers=2,
num_layers=2,
output_dims=[256], # maybe None or a 1D tensor
output_activation="linear",
use_bias=True,
Expand All @@ -640,29 +644,33 @@ class LSTMEncoderConfig(ModelConfig):
Example:
.. code-block:: python
# Configuration:
config = LSTMEncoderConfig(
config = RecurrentEncoderConfig(
recurrent_layer_type="gru",
input_dims=[32], # must be 1D tensor
hidden_dim=64,
num_lstm_layers=1,
num_layers=1,
output_dims=None, # maybe None or a 1D tensor
use_bias=False,
)
model = config.build(framework="torch")
# Resulting stack in pseudocode:
# LSTM(32, 64, bias=False)
# GRU(32, 64, bias=False)
# Resulting shape of the internal states (c- and h-states):
# (1, B, 64) for each c- and h-states.
# Resulting shape of the internal state:
# (1, B, 64)
Attributes:
input_dims: A 1D tensor indicating the input dimension, e.g. `[32]`.
hidden_dim: The size of the hidden internal states (h- and c-states) of the
LSTM layer(s).
num_lstm_layers: The number of LSTM layers to stack.
recurrent_layer_type: The type of the recurrent layer(s).
Either "lstm" or "gru".
input_dims: The input dimensions. Must be 1D. This is the 1D shape of the tensor
that goes into the first recurrent layer.
hidden_dim: The size of the hidden internal state(s) of the recurrent layer(s).
For example, for an LSTM, this would be the size of the c- and h-tensors.
num_layers: The number of recurrent (LSTM or GRU) layers to stack.
batch_major: Wether the input is batch major (B, T, ..) or
time major (T, B, ..).
output_activation: The activation function to use for the output layer.
output_activation: The activation function to use for the linear output layer.
use_bias: Whether to use bias on all layers in the network.
view_requirements_dict: The view requirements to use if anything else than
observation_space or action_space is to be encoded. This signifies an
Expand All @@ -672,8 +680,9 @@ class LSTMEncoderConfig(ModelConfig):
other spaces that might be present in the view_requirements_dict.
"""

recurrent_layer_type: str = "lstm"
hidden_dim: int = None
num_lstm_layers: int = None
num_layers: int = None
batch_major: bool = True
output_activation: str = "linear"
use_bias: bool = True
Expand All @@ -682,10 +691,15 @@ class LSTMEncoderConfig(ModelConfig):

def _validate(self, framework: str = "torch"):
"""Makes sure that settings are valid."""
if self.recurrent_layer_type not in ["gru", "lstm"]:
raise ValueError(
f"`recurrent_layer_type` ({self.recurrent_layer_type}) of "
"RecurrentEncoderConfig must be 'gru' or 'lstm'!"
)
if self.input_dims is not None and len(self.input_dims) != 1:
raise ValueError(
f"`input_dims` ({self.input_dims}) of LSTMEncoderConfig must be 1D, "
"e.g. `[32]`!"
f"`input_dims` ({self.input_dims}) of RecurrentEncoderConfig must be "
"1D, e.g. `[32]`!"
)

# Call these already here to catch errors early on.
Expand All @@ -698,20 +712,27 @@ def build(self, framework: str = "torch") -> Encoder:
or self.view_requirements_dict is not None
):
raise NotImplementedError(
"LSTMEncoderConfig does not support configuring LSTMs that encode "
"depending on view_requirements or have a custom tokenizer. "
"RecurrentEncoderConfig does not support configuring Models that "
"encode depending on view_requirements or have a custom tokenizer. "
"Therefore, this config expects `view_requirements_dict=None` and "
"`get_tokenizer_config=None`."
)

if framework == "torch":
from ray.rllib.core.models.torch.encoder import TorchLSTMEncoder

return TorchLSTMEncoder(self)
from ray.rllib.core.models.torch.encoder import (
TorchGRUEncoder as GRU,
TorchLSTMEncoder as LSTM,
)
else:
from ray.rllib.core.models.tf.encoder import TfLSTMEncoder
from ray.rllib.core.models.tf.encoder import (
TfGRUEncoder as GRU,
TfLSTMEncoder as LSTM,
)

return TfLSTMEncoder(self)
if self.recurrent_layer_type == "lstm":
return LSTM(self)
else:
return GRU(self)


@ExperimentalAPI
Expand Down
88 changes: 79 additions & 9 deletions rllib/core/models/tests/test_recurrent_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,102 @@
import unittest

from ray.rllib.core.models.base import ENCODER_OUT, STATE_OUT
from ray.rllib.core.models.configs import LSTMEncoderConfig
from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator, ModelChecker

_, tf, _ = try_import_tf()
torch, _ = try_import_torch()


class TestRecurrentEncoders(unittest.TestCase):
def test_gru_encoders(self):
"""Tests building GRU encoders properly and checks for correct architecture."""

# Loop through different combinations of hyperparameters.
inputs_dimss = [[1], [100]]
output_dimss = [[1], [50]]
num_layerss = [1, 4]
hidden_dims = [128, 256]
output_activations = ["linear", "silu", "relu"]
use_biases = [False, True]

for permutation in itertools.product(
inputs_dimss,
num_layerss,
hidden_dims,
output_activations,
output_dimss,
use_biases,
):
(
inputs_dims,
num_layers,
hidden_dim,
output_activation,
output_dims,
use_bias,
) = permutation

print(
f"Testing ...\n"
f"input_dims: {inputs_dims}\n"
f"num_layers: {num_layers}\n"
f"hidden_dim: {hidden_dim}\n"
f"output_activation: {output_activation}\n"
f"output_dims: {output_dims}\n"
f"use_bias: {use_bias}\n"
)

config = RecurrentEncoderConfig(
recurrent_layer_type="gru",
input_dims=inputs_dims,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dims=output_dims,
output_activation=output_activation,
use_bias=use_bias,
)

# Use a ModelChecker to compare all added models (different frameworks)
# with each other.
model_checker = ModelChecker(config)

for fw in framework_iterator(frameworks=("tf2", "torch")):
# Add this framework version of the model to our checker.
outputs = model_checker.add(framework=fw)
# Output shape: [1=B, 1=T, [output_dim]]
self.assertEqual(outputs[ENCODER_OUT].shape, (1, 1, output_dims[0]))
# State shapes: [1=B, 1=num_layers, [hidden_dim]]
self.assertEqual(
outputs[STATE_OUT]["h"].shape,
(1, num_layers, hidden_dim),
)
# Check all added models against each other.
model_checker.check()

def test_lstm_encoders(self):
"""Tests building LSTM encoders properly and checks for correct architecture."""

# Loop through different combinations of hyperparameters.
inputs_dimss = [[1], [100]]
output_dimss = [[1], [100]]
num_lstm_layerss = [1, 3]
num_layerss = [1, 3]
hidden_dims = [16, 128]
output_activations = [None, "linear", "relu"]
use_biases = [False, True]

for permutation in itertools.product(
inputs_dimss,
num_lstm_layerss,
num_layerss,
hidden_dims,
output_activations,
output_dimss,
use_biases,
):
(
inputs_dims,
num_lstm_layers,
num_layers,
hidden_dim,
output_activation,
output_dims,
Expand All @@ -38,16 +107,17 @@ def test_lstm_encoders(self):
print(
f"Testing ...\n"
f"input_dims: {inputs_dims}\n"
f"num_lstm_layers: {num_lstm_layers}\n"
f"num_layers: {num_layers}\n"
f"hidden_dim: {hidden_dim}\n"
f"output_activation: {output_activation}\n"
f"output_dims: {output_dims}\n"
f"use_bias: {use_bias}\n"
)

config = LSTMEncoderConfig(
config = RecurrentEncoderConfig(
recurrent_layer_type="lstm",
input_dims=inputs_dims,
num_lstm_layers=num_lstm_layers,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dims=output_dims,
output_activation=output_activation,
Expand All @@ -66,11 +136,11 @@ def test_lstm_encoders(self):
# State shapes: [1=B, 1=num_layers, [hidden_dim]]
self.assertEqual(
outputs[STATE_OUT]["h"].shape,
(1, num_lstm_layers, hidden_dim),
(1, num_layers, hidden_dim),
)
self.assertEqual(
outputs[STATE_OUT]["c"].shape,
(1, num_lstm_layers, hidden_dim),
(1, num_layers, hidden_dim),
)

# Check all added models against each other (only if bias=False).
Expand Down
Loading

0 comments on commit f1307cf

Please sign in to comment.