diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py index e61adbd2b7ed..b8dde4673ad0 100644 --- a/rllib/core/models/catalog.py +++ b/rllib/core/models/catalog.py @@ -11,9 +11,9 @@ from ray.rllib.core.models.base import ModelConfig from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.core.models.configs import ( - MLPEncoderConfig, - LSTMEncoderConfig, CNNEncoderConfig, + MLPEncoderConfig, + RecurrentEncoderConfig, ) from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor from ray.rllib.models import MODEL_DEFAULTS @@ -321,14 +321,13 @@ def get_encoder_config( use_attention = model_config_dict["use_attention"] if use_lstm: - encoder_config = LSTMEncoderConfig( + encoder_config = RecurrentEncoderConfig( + recurrent_layer_type="lstm", hidden_dim=model_config_dict["lstm_cell_size"], - batch_first=not model_config_dict["_time_major"], + batch_major=not model_config_dict["_time_major"], num_layers=1, output_dims=[model_config_dict["lstm_cell_size"]], output_activation=output_activation, - observation_space=observation_space, - action_space=action_space, view_requirements_dict=view_requirements, get_tokenizer_config=cls.get_tokenizer_config, ) diff --git a/rllib/core/models/configs.py b/rllib/core/models/configs.py index 04f9e0660e58..2011aed2e1d2 100644 --- a/rllib/core/models/configs.py +++ b/rllib/core/models/configs.py @@ -609,20 +609,24 @@ def build(self, framework: str = "torch") -> Encoder: @ExperimentalAPI @dataclass -class LSTMEncoderConfig(ModelConfig): - """Configuration for an LSTM encoder. +class RecurrentEncoderConfig(ModelConfig): + """Configuration for an LSTM-based or a GRU-based encoder. - The encoder consists of N LSTM layers stacked on top of each other and feeding their - outputs as inputs to the respective next layer. The internal state is structued - as (num_layers, B, hidden-size) for both h- and c-states of the LSTM layer(s). + The encoder consists of N LSTM/GRU layers stacked on top of each other and feeding + their outputs as inputs to the respective next layer. The internal state is + structued as (num_layers, B, hidden-size) for all hidden state components, e.g. + h- and c-states of the LSTM layer(s) or h-state of the GRU layer(s). + For example, the hidden states of an LSTMEncoder with num_layers=2 and hidden_dim=8 + would be: {"h": (2, B, 8), "c": (2, B, 8)}. Example: .. code-block:: python # Configuration: - config = LSTMEncoderConfig( + config = RecurrentEncoderConfig( + recurrent_layer_type="lstm", input_dims=[16], # must be 1D tensor hidden_dim=128, - num_lstm_layers=2, + num_layers=2, output_dims=[256], # maybe None or a 1D tensor output_activation="linear", use_bias=True, @@ -640,29 +644,33 @@ class LSTMEncoderConfig(ModelConfig): Example: .. code-block:: python # Configuration: - config = LSTMEncoderConfig( + config = RecurrentEncoderConfig( + recurrent_layer_type="gru", input_dims=[32], # must be 1D tensor hidden_dim=64, - num_lstm_layers=1, + num_layers=1, output_dims=None, # maybe None or a 1D tensor use_bias=False, ) model = config.build(framework="torch") # Resulting stack in pseudocode: - # LSTM(32, 64, bias=False) + # GRU(32, 64, bias=False) - # Resulting shape of the internal states (c- and h-states): - # (1, B, 64) for each c- and h-states. + # Resulting shape of the internal state: + # (1, B, 64) Attributes: - input_dims: A 1D tensor indicating the input dimension, e.g. `[32]`. - hidden_dim: The size of the hidden internal states (h- and c-states) of the - LSTM layer(s). - num_lstm_layers: The number of LSTM layers to stack. + recurrent_layer_type: The type of the recurrent layer(s). + Either "lstm" or "gru". + input_dims: The input dimensions. Must be 1D. This is the 1D shape of the tensor + that goes into the first recurrent layer. + hidden_dim: The size of the hidden internal state(s) of the recurrent layer(s). + For example, for an LSTM, this would be the size of the c- and h-tensors. + num_layers: The number of recurrent (LSTM or GRU) layers to stack. batch_major: Wether the input is batch major (B, T, ..) or time major (T, B, ..). - output_activation: The activation function to use for the output layer. + output_activation: The activation function to use for the linear output layer. use_bias: Whether to use bias on all layers in the network. view_requirements_dict: The view requirements to use if anything else than observation_space or action_space is to be encoded. This signifies an @@ -672,8 +680,9 @@ class LSTMEncoderConfig(ModelConfig): other spaces that might be present in the view_requirements_dict. """ + recurrent_layer_type: str = "lstm" hidden_dim: int = None - num_lstm_layers: int = None + num_layers: int = None batch_major: bool = True output_activation: str = "linear" use_bias: bool = True @@ -682,10 +691,15 @@ class LSTMEncoderConfig(ModelConfig): def _validate(self, framework: str = "torch"): """Makes sure that settings are valid.""" + if self.recurrent_layer_type not in ["gru", "lstm"]: + raise ValueError( + f"`recurrent_layer_type` ({self.recurrent_layer_type}) of " + "RecurrentEncoderConfig must be 'gru' or 'lstm'!" + ) if self.input_dims is not None and len(self.input_dims) != 1: raise ValueError( - f"`input_dims` ({self.input_dims}) of LSTMEncoderConfig must be 1D, " - "e.g. `[32]`!" + f"`input_dims` ({self.input_dims}) of RecurrentEncoderConfig must be " + "1D, e.g. `[32]`!" ) # Call these already here to catch errors early on. @@ -698,20 +712,27 @@ def build(self, framework: str = "torch") -> Encoder: or self.view_requirements_dict is not None ): raise NotImplementedError( - "LSTMEncoderConfig does not support configuring LSTMs that encode " - "depending on view_requirements or have a custom tokenizer. " + "RecurrentEncoderConfig does not support configuring Models that " + "encode depending on view_requirements or have a custom tokenizer. " "Therefore, this config expects `view_requirements_dict=None` and " "`get_tokenizer_config=None`." ) if framework == "torch": - from ray.rllib.core.models.torch.encoder import TorchLSTMEncoder - - return TorchLSTMEncoder(self) + from ray.rllib.core.models.torch.encoder import ( + TorchGRUEncoder as GRU, + TorchLSTMEncoder as LSTM, + ) else: - from ray.rllib.core.models.tf.encoder import TfLSTMEncoder + from ray.rllib.core.models.tf.encoder import ( + TfGRUEncoder as GRU, + TfLSTMEncoder as LSTM, + ) - return TfLSTMEncoder(self) + if self.recurrent_layer_type == "lstm": + return LSTM(self) + else: + return GRU(self) @ExperimentalAPI diff --git a/rllib/core/models/tests/test_recurrent_encoders.py b/rllib/core/models/tests/test_recurrent_encoders.py index bee4ac36945a..4708f0787eaa 100644 --- a/rllib/core/models/tests/test_recurrent_encoders.py +++ b/rllib/core/models/tests/test_recurrent_encoders.py @@ -2,25 +2,94 @@ import unittest from ray.rllib.core.models.base import ENCODER_OUT, STATE_OUT -from ray.rllib.core.models.configs import LSTMEncoderConfig +from ray.rllib.core.models.configs import RecurrentEncoderConfig +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import framework_iterator, ModelChecker +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() + class TestRecurrentEncoders(unittest.TestCase): + def test_gru_encoders(self): + """Tests building GRU encoders properly and checks for correct architecture.""" + + # Loop through different combinations of hyperparameters. + inputs_dimss = [[1], [100]] + output_dimss = [[1], [50]] + num_layerss = [1, 4] + hidden_dims = [128, 256] + output_activations = ["linear", "silu", "relu"] + use_biases = [False, True] + + for permutation in itertools.product( + inputs_dimss, + num_layerss, + hidden_dims, + output_activations, + output_dimss, + use_biases, + ): + ( + inputs_dims, + num_layers, + hidden_dim, + output_activation, + output_dims, + use_bias, + ) = permutation + + print( + f"Testing ...\n" + f"input_dims: {inputs_dims}\n" + f"num_layers: {num_layers}\n" + f"hidden_dim: {hidden_dim}\n" + f"output_activation: {output_activation}\n" + f"output_dims: {output_dims}\n" + f"use_bias: {use_bias}\n" + ) + + config = RecurrentEncoderConfig( + recurrent_layer_type="gru", + input_dims=inputs_dims, + num_layers=num_layers, + hidden_dim=hidden_dim, + output_dims=output_dims, + output_activation=output_activation, + use_bias=use_bias, + ) + + # Use a ModelChecker to compare all added models (different frameworks) + # with each other. + model_checker = ModelChecker(config) + + for fw in framework_iterator(frameworks=("tf2", "torch")): + # Add this framework version of the model to our checker. + outputs = model_checker.add(framework=fw) + # Output shape: [1=B, 1=T, [output_dim]] + self.assertEqual(outputs[ENCODER_OUT].shape, (1, 1, output_dims[0])) + # State shapes: [1=B, 1=num_layers, [hidden_dim]] + self.assertEqual( + outputs[STATE_OUT]["h"].shape, + (1, num_layers, hidden_dim), + ) + # Check all added models against each other. + model_checker.check() + def test_lstm_encoders(self): """Tests building LSTM encoders properly and checks for correct architecture.""" # Loop through different combinations of hyperparameters. inputs_dimss = [[1], [100]] output_dimss = [[1], [100]] - num_lstm_layerss = [1, 3] + num_layerss = [1, 3] hidden_dims = [16, 128] output_activations = [None, "linear", "relu"] use_biases = [False, True] for permutation in itertools.product( inputs_dimss, - num_lstm_layerss, + num_layerss, hidden_dims, output_activations, output_dimss, @@ -28,7 +97,7 @@ def test_lstm_encoders(self): ): ( inputs_dims, - num_lstm_layers, + num_layers, hidden_dim, output_activation, output_dims, @@ -38,16 +107,17 @@ def test_lstm_encoders(self): print( f"Testing ...\n" f"input_dims: {inputs_dims}\n" - f"num_lstm_layers: {num_lstm_layers}\n" + f"num_layers: {num_layers}\n" f"hidden_dim: {hidden_dim}\n" f"output_activation: {output_activation}\n" f"output_dims: {output_dims}\n" f"use_bias: {use_bias}\n" ) - config = LSTMEncoderConfig( + config = RecurrentEncoderConfig( + recurrent_layer_type="lstm", input_dims=inputs_dims, - num_lstm_layers=num_lstm_layers, + num_layers=num_layers, hidden_dim=hidden_dim, output_dims=output_dims, output_activation=output_activation, @@ -66,11 +136,11 @@ def test_lstm_encoders(self): # State shapes: [1=B, 1=num_layers, [hidden_dim]] self.assertEqual( outputs[STATE_OUT]["h"].shape, - (1, num_lstm_layers, hidden_dim), + (1, num_layers, hidden_dim), ) self.assertEqual( outputs[STATE_OUT]["c"].shape, - (1, num_lstm_layers, hidden_dim), + (1, num_layers, hidden_dim), ) # Check all added models against each other (only if bias=False). diff --git a/rllib/core/models/tf/encoder.py b/rllib/core/models/tf/encoder.py index 33282e81289f..8f739934dcb0 100644 --- a/rllib/core/models/tf/encoder.py +++ b/rllib/core/models/tf/encoder.py @@ -13,8 +13,8 @@ from ray.rllib.core.models.configs import ( ActorCriticEncoderConfig, CNNEncoderConfig, - LSTMEncoderConfig, MLPEncoderConfig, + RecurrentEncoderConfig, ) from ray.rllib.core.models.tf.base import TfModel from ray.rllib.core.models.tf.primitives import TfMLP, TfCNN @@ -150,15 +150,97 @@ def _forward(self, inputs: NestedDict, **kwargs) -> NestedDict: ) +class TfGRUEncoder(TfModel, Encoder): + """An encoder that uses one or more GRU layers and a linear output layer.""" + + def __init__(self, config: RecurrentEncoderConfig) -> None: + TfModel.__init__(self, config) + + # Create the tf GRU layers. + self.grus = [] + for _ in range(config.num_layers): + self.grus.append( + tf.keras.layers.GRU( + config.hidden_dim, + time_major=not config.batch_major, + use_bias=config.use_bias, + return_sequences=True, + return_state=True, + ) + ) + + # Create the final dense layer. + self.linear = tf.keras.layers.Dense( + units=config.output_dims[0], + use_bias=config.use_bias, + ) + + @override(Model) + def get_input_specs(self) -> Optional[Spec]: + return SpecDict( + { + # b, t for batch major; t, b for time major. + SampleBatch.OBS: TfTensorSpec("b, t, d", d=self.config.input_dims[0]), + STATE_IN: { + "h": TfTensorSpec( + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers + ), + }, + } + ) + + @override(Model) + def get_output_specs(self) -> Optional[Spec]: + return SpecDict( + { + ENCODER_OUT: TfTensorSpec("b, t, d", d=self.config.output_dims[0]), + STATE_OUT: { + "h": TfTensorSpec( + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers + ), + }, + } + ) + + @override(Model) + def get_initial_state(self): + return { + "h": tf.zeros((self.config.num_layers, self.config.hidden_dim)), + } + + @override(Model) + def _forward(self, inputs: NestedDict, **kwargs) -> NestedDict: + out = tf.cast(inputs[SampleBatch.OBS], tf.float32) + + # States are batch-first when coming in. Make them layers-first. + states_in = tree.map_structure( + lambda s: tf.transpose(s, perm=[1, 0] + list(range(2, len(s.shape)))), + inputs[STATE_IN], + ) + + states_out = [] + for i, layer in enumerate(self.grus): + out, h = layer(out, states_in["h"][i]) + states_out.append(h) + + out = self.linear(out) + + return { + ENCODER_OUT: out, + # Make state_out batch-first. + STATE_OUT: {"h": tf.stack(states_out, 1)}, + } + + class TfLSTMEncoder(TfModel, Encoder): """An encoder that uses an LSTM cell and a linear layer.""" - def __init__(self, config: LSTMEncoderConfig) -> None: + def __init__(self, config: RecurrentEncoderConfig) -> None: TfModel.__init__(self, config) # Create the tf LSTM layers. self.lstms = [] - for _ in range(config.num_lstm_layers): + for _ in range(config.num_layers): self.lstms.append( tf.keras.layers.LSTM( config.hidden_dim, @@ -183,14 +265,10 @@ def get_input_specs(self) -> Optional[Spec]: SampleBatch.OBS: TfTensorSpec("b, t, d", d=self.config.input_dims[0]), STATE_IN: { "h": TfTensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers ), "c": TfTensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers ), }, } @@ -203,14 +281,10 @@ def get_output_specs(self) -> Optional[Spec]: ENCODER_OUT: TfTensorSpec("b, t, d", d=self.config.output_dims[0]), STATE_OUT: { "h": TfTensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers ), "c": TfTensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers ), }, } diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index d9e587f2b723..8fbd1c70c03b 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -13,8 +13,8 @@ from ray.rllib.core.models.configs import ( ActorCriticEncoderConfig, CNNEncoderConfig, - LSTMEncoderConfig, MLPEncoderConfig, + RecurrentEncoderConfig, ) from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.models.torch.primitives import TorchMLP, TorchCNN @@ -155,10 +155,85 @@ def _forward(self, inputs: NestedDict, **kwargs) -> NestedDict: ) +class TorchGRUEncoder(TorchModel, Encoder): + """An encoder that uses one or more GRU cells and a linear output layer.""" + + def __init__(self, config: RecurrentEncoderConfig) -> None: + TorchModel.__init__(self, config) + + # Create the torch LSTM layer. + self.gru = nn.GRU( + config.input_dims[0], + config.hidden_dim, + config.num_layers, + batch_first=config.batch_major, + bias=config.use_bias, + ) + # Create the final dense layer. + self.linear = nn.Linear( + config.hidden_dim, + config.output_dims[0], + bias=config.use_bias, + ) + + @override(Model) + def get_input_specs(self) -> Optional[Spec]: + return SpecDict( + { + # b, t for batch major; t, b for time major. + SampleBatch.OBS: TorchTensorSpec( + "b, t, d", d=self.config.input_dims[0] + ), + STATE_IN: { + "h": TorchTensorSpec( + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers + ), + }, + } + ) + + @override(Model) + def get_output_specs(self) -> Optional[Spec]: + return SpecDict( + { + ENCODER_OUT: TorchTensorSpec("b, t, d", d=self.config.output_dims[0]), + STATE_OUT: { + "h": TorchTensorSpec( + "b, l, h", h=self.config.hidden_dim, l=self.config.num_layers + ), + }, + } + ) + + @override(Model) + def get_initial_state(self): + return { + "h": torch.zeros(self.config.num_layers, self.config.hidden_dim), + } + + @override(Model) + def _forward(self, inputs: NestedDict, **kwargs) -> NestedDict: + out = inputs[SampleBatch.OBS].float() + + # States are batch-first when coming in. Make them layers-first. + states_in = tree.map_structure(lambda s: s.transpose(0, 1), inputs[STATE_IN]) + + out, states_out = self.gru(out, states_in["h"]) + states_out = {"h": states_out} + + out = self.linear(out) + + return { + ENCODER_OUT: out, + # Make states layer-first again. + STATE_OUT: tree.map_structure(lambda s: s.transpose(0, 1), states_out), + } + + class TorchLSTMEncoder(TorchModel, Encoder): """An encoder that uses an LSTM cell and a linear layer.""" - def __init__(self, config: LSTMEncoderConfig) -> None: + def __init__(self, config: RecurrentEncoderConfig) -> None: TorchModel.__init__(self, config) # Create the torch LSTM layer. @@ -166,7 +241,7 @@ def __init__(self, config: LSTMEncoderConfig) -> None: # We only support 1D spaces right now. config.input_dims[0], config.hidden_dim, - config.num_lstm_layers, + config.num_layers, batch_first=config.batch_major, bias=config.use_bias, ) @@ -189,12 +264,12 @@ def get_input_specs(self) -> Optional[Spec]: "h": TorchTensorSpec( "b, l, h", h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + l=self.config.num_layers, ), "c": TorchTensorSpec( "b, l, h", h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + l=self.config.num_layers, ), }, } @@ -209,12 +284,12 @@ def get_output_specs(self) -> Optional[Spec]: "h": TorchTensorSpec( "b, l, h", h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + l=self.config.num_layers, ), "c": TorchTensorSpec( "b, l, h", h=self.config.hidden_dim, - l=self.config.num_lstm_layers, + l=self.config.num_layers, ), }, }