Skip to content

Commit

Permalink
[RLlib] New API stack: (Multi)RLModule overhaul vol 05 (deprecate Spe…
Browse files Browse the repository at this point in the history
…cs, SpecDict, TensorSpec). (#47915)
  • Loading branch information
sven1977 authored Oct 9, 2024
1 parent 5e4b1bc commit 616eef8
Show file tree
Hide file tree
Showing 30 changed files with 77 additions and 2,698 deletions.
66 changes: 0 additions & 66 deletions doc/source/rllib/doc_code/rlmodule_guide.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# flake8: noqa
from ray.rllib.utils.annotations import override
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.models.specs.specs_base import TensorSpec


# __enabling-rlmodules-in-configs-begin__
Expand Down Expand Up @@ -224,70 +222,6 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# __write-custom-sa-rlmodule-tf-end__


# __extend-spec-checking-single-level-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def input_specs_exploration(self) -> SpecType:
# Enforce that input nested dict to exploration method has a key "obs"
return ["obs"]

@override(TorchRLModule)
def output_specs_exploration(self) -> SpecType:
# Enforce that output nested dict from exploration method has a key
# "action_dist"
return ["action_dist"]


# __extend-spec-checking-single-level-end__


# __extend-spec-checking-nested-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def input_specs_exploration(self) -> SpecType:
# Enforce that input nested dict to exploration method has a key "obs"
# and within that key, it has a key "global" and "local". There should
# also be a key "action_mask"
return [("obs", "global"), ("obs", "local"), "action_mask"]


# __extend-spec-checking-nested-end__


# __extend-spec-checking-torch-specs-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def input_specs_exploration(self) -> SpecType:
# Enforce that input nested dict to exploration method has a key "obs"
# and its value is a torch.Tensor with shape (b, h) where b is the
# batch size (determined at run-time) and h is the hidden size
# (fixed at 10).
return {"obs": TensorSpec("b, h", h=10, framework="torch")}


# __extend-spec-checking-torch-specs-end__


# __extend-spec-checking-type-specs-begin__
class DiscreteBCTorchModule(TorchRLModule):
...

@override(TorchRLModule)
def output_specs_exploration(self) -> SpecType:
# Enforce that output nested dict from exploration method has a key
# "action_dist" and its value is a torch.distribution.Categorical
return {"action_dist": torch.distribution.Categorical}


# __extend-spec-checking-type-specs-end__


# __write-custom-multirlmodule-shared-enc-begin__
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleConfig, MultiRLModule
Expand Down
48 changes: 0 additions & 48 deletions doc/source/rllib/rllib-rlmodule.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,54 +255,6 @@ When writing RL Modules, you need to use these fields to construct your model.
:end-before: __write-custom-sa-rlmodule-tf-end__


In :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` you can enforce the checking for the existence of certain input or output keys in the data that is communicated into and out of RL Modules. This serves multiple purposes:

- For the I/O requirement of each method to be self-documenting.
- For failures to happen quickly. If users extend the modules and implement something that does not match the assumptions of the I/O specs, the check reports missing keys and their expected format. For example, RLModule should always have an ``obs`` key in the input batch and an ``action_dist`` key in the output.

.. tab-set::

.. tab-item:: Single Level Keys

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-single-level-begin__
:end-before: __extend-spec-checking-single-level-end__

.. tab-item:: Nested Keys

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-nested-begin__
:end-before: __extend-spec-checking-nested-end__


.. tab-item:: TensorShape Spec

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-torch-specs-begin__
:end-before: __extend-spec-checking-torch-specs-end__


.. tab-item:: Type Spec

.. literalinclude:: doc_code/rlmodule_guide.py
:language: python
:start-after: __extend-spec-checking-type-specs-begin__
:end-before: __extend-spec-checking-type-specs-end__

:py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` has two methods for each forward method, totaling 6 methods that can be override to describe the specs of the input and output of each method:

- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_inference`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_inference`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_exploration`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_exploration`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_train`
- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_train`

To learn more, see the `SpecType` documentation.


Writing Custom Multi-Agent RL Modules (Advanced)
------------------------------------------------
Expand Down
22 changes: 0 additions & 22 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1456,28 +1456,6 @@ py_test(
srcs = ["core/models/tests/test_recurrent_encoders.py"]
)

# Specs
py_test(
name = "test_check_specs",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["core/models/specs/tests/test_check_specs.py"]
)

py_test(
name = "test_tensor_spec",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["core/models/specs/tests/test_tensor_spec.py"]
)

py_test(
name = "test_spec_dict",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["core/models/specs/tests/test_spec_dict.py"]
)

# RLModule
py_test(
name = "test_torch_rl_module",
Expand Down
33 changes: 0 additions & 33 deletions rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from ray.rllib.algorithms.dqn.torch.torch_noisy_linear import NoisyLinear
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model
from ray.rllib.core.models.specs.specs_base import Spec
from ray.rllib.core.models.specs.specs_base import TensorSpec
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.models.torch.base import TorchModel
from ray.rllib.core.models.torch.heads import auto_fold_unfold_time
from ray.rllib.models.utils import get_activation_fn, get_initializer_fn
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -52,26 +48,6 @@ def __init__(self, config: NoisyMLPEncoderConfig) -> None:
std_init=config.std_init,
)

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
return SpecDict(
{
Columns.OBS: TensorSpec(
"b, d", d=self.config.input_dims[0], framework="torch"
),
}
)

@override(Model)
def get_output_specs(self) -> Optional[Spec]:
return SpecDict(
{
ENCODER_OUT: TensorSpec(
"b, d", d=self.config.output_dims[0], framework="torch"
),
}
)

@override(Model)
def _forward(self, inputs: dict, **kwargs) -> dict:
return {ENCODER_OUT: self.net(inputs[Columns.OBS])}
Expand Down Expand Up @@ -113,15 +89,6 @@ def __init__(self, config: NoisyMLPHeadConfig) -> None:
)

