Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] DreamerV3: Catalog enhancements (MLP/CNN encoders/heads completed and unified accross DL frameworks). #33967

Merged

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Mar 31, 2023

DreamerV3: Catalog enhancements (MLP/CNN encoders/heads completed and unified accross DL frameworks).

  • MLP/CNN heads/encoders catalog completed
  • added use_bias option
  • added use_layernorm option
  • unified across DL frameworks
  • more tests (amongst other things: making sure, both torch and tf2 corresponding models have exact same number of trainable and non-trainable params and compute the exact same output values given equal weights and inputs)

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@@ -1861,6 +1861,27 @@ py_test(
srcs = ["core/models/tests/test_catalog.py"]
)

py_test(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unified tests between tf and torch to be able to compare exact number of model parameters (assert architectures are the same).

@@ -20,16 +20,25 @@
CRITIC: str = "critic"


def _raise_not_decorated_exception(class_and_method, input_or_output):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this here. This was duplicated code in tf and torch versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -193,6 +202,11 @@ def _forward(self, input_dict: NestedDict, **kwargs) -> NestedDict:
"""
raise NotImplementedError

@abc.abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this convenience method to the top-level API.

Attributes:
hidden_layer_dims: The sizes of the hidden layers.
hidden_layer_activation: The activation function to use after each layer (
except for the output).
hidden_layer_use_layernorm: Whether to insert a LayerNorm functionality
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to all primitives (MLP and CNN):

  • option to switch on layernorm'ing in between layers
  • use bias or not

@@ -216,37 +236,47 @@ class CNNEncoderConfig(ModelConfig):
Attributes:
input_dims: The input dimension of the network. These must be given in the
form of `(width, height, channels)`.
filter_specifiers: A list of lists, where each element of an inner list
cnn_filter_specifiers: A list of lists, where each element of an inner list
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to chose more accucate terms for CNN configs.

@@ -0,0 +1,129 @@
import unittest
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly the exact same file as before, but now unified between tf and torch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -0,0 +1,116 @@
import unittest
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly the exact same file as before, but now unified between tf and torch.

@@ -0,0 +1,65 @@
import abc
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this from another file. Now this is analogous to the already existing torch/base.py.
Both define the DL-specific base RLlib Model classes TfModel and TorchModel.

@@ -63,12 +71,75 @@ def _forward(self, inputs: NestedDict) -> NestedDict:
)


class TfCNNEncoder(TfModel, Encoder):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new class (tf did not have CNN encoder before).

self.network = tf.keras.Sequential(layers)

def __call__(self, inputs):
def call(self, inputs, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should touch __call__ directly, but override call instead. At least that's what keras docs say.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this turned out to be a problem. Now that we override call, we ran into a naming conflict with keras models, which have their own input_spec. We therefore decided to rename our properties into input_specs (plural) and output_specs, which alleviated this issue. Should be ok now.

Signed-off-by: sven1977 <[email protected]>
)

@override(Model)
def get_input_spec(self) -> Union[Spec, None]:
return SpecDict(
{
SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dims[0]),
SampleBatch.OBS: TorchTensorSpec("b, d", d=self.config.input_dims[0]),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed these "hanging right-side dims" to d. I feel like h should be used only for image height and LSTM internal h-state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d=dims

layers.append(nn.Flatten())

# Add a final linear layer to make sure that the outputs have the correct
# dimensionality.
layers.append(
nn.Linear(
int(cnn.output_width) * int(cnn.output_height), config.output_dims[0]
int(cnn.output_width) * int(cnn.output_height) * int(cnn.output_depth),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We forgot the output depth here.

@@ -104,10 +113,10 @@ def get_input_spec(self) -> Union[Spec, None]:
return SpecDict(
{
SampleBatch.OBS: TorchTensorSpec(
"b, w, h, d",
"b, w, h, c",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c=channels

"""
def _validate(self, framework: str = "torch"):
super()._validate(framework)
if self.output_dims is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You say in the docstring that this may be None. This conflicts with this check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

We are now allowing output_dims=None for any MLP config as for MLPs, simply the last hidden dim could be used.
This is useful for homogenous dense nets, where all layers have the same activation and no special output layer logic is needed.

hidden_layer_use_layernorm=False,
output_dims=None, # maybe None or a 1D tensor
)
model = config.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build should always take in a framework!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch! fixed.

output_activation="tanh",
use_bias=False,
)
model = config.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

