Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New API stack: (Multi)RLModule overhaul vol 05 (deprecate Specs, SpecDict, TensorSpec). #47915

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 0 additions & 66 deletions doc/source/rllib/doc_code/rlmodule_guide.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# flake8: noqa
from ray.rllib.utils.annotations import override
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.models.specs.specs_base import TensorSpec


# __enabling-rlmodules-in-configs-begin__
Expand Down Expand Up @@ -224,70 +222,6 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# __write-custom-sa-rlmodule-tf-end__


# __extend-spec-checking-single-level-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def input_specs_exploration(self) -> SpecType:
# Enforce that input nested dict to exploration method has a key "obs"
return ["obs"]

@override(TorchRLModule)
def output_specs_exploration(self) -> SpecType:
# Enforce that output nested dict from exploration method has a key
# "action_dist"
return ["action_dist"]


# __extend-spec-checking-single-level-end__


# __extend-spec-checking-nested-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def input_specs_exploration(self) -> SpecType:
# Enforce that input nested dict to exploration method has a key "obs"
# and within that key, it has a key "global" and "local". There should
# also be a key "action_mask"
return [("obs", "global"), ("obs", "local"), "action_mask"]


# __extend-spec-checking-nested-end__


# __extend-spec-checking-torch-specs-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def input_specs_exploration(self) -> SpecType:
# Enforce that input nested dict to exploration method has a key "obs"
# and its value is a torch.Tensor with shape (b, h) where b is the
# batch size (determined at run-time) and h is the hidden size
# (fixed at 10).
return {"obs": TensorSpec("b, h", h=10, framework="torch")}


# __extend-spec-checking-torch-specs-end__


# __extend-spec-checking-type-specs-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def output_specs_exploration(self) -> SpecType:
# Enforce that output nested dict from exploration method has a key
# "action_dist" and its value is a torch.distribution.Categorical
return {"action_dist": torch.distribution.Categorical}


# __extend-spec-checking-type-specs-end__


# __write-custom-multirlmodule-shared-enc-begin__
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleConfig, MultiRLModule
Expand Down
48 changes: 0 additions & 48 deletions doc/source/rllib/rllib-rlmodule.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,54 +255,6 @@ When writing RL Modules, you need to use these fields to construct your model.
:end-before: __write-custom-sa-rlmodule-tf-end__


In :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` you can enforce the checking for the existence of certain input or output keys in the data that is communicated into and out of RL Modules. This serves multiple purposes:

- For the I/O requirement of each method to be self-documenting.
- For failures to happen quickly. If users extend the modules and implement something that does not match the assumptions of the I/O specs, the check reports missing keys and their expected format. For example, RLModule should always have an ``obs`` key in the input batch and an ``action_dist`` key in the output.

.. tab-set::

.. tab-item:: Single Level Keys

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-single-level-begin__
:end-before: __extend-spec-checking-single-level-end__

.. tab-item:: Nested Keys

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-nested-begin__
:end-before: __extend-spec-checking-nested-end__


.. tab-item:: TensorShape Spec

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-torch-specs-begin__
:end-before: __extend-spec-checking-torch-specs-end__


.. tab-item:: Type Spec

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-type-specs-begin__
:end-before: __extend-spec-checking-type-specs-end__

:py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` has two methods for each forward method, totaling 6 methods that can be override to describe the specs of the input and output of each method:

- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_inference`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_inference`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_exploration`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_exploration`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_train`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_train`

To learn more, see the `SpecType` documentation.


Writing Custom Multi-Agent RL Modules (Advanced)
------------------------------------------------
Expand Down
22 changes: 0 additions & 22 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1456,28 +1456,6 @@ py_test(
srcs = ["core/models/tests/test_recurrent_encoders.py"]
)

# Specs
py_test(
name = "test_check_specs",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["core/models/specs/tests/test_check_specs.py"]
)

py_test(
name = "test_tensor_spec",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["core/models/specs/tests/test_tensor_spec.py"]
)

py_test(
name = "test_spec_dict",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["core/models/specs/tests/test_spec_dict.py"]
)

# RLModule
py_test(
name = "test_torch_rl_module",
Expand Down
33 changes: 0 additions & 33 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from ray.rllib.algorithms.dqn.torch.torch_noisy_linear import NoisyLinear
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model
from ray.rllib.core.models.specs.specs_base import Spec
from ray.rllib.core.models.specs.specs_base import TensorSpec
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.models.torch.base import TorchModel
from ray.rllib.core.models.torch.heads import auto_fold_unfold_time
from ray.rllib.models.utils import get_activation_fn, get_initializer_fn
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -52,26 +48,6 @@ def __init__(self, config: NoisyMLPEncoderConfig) -> None:
std_init=config.std_init,
)

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
return SpecDict(
{
Columns.OBS: TensorSpec(
"b, d", d=self.config.input_dims[0], framework="torch"
),
}
)

@override(Model)
def get_output_specs(self) -> Optional[Spec]:
return SpecDict(
{
ENCODER_OUT: TensorSpec(
"b, d", d=self.config.output_dims[0], framework="torch"
),
}
)

@override(Model)
def _forward(self, inputs: dict, **kwargs) -> dict:
return {ENCODER_OUT: self.net(inputs[Columns.OBS])}
Expand Down Expand Up @@ -113,15 +89,6 @@ def __init__(self, config: NoisyMLPHeadConfig) -> None:
)

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
return TensorSpec("b, d", d=self.config.input_dims[0], framework="torch")

@override(Model)
def get_output_specs(self) -> Optional[Spec]:
return TensorSpec("b, d", d=self.config.output_dims[0], framework="torch")

@override(Model)
@auto_fold_unfold_time("input_specs")
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(inputs)

Expand Down
38 changes: 0 additions & 38 deletions rllib/algorithms/dreamerv3/dreamerv3_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ray.rllib.algorithms.dreamerv3.tf.models.dreamer_model import DreamerModel
from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.policy.eager_tf_policy import _convert_to_tf
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -123,43 +122,6 @@ def get_initial_state(self) -> Dict:
# Use `DreamerModel`'s `get_initial_state` method.
return self.dreamer_model.get_initial_state()

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS, Columns.STATE_IN, "is_first"]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTIONS, Columns.STATE_OUT]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return [Columns.OBS, Columns.ACTIONS, "is_first"]

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave these as a kind of comment/docstring in the forward method? Maybe we do this for all default algos such that we, new RLlib team members and users always know what is returned (received).

"sampled_obs_symlog_BxT",
"obs_distribution_means_BxT",
"reward_logits_BxT",
"rewards_BxT",
"continue_distribution_BxT",
"continues_BxT",
# Sampled, discrete posterior z-states (t1 to T).
"z_posterior_states_BxT",
"z_posterior_probs_BxT",
"z_prior_probs_BxT",
# Deterministic, continuous h-states (t1 to T).
"h_states_BxT",
]

@override(RLModule)
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# Call the Dreamer-Model's forward_inference method and return a dict.
Expand Down
26 changes: 0 additions & 26 deletions rllib/algorithms/marwil/marwil_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import abc

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module import RLModule
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.annotations import override
Expand All @@ -25,27 +23,3 @@ def get_initial_state(self) -> dict:
return self.encoder.get_initial_state()
else:
return {}

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]
26 changes: 0 additions & 26 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import abc
from typing import List

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -50,30 +48,6 @@ def get_initial_state(self) -> dict:
else:
return {}

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
Expand Down
Loading
Loading