@override(Model)
def get_input_specs(self) -> Optional[Spec]:
return TensorSpec("b, d", d=self.config.input_dims[0], framework="torch")

@override(Model)
def get_output_specs(self) -> Optional[Spec]:
return TensorSpec("b, d", d=self.config.output_dims[0], framework="torch")

@override(Model)
@auto_fold_unfold_time("input_specs")
def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor:
return self.net(inputs)

Expand Down
38 changes: 0 additions & 38 deletions rllib/algorithms/dreamerv3/dreamerv3_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ray.rllib.algorithms.dreamerv3.tf.models.dreamer_model import DreamerModel
from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.policy.eager_tf_policy import _convert_to_tf
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -123,43 +122,6 @@ def get_initial_state(self) -> Dict:
# Use `DreamerModel`'s `get_initial_state` method.
return self.dreamer_model.get_initial_state()

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS, Columns.STATE_IN, "is_first"]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTIONS, Columns.STATE_OUT]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return [Columns.OBS, Columns.ACTIONS, "is_first"]

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [
"sampled_obs_symlog_BxT",
"obs_distribution_means_BxT",
"reward_logits_BxT",
"rewards_BxT",
"continue_distribution_BxT",
"continues_BxT",
# Sampled, discrete posterior z-states (t1 to T).
"z_posterior_states_BxT",
"z_posterior_probs_BxT",
"z_prior_probs_BxT",
# Deterministic, continuous h-states (t1 to T).
"h_states_BxT",
]

@override(RLModule)
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# Call the Dreamer-Model's forward_inference method and return a dict.
Expand Down
26 changes: 0 additions & 26 deletions rllib/algorithms/marwil/marwil_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import abc

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module import RLModule
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.annotations import override
Expand All @@ -25,27 +23,3 @@ def get_initial_state(self) -> dict:
return self.encoder.get_initial_state()
else:
return {}

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]
26 changes: 0 additions & 26 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import abc
from typing import List

from ray.rllib.core.columns import Columns
from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
Expand Down Expand Up @@ -50,30 +48,6 @@ def get_initial_state(self) -> dict:
else:
return {}

@override(RLModule)
def input_specs_inference(self) -> SpecDict:
return [Columns.OBS]

@override(RLModule)
def output_specs_inference(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@override(RLModule)
def input_specs_exploration(self):
return self.input_specs_inference()

@override(RLModule)
def output_specs_exploration(self) -> SpecDict:
return self.output_specs_inference()

@override(RLModule)
def input_specs_train(self) -> SpecDict:
return self.input_specs_exploration()

@override(RLModule)
def output_specs_train(self) -> SpecDict:
return [Columns.ACTION_DIST_INPUTS]

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
Expand Down
Loading

0 comments on commit 616eef8

Please sign in to comment.