tf.keras.layers.Dense(config.output_dims[0], activation=output_activation),
)

self.net = tf.keras.Sequential(layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we unify the constructor a little more? So that the order of things and the comments are the same between Tf and Torch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kouroshHakha I've been thinking about our specs and think we should pull the functionality of TfTensorSpec and TorchTensorSpec into TensorSpec and give TensorSpec a framework kwarg.

If the kwarg is None, simply don't enforce tensor framework and check based on the incoming tensor.
That way we could unify our specs..
-> Saves many LOCs
-> Reduce burden auf maintanance and always checking for equality between frameworks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially when rolling this out over RLLib and writing many new RLModules and possibly Models over time, the sepcs would become less of a burden.
Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would get us another step close to having 99% of RLModule and Model code in base clases and having the framework specific classes only be separated by an attribute self.framework.
We are almost there for the PPORLModule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the idea. Let's do this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified c'tors a little more, comments, structure, etc..

):
"""Initialize a TorchCNN object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify TorchCNN and TfCNN docstrings?
What is now seen here as "Attributes" should be "Args", right?
I see similar stuff happening with TorchMLP!
Most of what is now listed in the class docstrings as attributes are not actually attributes but still args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

completely unified these now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, the primitives should have a Args list, not Attributes.
Fixed.

Copy link
Contributor

@ArturNiederfahrenhorst ArturNiederfahrenhorst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the massive amounts of small cleanups here!

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few nit and a question?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this to test_utils.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
inputs[key] = None
else:
inputs = model.input_specs.fill(self.random_fill_input_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, I kinda forgot about this fill thing. What a nice use-case man :)

# Bring model into a reproducible, comparable state (so we can compare
# computations across frameworks). Use only a value-sequence of len=1 here
# as it could possibly be that the layers are stored in different order
# across the different frameworks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we can't reliably match the order across frameworks, does it make sense for us to support a sequence of values in _set_to_dummy_weights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. Probably not. But one might also want to just make a network repeatable compute the same outputs across network instantiations (over time), not necessarily across different frameworks.

main_key = next(iter(self.models.keys()))
# Compare number of trainable and non-trainable params between all
# frameworks.
for c in self.param_counts.values():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we separate out param count and output checker into two different functions with their own control over accepted tolerance?

@@ -27,11 +27,11 @@ def __init__(self, config: MLPHeadConfig) -> None:
)

@override(Model)
def get_input_spec(self) -> Union[Spec, None]:
def get_input_specs(self) -> Union[Spec, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Optional[Spec]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

)

@override(Model)
def get_input_spec(self) -> Union[Spec, None]:
def get_input_specs(self) -> Union[Spec, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all fixed

@@ -104,3 +106,12 @@ def get_num_parameters(self) -> Tuple[int, int]:
num_trainable_params,
num_all_params - num_trainable_params,
)

@override(Model)
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the mechanism to ensure the order of parameters are the same between frameworks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure either. For torch, it's neither the order, in which you define Parameter properties in the ctor, nor alphabetical. Didn't have time to investigate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but using the same value for all trainable and non-trainable parameters should result in the same behavior between tf and torch. We are relying on this assumption, right?

@@ -28,11 +28,11 @@ def __init__(self, config: MLPHeadConfig) -> None:
)

@override(Model)
def get_input_spec(self) -> Union[Spec, None]:
def get_input_specs(self) -> Union[Spec, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

)

self.log_std = torch.nn.Parameter(
torch.as_tensor([0.0] * self._half_output_dim)
)

@override(Model)
def get_input_spec(self) -> Union[Spec, None]:
def get_input_specs(self) -> Union[Spec, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same: Optional[Spec]

@kouroshHakha
Copy link
Contributor

Let's merged upon addressing the questions.

@kouroshHakha
Copy link
Contributor

Also tests are failing.

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Apr 6, 2023
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 requested a review from a team as a code owner April 7, 2023 10:06
@sven1977 sven1977 merged commit 3f6e084 into ray-project:master Apr 7, 2023
elliottower pushed a commit to elliottower/ray that referenced this pull request Apr 22, 2023
…eted and unified accross DL frameworks). (ray-project#33967)

Signed-off-by: elliottower <[email protected]>
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
…eted and unified accross DL frameworks). (ray-project#33967)

Signed-off-by: Jack He <[email protected]>
@sven1977 sven1977 deleted the dreamer_v3_catalog_enhancements_01 branch May 5, 2023 20:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants