Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] DreamerV3: Catalog enhancements 05 - GRU default model support #34284

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
36a4c65
wip
sven1977 Mar 28, 2023
fd432e1
wip
sven1977 Mar 28, 2023
eb33ba0
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Mar 29, 2023
36346cf
wip
sven1977 Mar 29, 2023
5995906
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Mar 29, 2023
950eb81
wip
sven1977 Mar 30, 2023
392b5f6
wip
sven1977 Mar 31, 2023
2cce48c
wip
sven1977 Mar 31, 2023
46b3c4f
fix
sven1977 Mar 31, 2023
1ab97e7
removed torch SiLU (torch now has its own)
sven1977 Mar 31, 2023
6383061
LINT
sven1977 Mar 31, 2023
1413b98
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Mar 31, 2023
dc46838
Update rllib/core/models/base.py
sven1977 Mar 31, 2023
9427f9b
wip
sven1977 Apr 2, 2023
5f91341
wip
sven1977 Apr 2, 2023
647545d
LINT and cleanup
sven1977 Apr 2, 2023
6e88ecc
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 2, 2023
14a277b
Merge branch 'model_spec_rename_to_specs_plural_to_avoid_tf_keras_mod…
sven1977 Apr 2, 2023
fe955f6
Merge remote-tracking branch 'origin/dreamer_v3_catalog_enhancements_…
sven1977 Apr 2, 2023
ebadcfe
wip
sven1977 Apr 2, 2023
fb183ea
Merge branch 'master' of https://github.com/ray-project/ray into mode…
sven1977 Apr 3, 2023
d802617
test case fix
sven1977 Apr 3, 2023
4efa147
cleanup
sven1977 Apr 3, 2023
3885cba
Merge branch 'model_spec_rename_to_specs_plural_to_avoid_tf_keras_mod…
sven1977 Apr 3, 2023
6a6370c
wip
sven1977 Apr 3, 2023
3de4cd1
merge
sven1977 Apr 4, 2023
28d0c5e
wip
sven1977 Apr 4, 2023
a9c2897
wip
sven1977 Apr 4, 2023
16101b8
wip
sven1977 Apr 4, 2023
ed65960
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 4, 2023
205026f
Merge branch 'dreamer_v3_catalog_enhancements_01' into dreamer_v3_cat…
sven1977 Apr 4, 2023
4dbb26a
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 5, 2023
9670662
wip
sven1977 Apr 5, 2023
cbe48bf
wip
sven1977 Apr 6, 2023
bea7687
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 6, 2023
b876731
wip
sven1977 Apr 6, 2023
f16fb63
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 6, 2023
f349292
wip
sven1977 Apr 6, 2023
5d6f194
LINT
sven1977 Apr 6, 2023
0288d21
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 6, 2023
6bc283a
LINT
sven1977 Apr 6, 2023
c0b04fa
LINT
sven1977 Apr 6, 2023
6506e15
wip
sven1977 Apr 6, 2023
c6322b0
wip
sven1977 Apr 6, 2023
e59f467
fix
sven1977 Apr 6, 2023
2cdd41c
Merge branch 'dreamer_v3_catalog_enhancements_01' into dreamer_v3_cat…
sven1977 Apr 6, 2023
80c5a99
merge
sven1977 Apr 8, 2023
4056ab6
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 8, 2023
e74db3b
wip
sven1977 Apr 11, 2023
5891668
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 11, 2023
5467f04
wip
sven1977 Apr 11, 2023
883993c
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 11, 2023
33b998e
wip
sven1977 Apr 11, 2023
d2e470b
merge
sven1977 Apr 16, 2023
9f783ea
wip
sven1977 Apr 16, 2023
29638da
fix
sven1977 Apr 17, 2023
cee2ee7
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 17, 2023
2989da3
fixes
sven1977 Apr 18, 2023
283c39a
Merge branch 'master' of https://github.com/ray-project/ray into drea…
sven1977 Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions rllib/core/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from ray.rllib.core.models.base import ModelConfig
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.core.models.configs import (
MLPEncoderConfig,
LSTMEncoderConfig,
CNNEncoderConfig,
MLPEncoderConfig,
RecurrentEncoderConfig,
)
from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
from ray.rllib.models import MODEL_DEFAULTS
Expand Down Expand Up @@ -321,14 +321,13 @@ def get_encoder_config(
use_attention = model_config_dict["use_attention"]

if use_lstm:
encoder_config = LSTMEncoderConfig(
encoder_config = RecurrentEncoderConfig(
recurrent_layer_type="lstm",
hidden_dim=model_config_dict["lstm_cell_size"],
batch_first=not model_config_dict["_time_major"],
batch_major=not model_config_dict["_time_major"],
num_layers=1,
output_dims=[model_config_dict["lstm_cell_size"]],
output_activation=output_activation,
observation_space=observation_space,
action_space=action_space,
view_requirements_dict=view_requirements,
get_tokenizer_config=cls.get_tokenizer_config,
)
Expand Down
75 changes: 48 additions & 27 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,20 +609,24 @@ def build(self, framework: str = "torch") -> Encoder:

@ExperimentalAPI
@dataclass
class LSTMEncoderConfig(ModelConfig):
"""Configuration for an LSTM encoder.
class RecurrentEncoderConfig(ModelConfig):
"""Configuration for an LSTM-based or a GRU-based encoder.

The encoder consists of N LSTM layers stacked on top of each other and feeding their
outputs as inputs to the respective next layer. The internal state is structued
as (num_layers, B, hidden-size) for both h- and c-states of the LSTM layer(s).
The encoder consists of N LSTM/GRU layers stacked on top of each other and feeding
their outputs as inputs to the respective next layer. The internal state is
structued as (num_layers, B, hidden-size) for all hidden state components, e.g.
h- and c-states of the LSTM layer(s) or h-state of the GRU layer(s).
For example, the hidden states of an LSTMEncoder with num_layers=2 and hidden_dim=8
would be: {"h": (2, B, 8), "c": (2, B, 8)}.

Example:
.. code-block:: python
# Configuration:
config = LSTMEncoderConfig(
config = RecurrentEncoderConfig(
recurrent_layer_type="lstm",
input_dims=[16], # must be 1D tensor
hidden_dim=128,
num_lstm_layers=2,
num_layers=2,
output_dims=[256], # maybe None or a 1D tensor
output_activation="linear",
use_bias=True,
Expand All @@ -640,29 +644,33 @@ class LSTMEncoderConfig(ModelConfig):
Example:
.. code-block:: python
# Configuration:
config = LSTMEncoderConfig(
config = RecurrentEncoderConfig(
recurrent_layer_type="gru",
input_dims=[32], # must be 1D tensor
hidden_dim=64,
num_lstm_layers=1,
num_layers=1,
output_dims=None, # maybe None or a 1D tensor
use_bias=False,
)
model = config.build(framework="torch")

# Resulting stack in pseudocode:
# LSTM(32, 64, bias=False)
# GRU(32, 64, bias=False)

# Resulting shape of the internal states (c- and h-states):
# (1, B, 64) for each c- and h-states.
# Resulting shape of the internal state:
# (1, B, 64)

Attributes:
input_dims: A 1D tensor indicating the input dimension, e.g. `[32]`.
hidden_dim: The size of the hidden internal states (h- and c-states) of the
LSTM layer(s).
num_lstm_layers: The number of LSTM layers to stack.
recurrent_layer_type: The type of the recurrent layer(s).
Either "lstm" or "gru".
input_dims: The input dimensions. Must be 1D. This is the 1D shape of the tensor
that goes into the first recurrent layer.
hidden_dim: The size of the hidden internal state(s) of the recurrent layer(s).
For example, for an LSTM, this would be the size of the c- and h-tensors.
num_layers: The number of recurrent (LSTM or GRU) layers to stack.
batch_major: Wether the input is batch major (B, T, ..) or
time major (T, B, ..).
output_activation: The activation function to use for the output layer.
output_activation: The activation function to use for the linear output layer.
use_bias: Whether to use bias on all layers in the network.
view_requirements_dict: The view requirements to use if anything else than
observation_space or action_space is to be encoded. This signifies an
Expand All @@ -672,8 +680,9 @@ class LSTMEncoderConfig(ModelConfig):
other spaces that might be present in the view_requirements_dict.
"""

recurrent_layer_type: str = "lstm"
hidden_dim: int = None
num_lstm_layers: int = None
num_layers: int = None
batch_major: bool = True
output_activation: str = "linear"
use_bias: bool = True
Expand All @@ -682,10 +691,15 @@ class LSTMEncoderConfig(ModelConfig):

def _validate(self, framework: str = "torch"):
"""Makes sure that settings are valid."""
if self.recurrent_layer_type not in ["gru", "lstm"]:
raise ValueError(
f"`recurrent_layer_type` ({self.recurrent_layer_type}) of "
"RecurrentEncoderConfig must be 'gru' or 'lstm'!"
)
if self.input_dims is not None and len(self.input_dims) != 1:
raise ValueError(
f"`input_dims` ({self.input_dims}) of LSTMEncoderConfig must be 1D, "
"e.g. `[32]`!"
f"`input_dims` ({self.input_dims}) of RecurrentEncoderConfig must be "
"1D, e.g. `[32]`!"
)

# Call these already here to catch errors early on.
Expand All @@ -698,20 +712,27 @@ def build(self, framework: str = "torch") -> Encoder:
or self.view_requirements_dict is not None
):
raise NotImplementedError(
"LSTMEncoderConfig does not support configuring LSTMs that encode "
"depending on view_requirements or have a custom tokenizer. "
"RecurrentEncoderConfig does not support configuring Models that "
"encode depending on view_requirements or have a custom tokenizer. "
"Therefore, this config expects `view_requirements_dict=None` and "
"`get_tokenizer_config=None`."
)

if framework == "torch":
from ray.rllib.core.models.torch.encoder import TorchLSTMEncoder

return TorchLSTMEncoder(self)
from ray.rllib.core.models.torch.encoder import (
TorchGRUEncoder as GRU,
TorchLSTMEncoder as LSTM,
)
else:
from ray.rllib.core.models.tf.encoder import TfLSTMEncoder
from ray.rllib.core.models.tf.encoder import (
TfGRUEncoder as GRU,
TfLSTMEncoder as LSTM,
)

return TfLSTMEncoder(self)
if self.recurrent_layer_type == "lstm":
return LSTM(self)
else:
return GRU(self)


@ExperimentalAPI
Expand Down
88 changes: 79 additions & 9 deletions rllib/core/models/tests/test_recurrent_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,102 @@
import unittest

from ray.rllib.core.models.base import ENCODER_OUT, STATE_OUT
from ray.rllib.core.models.configs import LSTMEncoderConfig
from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator, ModelChecker

_, tf, _ = try_import_tf()
torch, _ = try_import_torch()


class TestRecurrentEncoders(unittest.TestCase):
def test_gru_encoders(self):
"""Tests building GRU encoders properly and checks for correct architecture."""

# Loop through different combinations of hyperparameters.
inputs_dimss = [[1], [100]]
output_dimss = [[1], [50]]
num_layerss = [1, 4]
hidden_dims = [128, 256]
output_activations = ["linear", "silu", "relu"]
use_biases = [False, True]

for permutation in itertools.product(
inputs_dimss,
num_layerss,
hidden_dims,
output_activations,
output_dimss,
use_biases,
):
(
inputs_dims,
num_layers,
hidden_dim,
output_activation,
output_dims,
use_bias,
) = permutation

print(
f"Testing ...\n"
f"input_dims: {inputs_dims}\n"
f"num_layers: {num_layers}\n"
f"hidden_dim: {hidden_dim}\n"
f"output_activation: {output_activation}\n"
f"output_dims: {output_dims}\n"
f"use_bias: {use_bias}\n"
)

config = RecurrentEncoderConfig(
recurrent_layer_type="gru",
input_dims=inputs_dims,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dims=output_dims,
output_activation=output_activation,
use_bias=use_bias,
)

# Use a ModelChecker to compare all added models (different frameworks)
# with each other.
model_checker = ModelChecker(config)

for fw in framework_iterator(frameworks=("tf2", "torch")):
# Add this framework version of the model to our checker.
outputs = model_checker.add(framework=fw)
# Output shape: [1=B, 1=T, [output_dim]]
self.assertEqual(outputs[ENCODER_OUT].shape, (1, 1, output_dims[0]))
# State shapes: [1=B, 1=num_layers, [hidden_dim]]
self.assertEqual(
outputs[STATE_OUT]["h"].shape,
(1, num_layers, hidden_dim),
)
# Check all added models against each other.
model_checker.check()

def test_lstm_encoders(self):
"""Tests building LSTM encoders properly and checks for correct architecture."""

# Loop through different combinations of hyperparameters.
inputs_dimss = [[1], [100]]
output_dimss = [[1], [100]]
num_lstm_layerss = [1, 3]
num_layerss = [1, 3]
hidden_dims = [16, 128]
output_activations = [None, "linear", "relu"]
use_biases = [False, True]

for permutation in itertools.product(
inputs_dimss,
num_lstm_layerss,
num_layerss,
hidden_dims,
output_activations,
output_dimss,
use_biases,
):
(
inputs_dims,
num_lstm_layers,
num_layers,
hidden_dim,
output_activation,
output_dims,
Expand All @@ -38,16 +107,17 @@ def test_lstm_encoders(self):
print(
f"Testing ...\n"
f"input_dims: {inputs_dims}\n"
f"num_lstm_layers: {num_lstm_layers}\n"
f"num_layers: {num_layers}\n"
f"hidden_dim: {hidden_dim}\n"
f"output_activation: {output_activation}\n"
f"output_dims: {output_dims}\n"
f"use_bias: {use_bias}\n"
)

config = LSTMEncoderConfig(
config = RecurrentEncoderConfig(
recurrent_layer_type="lstm",
input_dims=inputs_dims,
num_lstm_layers=num_lstm_layers,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dims=output_dims,
output_activation=output_activation,
Expand All @@ -66,11 +136,11 @@ def test_lstm_encoders(self):
# State shapes: [1=B, 1=num_layers, [hidden_dim]]
self.assertEqual(
outputs[STATE_OUT]["h"].shape,
(1, num_lstm_layers, hidden_dim),
(1, num_layers, hidden_dim),
)
self.assertEqual(
outputs[STATE_OUT]["c"].shape,
(1, num_lstm_layers, hidden_dim),
(1, num_layers, hidden_dim),
)

# Check all added models against each other (only if bias=False).
Expand Down
Loading