Skip to content

Commit

Permalink
[RLlib] DreamerV3: Catalog enhancements (MLP/CNN encoders/heads compl…
Browse files Browse the repository at this point in the history
…eted and unified accross DL frameworks). (ray-project#33967)

Signed-off-by: elliottower <[email protected]>
  • Loading branch information
sven1977 authored and elliottower committed Apr 22, 2023
1 parent 2698ef8 commit 279a7d6
Show file tree
Hide file tree
Showing 30 changed files with 1,331 additions and 833 deletions.
6 changes: 0 additions & 6 deletions doc/source/rllib/package_ref/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,6 @@ Torch utilities
~set_torch_seed
~softmax_cross_entropy_with_logits

.. autosummary::
:toctree: doc/
:template: autosummary/class_without_autosummary.rst

~Swish


Numpy utilities
~~~~~~~~~~~~~~~
Expand Down
23 changes: 8 additions & 15 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1864,31 +1864,24 @@ py_test(

# Default Models
py_test(
name = "test_tf_mlp_head",
name = "test_cnn_encoders",
tags = ["team:rllib", "core", "models"],
size = "medium",
srcs = ["core/models/tf/tests/test_tf_mlp_head.py"]
srcs = ["core/models/tests/test_cnn_encoders.py"]
)

py_test(
name = "test_torch_cnn_encoder",
name = "test_mlp_encoders",
tags = ["team:rllib", "core", "models"],
size = "small",
srcs = ["core/models/torch/tests/test_torch_cnn_encoder.py"]
)

py_test(
name = "test_torch_mlp_encoder",
tags = ["team:rllib", "core", "models"],
size = "small",
srcs = ["core/models/torch/tests/test_torch_mlp_encoder.py"]
size = "medium",
srcs = ["core/models/tests/test_mlp_encoders.py"]
)

py_test(
name = "test_torch_mlp_head",
name = "test_mlp_heads",
tags = ["team:rllib", "core", "models"],
size = "small",
srcs = ["core/models/torch/tests/test_torch_mlp_head.py"]
size = "medium",
srcs = ["core/models/tests/test_mlp_heads.py"]
)

# Specs
Expand Down
11 changes: 6 additions & 5 deletions rllib/algorithms/ppo/ppo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def __init__(
post_fcnet_hiddens = self.model_config_dict["post_fcnet_hiddens"]
post_fcnet_activation = self.model_config_dict["post_fcnet_activation"]

self.pi_head_config = MLPHeadConfig(
pi_head_config_class = (
FreeLogStdMLPHeadConfig
if self.model_config_dict["free_log_std"]
else MLPHeadConfig
)
self.pi_head_config = pi_head_config_class(
input_dims=self.latent_dims,
hidden_layer_dims=post_fcnet_hiddens,
hidden_layer_activation=post_fcnet_activation,
Expand Down Expand Up @@ -144,10 +149,6 @@ def build_pi_head(self, framework: str) -> Model:
_check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls, framework=framework
)
self.pi_head_config = FreeLogStdMLPHeadConfig(
mlp_head_config=self.pi_head_config
)

return self.pi_head_config.build(framework=framework)

def build_vf_head(self, framework: str) -> Model:
Expand Down
98 changes: 69 additions & 29 deletions rllib/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import List, Optional, Tuple, Union

from ray.rllib import SampleBatch
from ray.rllib.core.models.specs.checker import convert_to_canonical_format
from ray.rllib.core.models.specs.specs_base import Spec
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.core.models.specs.checker import convert_to_canonical_format
from ray.rllib.utils.typing import TensorType

# Top level keys that unify model i/o.
Expand All @@ -20,18 +20,29 @@
CRITIC: str = "critic"


def _raise_not_decorated_exception(class_and_method, input_or_output):
raise ValueError(
f"`{class_and_method}()` not decorated with {input_or_output} specification. "
f"Decorate it with @check_{input_or_output}_specs() to define a specification."
)


@ExperimentalAPI
@dataclass
class ModelConfig(abc.ABC):
"""Base class for model configurations.
"""Base class for configuring a `Model` instance.
ModelConfigs are framework-agnostic.
A ModelConfig is usually built by RLModules after getting it from a Catalog object.
It is therefore a means of configuration for RLModules. However, ModelConfigs are
not restricted to be used only with Catalog or RLModules.
A usage Example together with a Model can be found in the Model.
ModelConfigs are DL framework-agnostic.
A `Model` (as a sub-component of an `RLModule`) is built via calling the
respective ModelConfig's `build()` method.
RLModules build their sub-components this way after receiving one or more
`ModelConfig` instances from a Catalog object.
Args:
However, `ModelConfig` is not restricted to be used only with Catalog or RLModules.
Usage examples can be found in the individual Model classes', e.g.
see `ray.rllib.core.models.configs::MLPHeadConfig`.
Attributes:
input_dims: The input dimensions of the network
output_dims: The output dimensions of the network.
"""
Expand Down Expand Up @@ -191,19 +202,42 @@ def _forward(self, input_dict: NestedDict, **kwargs) -> NestedDict:
Returns:
NestedDict: The output tensors.
"""
raise NotImplementedError

@abc.abstractmethod
def get_num_parameters(self) -> Tuple[int, int]:
"""Returns a tuple of (num trainable params, num non-trainable params)."""

@abc.abstractmethod
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)) -> None:
"""Helper method to set all weights to deterministic dummy values.
Calling this method on two `Models` that have the same architecture using
the exact same `value_sequence` arg should make both models output the exact
same values on arbitrary inputs. This will work, even if the two `Models`
are of different DL frameworks.
Args:
value_sequence: Looping through the list of all parameters (weight matrices,
bias tensors, etc..) of this model, in each iteration i, we set all
values in this parameter to `value_sequence[i % len(value_sequence)]`
(round robin).
Example:
TODO:
"""


class Encoder(Model, abc.ABC):
"""The framework-agnostic base class for all encoders RLlib produces.
"""The framework-agnostic base class for all RLlib encoders.
Encoders are used to encode observations into a latent space in RLModules.
Encoders are used to transform observations to a latent space.
Therefore, their `input_specs` contains the observation space dimensions.
Similarly, their `output_specs` contains the latent space dimensions.
Encoders can be recurrent, in which case the state should be part of input- and
output_specs. The latents that are produced by an encoder are fed into subsequent
heads. Any implementation of Encoder should also be callable. This should be done
by also inheriting from a framework-specific model base-class, s.a. TorchModel.
output_specs. The latent vectors produced by an encoder are fed into subsequent
"heads". Any implementation of Encoder should also be callable. This should be done
by also inheriting from a framework-specific model base-class, s.a. TorchModel or
TfModel.
Abstract illustration of typical flow of tensors:
Expand All @@ -219,16 +253,17 @@ class Encoder(Model, abc.ABC):
That is, for time-series data, we encode into the latent space for each time step.
This should be reflected in the `output_specs`.
Usage Example together with a ModelConfig:
Usage example together with a ModelConfig:
.. testcode::
from ray.rllib.core.models.base import ModelConfig
from ray.rllib.core.models.base import ENCODER_OUT, STATE_IN, STATE_OUT, Encoder
from ray.rllib.policy.sample_batch import SampleBatch
from dataclasses import dataclass
import numpy as np
from ray.rllib.core.models.base import ModelConfig
from ray.rllib.core.models.base import Encoder, ENCODER_OUT, STATE_IN, STATE_OUT
from ray.rllib.policy.sample_batch import SampleBatch
class NumpyEncoder(Encoder):
def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -265,11 +300,11 @@ def build(self, framework: str):
"""

@override(Model)
def get_input_specs(self) -> Union[Spec, None]:
def get_input_specs(self) -> Optional[Spec]:
return convert_to_canonical_format([SampleBatch.OBS, STATE_IN])

@override(Model)
def get_output_specs(self) -> Union[Spec, None]:
def get_output_specs(self) -> Optional[Spec]:
return convert_to_canonical_format([ENCODER_OUT, STATE_OUT])

@abc.abstractmethod
Expand All @@ -279,18 +314,23 @@ def _forward(self, input_dict: NestedDict, **kwargs) -> NestedDict:
This method is called by the forwarding method of the respective framework
that is itself wrapped by RLlib in order to check model inputs and outputs.
The input dict contains at minimum the observation and the state of the encoder.
The output dict contains at minimum the latent and the state of the encoder.
These values have the keys `SampleBatch.OBS` and `STATE_IN` in the inputs, and
`STATE_OUT` and `ENCODER_OUT` and outputs to establish an agreement
between the encoder and RLModules. For stateless encoders, states can be None.
The input dict contains at minimum the observation and the state of the encoder
(None for stateless encoders).
The output dict contains at minimum the latent and the state of the encoder
(None for stateless encoders).
To establish an agreement between the encoder and RLModules, these values
have the fixed keys `SampleBatch.OBS` and `STATE_IN` for the `input_dict`,
and `STATE_OUT` and `ENCODER_OUT` for the returned NestedDict.
Args:
input_dict: The input tensors.
input_dict: The input tensors. Must contain at a minimum the keys
SampleBatch.OBS and STATE_IN (which might be None for stateless
encoders).
**kwargs: Forward compatibility kwargs.
Returns:
NestedDict: The output tensors.
NestedDict: The output tensors. Must contain at a minimum the keys
ENCODER_OUT and STATE_OUT (which might be None for stateless encoders).
"""
raise NotImplementedError

Expand Down Expand Up @@ -322,7 +362,7 @@ def __init__(self, config: ModelConfig) -> None:
super().__init__(config)

@override(Model)
def get_input_specs(self) -> Union[Spec, None]:
def get_input_specs(self) -> Optional[Spec]:
if self.config.shared:
state_in_spec = self.encoder.input_specs[STATE_IN]
else:
Expand All @@ -340,7 +380,7 @@ def get_input_specs(self) -> Union[Spec, None]:
)

@override(Model)
def get_output_specs(self) -> Union[Spec, None]:
def get_output_specs(self) -> Optional[Spec]:
if self.config.shared:
state_out_spec = self.encoder.output_specs[STATE_OUT]
else:
Expand Down
9 changes: 6 additions & 3 deletions rllib/core/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,13 @@ def get_encoder_config(

encoder_config = CNNEncoderConfig(
input_dims=observation_space.shape,
filter_specifiers=model_config_dict["conv_filters"],
filter_layer_activation=activation,
output_activation=output_activation,
cnn_filter_specifiers=model_config_dict["conv_filters"],
cnn_activation=activation,
cnn_use_layernorm=model_config_dict.get(
"conv_use_layernorm", False
),
output_dims=[encoder_latent_dim],
output_activation=output_activation,
)
# input_space is a 2D Box
elif (
Expand Down
Loading

0 comments on commit 279a7d6

Please sign in to comment.