-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] DreamerV3: Catalog enhancements (MLP/CNN encoders/heads completed and unified accross DL frameworks). #33967
[RLlib] DreamerV3: Catalog enhancements (MLP/CNN encoders/heads completed and unified accross DL frameworks). #33967
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…mer_v3_catalog_enhancements_01
Signed-off-by: sven1977 <[email protected]>
…mer_v3_catalog_enhancements_01
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@@ -1861,6 +1861,27 @@ py_test( | |||
srcs = ["core/models/tests/test_catalog.py"] | |||
) | |||
|
|||
py_test( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unified tests between tf and torch to be able to compare exact number of model parameters (assert architectures are the same).
@@ -20,16 +20,25 @@ | |||
CRITIC: str = "critic" | |||
|
|||
|
|||
def _raise_not_decorated_exception(class_and_method, input_or_output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this here. This was duplicated code in tf and torch versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -193,6 +202,11 @@ def _forward(self, input_dict: NestedDict, **kwargs) -> NestedDict: | |||
""" | |||
raise NotImplementedError | |||
|
|||
@abc.abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this convenience method to the top-level API.
Attributes: | ||
hidden_layer_dims: The sizes of the hidden layers. | ||
hidden_layer_activation: The activation function to use after each layer ( | ||
except for the output). | ||
hidden_layer_use_layernorm: Whether to insert a LayerNorm functionality |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to all primitives (MLP and CNN):
- option to switch on layernorm'ing in between layers
- use bias or not
@@ -216,37 +236,47 @@ class CNNEncoderConfig(ModelConfig): | |||
Attributes: | |||
input_dims: The input dimension of the network. These must be given in the | |||
form of `(width, height, channels)`. | |||
filter_specifiers: A list of lists, where each element of an inner list | |||
cnn_filter_specifiers: A list of lists, where each element of an inner list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to chose more accucate terms for CNN configs.
@@ -0,0 +1,129 @@ | |||
import unittest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly the exact same file as before, but now unified between tf and torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -0,0 +1,116 @@ | |||
import unittest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly the exact same file as before, but now unified between tf and torch.
@@ -0,0 +1,65 @@ | |||
import abc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this from another file. Now this is analogous to the already existing torch/base.py
.
Both define the DL-specific base RLlib Model classes TfModel and TorchModel.
@@ -63,12 +71,75 @@ def _forward(self, inputs: NestedDict) -> NestedDict: | |||
) | |||
|
|||
|
|||
class TfCNNEncoder(TfModel, Encoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new class (tf did not have CNN encoder before).
self.network = tf.keras.Sequential(layers) | ||
|
||
def __call__(self, inputs): | ||
def call(self, inputs, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should touch __call__
directly, but override call
instead. At least that's what keras docs say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this turned out to be a problem. Now that we override call
, we ran into a naming conflict with keras models, which have their own input_spec
. We therefore decided to rename our properties into input_specs
(plural) and output_specs
, which alleviated this issue. Should be ok now.
Signed-off-by: sven1977 <[email protected]>
) | ||
|
||
@override(Model) | ||
def get_input_spec(self) -> Union[Spec, None]: | ||
return SpecDict( | ||
{ | ||
SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dims[0]), | ||
SampleBatch.OBS: TorchTensorSpec("b, d", d=self.config.input_dims[0]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed these "hanging right-side dims" to d
. I feel like h
should be used only for image height and LSTM internal h-state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d=dims
layers.append(nn.Flatten()) | ||
|
||
# Add a final linear layer to make sure that the outputs have the correct | ||
# dimensionality. | ||
layers.append( | ||
nn.Linear( | ||
int(cnn.output_width) * int(cnn.output_height), config.output_dims[0] | ||
int(cnn.output_width) * int(cnn.output_height) * int(cnn.output_depth), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We forgot the output depth
here.
@@ -104,10 +113,10 @@ def get_input_spec(self) -> Union[Spec, None]: | |||
return SpecDict( | |||
{ | |||
SampleBatch.OBS: TorchTensorSpec( | |||
"b, w, h, d", | |||
"b, w, h, c", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c=channels
""" | ||
def _validate(self, framework: str = "torch"): | ||
super()._validate(framework) | ||
if self.output_dims is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You say in the docstring that this may be None. This conflicts with this check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
We are now allowing output_dims=None for any MLP config as for MLPs, simply the last hidden dim could be used.
This is useful for homogenous dense nets, where all layers have the same activation and no special output layer logic is needed.
rllib/core/models/configs.py
Outdated
hidden_layer_use_layernorm=False, | ||
output_dims=None, # maybe None or a 1D tensor | ||
) | ||
model = config.build() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
build should always take in a framework!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch! fixed.
rllib/core/models/configs.py
Outdated
output_activation="tanh", | ||
use_bias=False, | ||
) | ||
model = config.build() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
tf.keras.layers.Dense(config.output_dims[0], activation=output_activation), | ||
) | ||
|
||
self.net = tf.keras.Sequential(layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we unify the constructor a little more? So that the order of things and the comments are the same between Tf and Torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kouroshHakha I've been thinking about our specs and think we should pull the functionality of TfTensorSpec and TorchTensorSpec into TensorSpec and give TensorSpec a framework kwarg.
If the kwarg is None, simply don't enforce tensor framework and check based on the incoming tensor.
That way we could unify our specs..
-> Saves many LOCs
-> Reduce burden auf maintanance and always checking for equality between frameworks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Especially when rolling this out over RLLib and writing many new RLModules and possibly Models over time, the sepcs would become less of a burden.
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would get us another step close to having 99% of RLModule and Model code in base clases and having the framework specific classes only be separated by an attribute self.framework.
We are almost there for the PPORLModule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the idea. Let's do this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unified c'tors a little more, comments, structure, etc..
): | ||
"""Initialize a TorchCNN object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify TorchCNN and TfCNN docstrings?
What is now seen here as "Attributes" should be "Args", right?
I see similar stuff happening with TorchMLP!
Most of what is now listed in the class docstrings as attributes are not actually attributes but still args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
completely unified these now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, the primitives should have a Args
list, not Attributes
.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the massive amounts of small cleanups here!
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few nit and a question?
rllib/core/models/utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this to test_utils.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rllib/core/models/utils.py
Outdated
else: | ||
inputs[key] = None | ||
else: | ||
inputs = model.input_specs.fill(self.random_fill_input_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha, I kinda forgot about this fill thing. What a nice use-case man :)
rllib/core/models/utils.py
Outdated
# Bring model into a reproducible, comparable state (so we can compare | ||
# computations across frameworks). Use only a value-sequence of len=1 here | ||
# as it could possibly be that the layers are stored in different order | ||
# across the different frameworks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we can't reliably match the order across frameworks, does it make sense for us to support a sequence of values in _set_to_dummy_weights
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point. Probably not. But one might also want to just make a network repeatable compute the same outputs across network instantiations (over time), not necessarily across different frameworks.
rllib/core/models/utils.py
Outdated
main_key = next(iter(self.models.keys())) | ||
# Compare number of trainable and non-trainable params between all | ||
# frameworks. | ||
for c in self.param_counts.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we separate out param count and output checker into two different functions with their own control over accepted tolerance?
rllib/core/models/tf/mlp.py
Outdated
@@ -27,11 +27,11 @@ def __init__(self, config: MLPHeadConfig) -> None: | |||
) | |||
|
|||
@override(Model) | |||
def get_input_spec(self) -> Union[Spec, None]: | |||
def get_input_specs(self) -> Union[Spec, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Optional[Spec]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
rllib/core/models/tf/mlp.py
Outdated
) | ||
|
||
@override(Model) | ||
def get_input_spec(self) -> Union[Spec, None]: | ||
def get_input_specs(self) -> Union[Spec, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all fixed
@@ -104,3 +106,12 @@ def get_num_parameters(self) -> Tuple[int, int]: | |||
num_trainable_params, | |||
num_all_params - num_trainable_params, | |||
) | |||
|
|||
@override(Model) | |||
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the mechanism to ensure the order of parameters are the same between frameworks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure either. For torch, it's neither the order, in which you define Parameter properties in the ctor, nor alphabetical. Didn't have time to investigate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but using the same value for all trainable and non-trainable parameters should result in the same behavior between tf and torch. We are relying on this assumption, right?
rllib/core/models/torch/mlp.py
Outdated
@@ -28,11 +28,11 @@ def __init__(self, config: MLPHeadConfig) -> None: | |||
) | |||
|
|||
@override(Model) | |||
def get_input_spec(self) -> Union[Spec, None]: | |||
def get_input_specs(self) -> Union[Spec, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
rllib/core/models/torch/mlp.py
Outdated
) | ||
|
||
self.log_std = torch.nn.Parameter( | ||
torch.as_tensor([0.0] * self._half_output_dim) | ||
) | ||
|
||
@override(Model) | ||
def get_input_spec(self) -> Union[Spec, None]: | ||
def get_input_specs(self) -> Union[Spec, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same: Optional[Spec]
Let's merged upon addressing the questions. |
Also tests are failing. |
…mer_v3_catalog_enhancements_01
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…mer_v3_catalog_enhancements_01
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…eted and unified accross DL frameworks). (ray-project#33967) Signed-off-by: elliottower <[email protected]>
…eted and unified accross DL frameworks). (ray-project#33967) Signed-off-by: Jack He <[email protected]>
DreamerV3: Catalog enhancements (MLP/CNN encoders/heads completed and unified accross DL frameworks).
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.