diff --git a/doc/source/rllib/doc_code/rlmodule_guide.py b/doc/source/rllib/doc_code/rlmodule_guide.py index 4e5b8bc45245..7e0be191ebf7 100644 --- a/doc/source/rllib/doc_code/rlmodule_guide.py +++ b/doc/source/rllib/doc_code/rlmodule_guide.py @@ -1,7 +1,5 @@ # flake8: noqa from ray.rllib.utils.annotations import override -from ray.rllib.core.models.specs.typing import SpecType -from ray.rllib.core.models.specs.specs_base import TensorSpec # __enabling-rlmodules-in-configs-begin__ @@ -224,70 +222,6 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: # __write-custom-sa-rlmodule-tf-end__ -# __extend-spec-checking-single-level-begin__ -class DiscreteBCTorchModule(TorchRLModule): - ... - - @override(TorchRLModule) - def input_specs_exploration(self) -> SpecType: - # Enforce that input nested dict to exploration method has a key "obs" - return ["obs"] - - @override(TorchRLModule) - def output_specs_exploration(self) -> SpecType: - # Enforce that output nested dict from exploration method has a key - # "action_dist" - return ["action_dist"] - - -# __extend-spec-checking-single-level-end__ - - -# __extend-spec-checking-nested-begin__ -class DiscreteBCTorchModule(TorchRLModule): - ... - - @override(TorchRLModule) - def input_specs_exploration(self) -> SpecType: - # Enforce that input nested dict to exploration method has a key "obs" - # and within that key, it has a key "global" and "local". There should - # also be a key "action_mask" - return [("obs", "global"), ("obs", "local"), "action_mask"] - - -# __extend-spec-checking-nested-end__ - - -# __extend-spec-checking-torch-specs-begin__ -class DiscreteBCTorchModule(TorchRLModule): - ... - - @override(TorchRLModule) - def input_specs_exploration(self) -> SpecType: - # Enforce that input nested dict to exploration method has a key "obs" - # and its value is a torch.Tensor with shape (b, h) where b is the - # batch size (determined at run-time) and h is the hidden size - # (fixed at 10). - return {"obs": TensorSpec("b, h", h=10, framework="torch")} - - -# __extend-spec-checking-torch-specs-end__ - - -# __extend-spec-checking-type-specs-begin__ -class DiscreteBCTorchModule(TorchRLModule): - ... - - @override(TorchRLModule) - def output_specs_exploration(self) -> SpecType: - # Enforce that output nested dict from exploration method has a key - # "action_dist" and its value is a torch.distribution.Categorical - return {"action_dist": torch.distribution.Categorical} - - -# __extend-spec-checking-type-specs-end__ - - # __write-custom-multirlmodule-shared-enc-begin__ from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleConfig, MultiRLModule diff --git a/doc/source/rllib/rllib-rlmodule.rst b/doc/source/rllib/rllib-rlmodule.rst index 80b36aab54c0..b6505c71f794 100644 --- a/doc/source/rllib/rllib-rlmodule.rst +++ b/doc/source/rllib/rllib-rlmodule.rst @@ -255,54 +255,6 @@ When writing RL Modules, you need to use these fields to construct your model. :end-before: __write-custom-sa-rlmodule-tf-end__ -In :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` you can enforce the checking for the existence of certain input or output keys in the data that is communicated into and out of RL Modules. This serves multiple purposes: - -- For the I/O requirement of each method to be self-documenting. -- For failures to happen quickly. If users extend the modules and implement something that does not match the assumptions of the I/O specs, the check reports missing keys and their expected format. For example, RLModule should always have an ``obs`` key in the input batch and an ``action_dist`` key in the output. - -.. tab-set:: - - .. tab-item:: Single Level Keys - - .. literalinclude:: doc_code/rlmodule_guide.py - :language: python - :start-after: __extend-spec-checking-single-level-begin__ - :end-before: __extend-spec-checking-single-level-end__ - - .. tab-item:: Nested Keys - - .. literalinclude:: doc_code/rlmodule_guide.py - :language: python - :start-after: __extend-spec-checking-nested-begin__ - :end-before: __extend-spec-checking-nested-end__ - - - .. tab-item:: TensorShape Spec - - .. literalinclude:: doc_code/rlmodule_guide.py - :language: python - :start-after: __extend-spec-checking-torch-specs-begin__ - :end-before: __extend-spec-checking-torch-specs-end__ - - - .. tab-item:: Type Spec - - .. literalinclude:: doc_code/rlmodule_guide.py - :language: python - :start-after: __extend-spec-checking-type-specs-begin__ - :end-before: __extend-spec-checking-type-specs-end__ - -:py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` has two methods for each forward method, totaling 6 methods that can be override to describe the specs of the input and output of each method: - -- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_inference` -- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_inference` -- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_exploration` -- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_exploration` -- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.input_specs_train` -- :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.output_specs_train` - -To learn more, see the `SpecType` documentation. - Writing Custom Multi-Agent RL Modules (Advanced) ------------------------------------------------ diff --git a/rllib/BUILD b/rllib/BUILD index 47183691b1d6..fce667712b3a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1456,28 +1456,6 @@ py_test( srcs = ["core/models/tests/test_recurrent_encoders.py"] ) -# Specs -py_test( - name = "test_check_specs", - tags = ["team:rllib", "models"], - size = "small", - srcs = ["core/models/specs/tests/test_check_specs.py"] -) - -py_test( - name = "test_tensor_spec", - tags = ["team:rllib", "models"], - size = "small", - srcs = ["core/models/specs/tests/test_tensor_spec.py"] -) - -py_test( - name = "test_spec_dict", - tags = ["team:rllib", "models"], - size = "small", - srcs = ["core/models/specs/tests/test_spec_dict.py"] -) - # RLModule py_test( name = "test_torch_rl_module", diff --git a/rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py b/rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py index d8f01c80181a..ddd8492e6eb9 100644 --- a/rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py +++ b/rllib/algorithms/dqn/torch/dqn_rainbow_torch_noisy_net.py @@ -7,11 +7,7 @@ from ray.rllib.algorithms.dqn.torch.torch_noisy_linear import NoisyLinear from ray.rllib.core.columns import Columns from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model -from ray.rllib.core.models.specs.specs_base import Spec -from ray.rllib.core.models.specs.specs_base import TensorSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.models.torch.base import TorchModel -from ray.rllib.core.models.torch.heads import auto_fold_unfold_time from ray.rllib.models.utils import get_activation_fn, get_initializer_fn from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch @@ -52,26 +48,6 @@ def __init__(self, config: NoisyMLPEncoderConfig) -> None: std_init=config.std_init, ) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - Columns.OBS: TensorSpec( - "b, d", d=self.config.input_dims[0], framework="torch" - ), - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, d", d=self.config.output_dims[0], framework="torch" - ), - } - ) - @override(Model) def _forward(self, inputs: dict, **kwargs) -> dict: return {ENCODER_OUT: self.net(inputs[Columns.OBS])} @@ -113,15 +89,6 @@ def __init__(self, config: NoisyMLPHeadConfig) -> None: ) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="torch") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.output_dims[0], framework="torch") - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: return self.net(inputs) diff --git a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py index add836028015..d7bbd32825a1 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py @@ -14,7 +14,6 @@ from ray.rllib.algorithms.dreamerv3.tf.models.dreamer_model import DreamerModel from ray.rllib.algorithms.dreamerv3.tf.models.world_model import WorldModel from ray.rllib.core.columns import Columns -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.policy.eager_tf_policy import _convert_to_tf from ray.rllib.utils.annotations import override @@ -123,43 +122,6 @@ def get_initial_state(self) -> Dict: # Use `DreamerModel`'s `get_initial_state` method. return self.dreamer_model.get_initial_state() - @override(RLModule) - def input_specs_inference(self) -> SpecDict: - return [Columns.OBS, Columns.STATE_IN, "is_first"] - - @override(RLModule) - def output_specs_inference(self) -> SpecDict: - return [Columns.ACTIONS, Columns.STATE_OUT] - - @override(RLModule) - def input_specs_exploration(self): - return self.input_specs_inference() - - @override(RLModule) - def output_specs_exploration(self) -> SpecDict: - return self.output_specs_inference() - - @override(RLModule) - def input_specs_train(self) -> SpecDict: - return [Columns.OBS, Columns.ACTIONS, "is_first"] - - @override(RLModule) - def output_specs_train(self) -> SpecDict: - return [ - "sampled_obs_symlog_BxT", - "obs_distribution_means_BxT", - "reward_logits_BxT", - "rewards_BxT", - "continue_distribution_BxT", - "continues_BxT", - # Sampled, discrete posterior z-states (t1 to T). - "z_posterior_states_BxT", - "z_posterior_probs_BxT", - "z_prior_probs_BxT", - # Deterministic, continuous h-states (t1 to T). - "h_states_BxT", - ] - @override(RLModule) def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: # Call the Dreamer-Model's forward_inference method and return a dict. diff --git a/rllib/algorithms/marwil/marwil_rl_module.py b/rllib/algorithms/marwil/marwil_rl_module.py index a0e5a40db4f9..be72909e1cd9 100644 --- a/rllib/algorithms/marwil/marwil_rl_module.py +++ b/rllib/algorithms/marwil/marwil_rl_module.py @@ -1,7 +1,5 @@ import abc -from ray.rllib.core.columns import Columns -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.rl_module import RLModule from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.utils.annotations import override @@ -25,27 +23,3 @@ def get_initial_state(self) -> dict: return self.encoder.get_initial_state() else: return {} - - @override(RLModule) - def input_specs_inference(self) -> SpecDict: - return [Columns.OBS] - - @override(RLModule) - def output_specs_inference(self) -> SpecDict: - return [Columns.ACTION_DIST_INPUTS] - - @override(RLModule) - def input_specs_exploration(self): - return self.input_specs_inference() - - @override(RLModule) - def output_specs_exploration(self) -> SpecDict: - return self.output_specs_inference() - - @override(RLModule) - def input_specs_train(self) -> SpecDict: - return self.input_specs_exploration() - - @override(RLModule) - def output_specs_train(self) -> SpecDict: - return [Columns.ACTION_DIST_INPUTS] diff --git a/rllib/algorithms/ppo/ppo_rl_module.py b/rllib/algorithms/ppo/ppo_rl_module.py index 5c48ab7af7b6..8fceaa715e0e 100644 --- a/rllib/algorithms/ppo/ppo_rl_module.py +++ b/rllib/algorithms/ppo/ppo_rl_module.py @@ -5,9 +5,7 @@ import abc from typing import List -from ray.rllib.core.columns import Columns from ray.rllib.core.models.configs import RecurrentEncoderConfig -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( @@ -50,30 +48,6 @@ def get_initial_state(self) -> dict: else: return {} - @override(RLModule) - def input_specs_inference(self) -> SpecDict: - return [Columns.OBS] - - @override(RLModule) - def output_specs_inference(self) -> SpecDict: - return [Columns.ACTION_DIST_INPUTS] - - @override(RLModule) - def input_specs_exploration(self): - return self.input_specs_inference() - - @override(RLModule) - def output_specs_exploration(self) -> SpecDict: - return self.output_specs_inference() - - @override(RLModule) - def input_specs_train(self) -> SpecDict: - return self.input_specs_exploration() - - @override(RLModule) - def output_specs_train(self) -> SpecDict: - return [Columns.ACTION_DIST_INPUTS] - @OverrideToImplementCustomLogic_CallToSuperRecommended @override(InferenceOnlyAPI) def get_non_inference_attributes(self) -> List[str]: diff --git a/rllib/core/models/base.py b/rllib/core/models/base.py index 8339be323118..3bb6304449a5 100644 --- a/rllib/core/models/base.py +++ b/rllib/core/models/base.py @@ -230,8 +230,6 @@ def __init__(self, config): super().__init__(config) self.factor = config.factor - @check_input_specs("input_specs") - @check_output_specs("output_specs") def __call__(self, *args, **kwargs): # This is a dummy method to do checked forward passes. return self._forward(*args, **kwargs) @@ -263,14 +261,6 @@ def build(self, framework: str): """ - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return [Columns.OBS] - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return [] - @abc.abstractmethod def _forward(self, input_dict: dict, **kwargs) -> dict: """Returns the latent of the encoder for the given inputs. @@ -324,18 +314,6 @@ def __init__(self, config: ModelConfig) -> None: framework=self.framework ) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return [Columns.OBS] - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return [(ENCODER_OUT, ACTOR)] + ( - [(ENCODER_OUT, CRITIC)] - if not self.config.shared and self.critic_encoder - else [] - ) - @override(Model) def _forward(self, inputs: dict, **kwargs) -> dict: if self.config.shared: @@ -399,14 +377,6 @@ def __init__(self, config: ModelConfig) -> None: framework=self.framework ) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return [Columns.OBS, Columns.STATE_IN] - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return [(ENCODER_OUT, ACTOR), (ENCODER_OUT, CRITIC), (Columns.STATE_OUT,)] - @override(Model) def get_initial_state(self): if self.config.shared: diff --git a/rllib/core/models/specs/checker.py b/rllib/core/models/specs/checker.py deleted file mode 100644 index 1ff3ead16ecd..000000000000 --- a/rllib/core/models/specs/checker.py +++ /dev/null @@ -1,385 +0,0 @@ -import functools -import logging -from collections import abc -from typing import Any, Callable, Dict - -from ray.rllib.core.models.specs.specs_base import Spec, TypeSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict -from ray.rllib.core.models.specs.typing import SpecType -from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning -from ray.util.annotations import DeveloperAPI - -logger = logging.getLogger(__name__) - - -@DeveloperAPI -class SpecCheckingError(Exception): - """Raised when there is an error in the spec checking. - - This Error is raised when inputs or outputs do match the defined specs. - """ - - -@DeveloperAPI -def convert_to_canonical_format(spec: SpecType): - """Converts a spec type input to the canonical format. - - The canonical format is either - - 1. A nested SpecDict when the input spec is dict-like tree of specs and types or - nested list of nested_keys. - 2. A single SpecType object, if the spec is a single constraint. - - The input can be any of the following: - - a list of nested_keys. nested_keys are either strings or tuples of strings - specifying the path to a leaf in the tree. - - a tree of constraints. The tree structure can be specified as any nested - hash-map structure (e.g. dict, SpecDict, etc.) The leaves of the - tree can be either a Spec object, a type, or None. If the leaf is a type, it is - converted to a TypeSpec. If the leaf is None, only the existance of the key is - checked and the value will be None in the canonical format. - - a single constraint. The constraint can be a Spec object, a type, or None. - - Args: - spec: The spec to convert to canonical format. - - Returns: - The canonical format of the spec. - """ - # convert spec of form list of nested_keys to model_spec with None leaves - if isinstance(spec, list): - - def _to_nested(tup): - nested_dict = current = {} - last_dict = {} - key = None - for key in tup: - current[key] = {} - last_dict = current - current = current[key] - last_dict[key] = None # Set the innermost value to None - return nested_dict - - spec_dict = {} - for k in spec: - if not isinstance(k, tuple): - spec_dict[k] = None - elif len(k) == 1: - spec_dict[k[0]] = None - else: - spec_dict[k[0]] = _to_nested(k[1:]) - return SpecDict(spec_dict) - - # convert spec of form tree of constraints to model_spec - if isinstance(spec, abc.Mapping): - spec = SpecDict(spec) - for key in spec: - # If values are types or tuple of types, convert to TypeSpec. - if isinstance(spec[key], (type, tuple)): - spec[key] = TypeSpec(spec[key]) - elif isinstance(spec[key], list): - # This enables nested conversion of none-canonical formats. - spec[key] = convert_to_canonical_format(spec[key]) - return spec - - if isinstance(spec, type): - return TypeSpec(spec) - - # otherwise, assume spec is already in canonical format - return spec - - -def _should_validate( - cls_instance: object, method: Callable, tag: str = "input" -) -> bool: - """Returns True if the spec should be validated, False otherwise. - - The spec should be validated if the method is not cached (i.e. there is no cache - storage attribute in the instance) or if the method is already cached. (i.e. it - exists in the cache storage attribute) - - Args: - cls_instance: The class instance that the method belongs to. - method: The method to apply the spec checking to. - tag: The tag of the spec to check. Either "input" or "output". This is used - internally to defined an internal cache storage attribute based on the tag. - - Returns: - True if the spec should be validated, False otherwise. - """ - cache_store = getattr(cls_instance, f"__checked_{tag}_specs_cache__", None) - return cache_store is None or method.__name__ not in cache_store - - -def _validate( - *, - cls_instance: object, - method: Callable, - data: Dict[str, Any], - spec: Spec, - tag: str = "input", - filter=DEPRECATED_VALUE, -) -> Dict: - """Validate the data against the spec. - - Args: - cls_instance: The class instance that the method belongs to. - method: The method to apply the spec checking to. - data: The data to validate. - spec: The spec to validate against. - tag: The tag of the spec to check. Either "input" or "output". This is used - internally to defined an internal cache storage attribute based on the tag. - - Returns: - The data, filtered if filter is True. - """ - if filter != DEPRECATED_VALUE: - deprecation_warning(old="_validate(filter=...)", error=True) - - cache_miss = _should_validate(cls_instance, method, tag=tag) - - if isinstance(spec, SpecDict): - if not isinstance(data, abc.Mapping): - raise ValueError(f"{tag} must be a Mapping, got {type(data).__name__}") - - if cache_miss: - try: - spec.validate(data) - except ValueError as e: - raise SpecCheckingError( - f"{tag} spec validation failed on " - f"{cls_instance.__class__.__name__}.{method.__name__}, {e}." - ) - - return data - - -@DeveloperAPI(stability="alpha") -def check_input_specs( - input_specs: str, - *, - only_check_on_retry: bool = True, - cache: bool = True, - filter=DEPRECATED_VALUE, -): - """A general-purpose spec checker decorator for neural network base classes. - - This is a stateful decorator - (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to - enforce input specs for any instance method that has an argument named - `batch` in its args. - - See more examples in ../tests/test_specs_dict.py) - - .. testcode:: - - import torch - from torch import nn - from ray.rllib.core.models.specs.specs_base import TensorSpec - - class MyModel(nn.Module): - @property - def input_specs(self): - return {"obs": TensorSpec("b, d", d=64)} - - @check_input_specs("input_specs", only_check_on_retry=False) - def forward(self, batch, return_loss=False): - ... - - model = MyModel() - model.forward({"obs": torch.randn(32, 64)}) - - # The following would raise an Error - # model.forward({"obs": torch.randn(32, 32)}) - - Args: - func: The instance method to decorate. It should be a callable that takes - `self` as the first argument, `batch` as the second argument and any - other keyword argument thereafter. - input_specs: `self` should have an instance attribute whose name matches the - string in input_specs and returns the `SpecDict`, `Spec`, or simply the - `Type` that the `batch` should comply with. It can also be None or - empty list / dict to enforce no input spec. - only_check_on_retry: If True, the spec will not be checked. Only if the - decorated method raises an Exception, we check the spec to provide a more - informative error message. - cache: If True, only checks the data validation for the first time the - instance method is called. - - Returns: - A wrapped instance method. In case of `cache=True`, after the first invokation - of the decorated method, the intance will have `__checked_input_specs_cache__` - attribute that stores which method has been invoked at least once. This is a - special attribute that can be used for the cache itself. The wrapped class - method also has a special attribute `__checked_input_specs__` that marks the - method as decorated. - """ - - if filter != DEPRECATED_VALUE: - deprecation_warning(old="check_input_specs(filter=...)", error=True) - - def decorator(func): - @functools.wraps(func) - def wrapper(self, batch, **kwargs): - if cache and not hasattr(self, "__checked_input_specs_cache__"): - self.__checked_input_specs_cache__ = {} - - initial_exception = None - if only_check_on_retry: - # Attempt to run the function without spec checking - try: - return func(self, batch, **kwargs) - except SpecCheckingError as e: - raise e - except Exception as e: - # We store the initial exception to raise it later if the spec - # check fails. - initial_exception = e - logger.error( - f"Exception {e} raised on function call without checking " - f"input specs. RLlib will now attempt to check the spec " - f"before calling the function again ..." - ) - - # If the function was not executed successfully yet, we check specs - checked_data = batch - - if input_specs and ( - initial_exception - or not cache - or func.__name__ not in self.__checked_input_specs_cache__ - or filter - ): - if hasattr(self, input_specs): - spec = getattr(self, input_specs) - else: - raise SpecCheckingError( - f"object {self} has no attribute {input_specs}." - ) - - if spec is not None: - spec = convert_to_canonical_format(spec) - checked_data = _validate( - cls_instance=self, - method=func, - data=batch, - spec=spec, - tag="input", - ) - - # If we have encountered an exception from calling `func` already, - # we raise it again here and don't need to call func again. - if initial_exception: - raise initial_exception - - if cache and func.__name__ not in self.__checked_input_specs_cache__: - self.__checked_input_specs_cache__[func.__name__] = True - - return func(self, checked_data, **kwargs) - - wrapper.__checked_input_specs__ = True - return wrapper - - return decorator - - -@DeveloperAPI(stability="alpha") -def check_output_specs( - output_specs: str, - *, - cache: bool = True, -): - """A general-purpose spec checker decorator for Neural Network base classes. - - This is a stateful decorator - (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to - enforce output specs for any instance method that outputs a single dict-like object. - - It also allows you to cache the validation to make sure the spec is only validated - once in the entire lifetime of the instance. - - Examples (See more examples in ../tests/test_specs_dict.py): - - .. testcode:: - - import torch - from torch import nn - from ray.rllib.core.models.specs.specs_base import TensorSpec - - class MyModel(nn.Module): - @property - def output_specs(self): - return {"obs": TensorSpec("b, d", d=64)} - - @check_output_specs("output_specs") - def forward(self, batch, return_loss=False): - return {"obs": torch.randn(32, 64)} - - Args: - func: The instance method to decorate. It should be a callable that takes - `self` as the first argument, `batch` as the second argument and any - other keyword argument thereafter. It should return a single dict-like - object (i.e. not a tuple). - output_specs: `self` should have an instance attribute whose name matches the - string in output_specs and returns the `SpecDict`, `Spec`, or simply the - `Type` that the `batch` should comply with. It can alos be None or - empty list / dict to enforce no input spec. - cache: If True, only checks the data validation for the first time the - instance method is called. - - Returns: - A wrapped instance method. In case of `cache=True`, after the first invokation - of the decorated method, the intance will have `__checked_output_specs_cache__` - attribute that stores which method has been invoked at least once. This is a - special attribute that can be used for the cache itself. The wrapped class - method also has a special attribute `__checked_output_specs__` that marks the - method as decorated. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(self, batch, **kwargs): - if cache and not hasattr(self, "__checked_output_specs_cache__"): - self.__checked_output_specs_cache__ = {} - - output_data = func(self, batch, **kwargs) - - if output_specs and ( - not cache or func.__name__ not in self.__checked_output_specs_cache__ - ): - if hasattr(self, output_specs): - spec = getattr(self, output_specs) - else: - raise ValueError(f"object {self} has no attribute {output_specs}.") - - if spec is not None: - spec = convert_to_canonical_format(spec) - _validate( - cls_instance=self, - method=func, - data=output_data, - spec=spec, - tag="output", - ) - - if cache and func.__name__ not in self.__checked_output_specs_cache__: - self.__checked_output_specs_cache__[func.__name__] = True - - return output_data - - wrapper.__checked_output_specs__ = True - return wrapper - - return decorator - - -@DeveloperAPI -def is_input_decorated(obj: object) -> bool: - """Returns True if the object is decorated with `check_input_specs`.""" - return hasattr(obj, "__checked_input_specs__") - - -@DeveloperAPI -def is_output_decorated(obj: object) -> bool: - """Returns True if the object is decorated with `check_output_specs`.""" - return hasattr(obj, "__checked_output_specs__") diff --git a/rllib/core/models/specs/specs_base.py b/rllib/core/models/specs/specs_base.py index 4bb07741be49..9099da941002 100644 --- a/rllib/core/models/specs/specs_base.py +++ b/rllib/core/models/specs/specs_base.py @@ -3,9 +3,7 @@ import numpy as np from typing import Any, Optional, Dict, List, Tuple, Union, Type from ray.rllib.utils import try_import_jax, try_import_tf, try_import_torch -from ray.rllib.utils.annotations import OverrideToImplementCustomLogic - -from ray.rllib.utils.annotations import DeveloperAPI, override +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.typing import TensorType torch, _ = try_import_torch() @@ -20,44 +18,30 @@ _INVALID_TYPE = "Expected data type {} but found {}" -@DeveloperAPI +@Deprecated( + help="The Spec checking APIs have been deprecated and cancelled without " + "replacement.", + error=False, +) class Spec(abc.ABC): - @DeveloperAPI @staticmethod @abc.abstractmethod def validate(self, data: Any) -> None: - """Validates the given data against this spec. - - Args: - data: The input to validate. + pass - Raises: - ValueError: If the data does not match this spec. - """ - -@DeveloperAPI +@Deprecated( + help="The Spec checking APIs have been deprecated and cancelled without " + "replacement.", + error=False, +) class TypeSpec(Spec): - """A base class that checks the type of the input data. - - Args: - dtype: The type of the object. - - .. testcode:: - :skipif: True - - spec = TypeSpec(tf.Tensor) - spec.validate(tf.ones((2, 3))) # passes - spec.validate(torch.ones((2, 3))) # ValueError - """ - def __init__(self, dtype: Type) -> None: self.dtype = dtype def __repr__(self): return f"TypeSpec({str(self.dtype)})" - @override(Spec) def validate(self, data: Any) -> None: if not isinstance(data, self.dtype): raise ValueError(_INVALID_TYPE.format(self.dtype, type(data))) @@ -71,42 +55,12 @@ def __ne__(self, other: "TypeSpec") -> bool: return not self == other -@DeveloperAPI +@Deprecated( + help="The Spec checking APIs have been deprecated and cancelled without " + "replacement.", + error=False, +) class TensorSpec(Spec): - """A base class that specifies the shape and dtype of a tensor. - - Args: - shape: A string representing einops notation of the shape of the tensor. - For example, "B, C" represents a tensor with two dimensions, the first - of which has size B and the second of which has size C. shape must - consist of unique dimension names. For example having "B B" is invalid. - dtype: The dtype of the tensor. If None, the dtype is not checked during - validation. Also during Sampling the dtype is set the default dtype of - the backend. - framework: The framework of the tensor. If None, the framework is not - checked during validation. - shape_vals: An optional dictionary mapping some dimension names to their - values. For example, if shape is "B, C" and shape_vals is {"C": 3}, then - the shape of the tensor is (B, 3). B is to be determined during - run-time but C is fixed to 3. - - .. testcode:: - :skipif: True - - spec = TensorSpec("b, d", d=128, dtype=tf.float32) - spec.shape # ('b', 128) - spec.validate(torch.rand(32, 128, dtype=torch.float32)) # passes - spec.validate(torch.rand(32, 64, dtype=torch.float32)) # raises ValueError - spec.validate(torch.rand(32, 128, dtype=torch.float64)) # raises ValueError - - Public Methods: - validate: Checks if the shape and dtype of the tensor matches the - specification. - fill: creates a tensor with the specified value that is an - example of a tensor that matches the specification. This can only be - called if `framework` is specified. - """ - def __init__( self, shape: str, @@ -125,9 +79,7 @@ def __init__( self._type = self._get_expected_type() - @OverrideToImplementCustomLogic def _get_expected_type(self) -> Type: - """Returns the expected type of the checked tensor.""" if self._framework == "torch": return torch.Tensor elif self._framework == "tf2": @@ -141,68 +93,33 @@ def _get_expected_type(self) -> Type: # Don't restrict the type of the tensor if no framework is specified. return object - @OverrideToImplementCustomLogic def get_shape(self, tensor: TensorType) -> Tuple[int]: - """Returns the shape of a tensor. - - Args: - tensor: The tensor whose shape is to be returned. - Returns: - A `tuple` specifying the shape of the tensor. - """ if self._framework == "tf2": - # tf2 returns `Dimension` objects instead of `int` objects. return tuple( int(i) if i is not None else None for i in tensor.shape.as_list() ) return tuple(tensor.shape) - @OverrideToImplementCustomLogic def get_dtype(self, tensor: TensorType) -> Any: - """Returns the expected data type of the checked tensor. - - Args: - tensor: The tensor whose data type is to be returned. - Returns: - The data type of the tensor. - """ return tensor.dtype @property def dtype(self) -> Any: - """Returns the expected data type of the checked tensor.""" return self._dtype @property def shape(self) -> Tuple[Union[int, str]]: - """Returns a `tuple` specifying the abstract tensor shape (int and str).""" return self._expected_shape @property def type(self) -> Type: - """Returns the expected type of the checked tensor.""" return self._type @property def full_shape(self) -> Tuple[int]: - """Returns a `tuple` specifying the concrete tensor shape (only ints).""" return self._full_shape def rdrop(self, n: int) -> "TensorSpec": - """Drops the last n dimensions. - - Returns a copy of this TensorSpec with the rightmost n dimensions removed. - - Args: - n: A positive number of dimensions to remove from the right - - Returns: - A copy of the tensor spec with the last n dims removed - - Raises: - IndexError: If n is greater than the number of indices in self - AssertionError: If n is negative or not an int - """ assert isinstance(n, int) and n >= 0, "n must be a positive integer or zero" copy_ = deepcopy(self) copy_._expected_shape = copy_.shape[:-n] @@ -210,31 +127,12 @@ def rdrop(self, n: int) -> "TensorSpec": return copy_ def append(self, spec: "TensorSpec") -> "TensorSpec": - """Appends the given TensorSpec to the self TensorSpec. - - Args: - spec: The TensorSpec to append to the current TensorSpec - - Returns: - A new tensor spec resulting from the concatenation of self and spec - - """ copy_ = deepcopy(self) copy_._expected_shape = (*copy_.shape, *spec.shape) copy_._full_shape = self._get_full_shape() return copy_ - @override(Spec) def validate(self, tensor: TensorType) -> None: - """Checks if the shape and dtype of the tensor matches the specification. - - Args: - tensor: The tensor to be validated. - - Raises: - ValueError: If the shape or dtype of the tensor does not match the - """ - if not isinstance(tensor, self.type): raise ValueError(_INVALID_TYPE.format(self.type, type(tensor).__name__)) @@ -250,20 +148,7 @@ def validate(self, tensor: TensorType) -> None: if self.dtype and dtype != self.dtype: raise ValueError(_INVALID_TYPE.format(self.dtype, tensor.dtype)) - @DeveloperAPI def fill(self, fill_value: Union[float, int] = 0) -> TensorType: - """Creates a tensor filled with `fill_value` that matches the specs. - - Args: - fill_value: The value to fill the tensor with. - - Returns: - A tensor with the specified value that matches the specs. - - Raises: - ValueError: If `framework` is not specified. - """ - if self._framework == "torch": return torch.full(self.full_shape, fill_value, dtype=self.dtype) @@ -285,8 +170,6 @@ def fill(self, fill_value: Union[float, int] = 0) -> TensorType: ) def _get_full_shape(self) -> Tuple[int]: - """Converts the expected shape to a shape by replacing the unknown dimension - sizes with a value of 1.""" sampled_shape = tuple() for d in self._expected_shape: if isinstance(d, int): @@ -296,9 +179,6 @@ def _get_full_shape(self) -> Tuple[int]: return sampled_shape def _parse_expected_shape(self, shape: str, shape_vals: Dict[str, int]) -> tuple: - """Converts the input shape to a tuple of integers and strings.""" - - # check the validity of shape_vals and get a list of dimension names d_names = shape.replace(" ", "").split(",") self._validate_shape_vals(d_names, shape_vals) @@ -309,12 +189,6 @@ def _parse_expected_shape(self, shape: str, shape_vals: Dict[str, int]) -> tuple def _validate_shape_vals( self, d_names: List[str], shape_vals: Dict[str, int] ) -> None: - """Checks if the shape_vals is valid. - - Valid means that shape consist of unique dimension names and shape_vals only - consists of keys that are in shape. Also shape_vals can only contain postive - integers. - """ d_names_set = set(d_names) if len(d_names_set) != len(d_names): raise ValueError(_INVALID_INPUT_DUP_DIM.format(",".join(d_names))) @@ -344,7 +218,6 @@ def __repr__(self) -> str: return f"TensorSpec(shape={tuple(self.shape)}, dtype={self.dtype})" def __eq__(self, other: "TensorSpec") -> bool: - """Checks if the shape and dtype of two specs are equal.""" if not isinstance(other, TensorSpec): return False return self.shape == other.shape and self.dtype == other.dtype diff --git a/rllib/core/models/specs/specs_dict.py b/rllib/core/models/specs/specs_dict.py index 7d33e546402d..adc2c46a9412 100644 --- a/rllib/core/models/specs/specs_dict.py +++ b/rllib/core/models/specs/specs_dict.py @@ -2,14 +2,9 @@ import tree from ray.rllib.core.models.specs.specs_base import Spec -from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils import force_tuple -_MISSING_KEYS_FROM_SPEC = ( - "The data dict does not match the model specs. Keys {} are " - "in the data dict but not on the given spec dict, and exact_match is set to True" -) _MISSING_KEYS_FROM_DATA = ( "The data dict does not match the model specs. Keys {} are " "in the spec dict but not on the data dict. Data keys are {}" @@ -24,97 +19,14 @@ IS_NOT_PROPERTY = "Spec {} must be a property of the class {}." -@ExperimentalAPI class SpecDict(dict, Spec): - """A dict containing `TensorSpec` and `Types`. - - It can be used to validate an incoming data against a nested dictionary of specs. - - Examples: - - Basic validation: - ----------------- - - .. testcode:: - :skipif: True - - spec_dict = SpecDict({ - "obs": { - "arm": TensorSpec("b, dim_arm", dim_arm=64), - "gripper": TensorSpec("b, dim_grip", dim_grip=12) - }, - "action": TensorSpec("b, dim_action", dim_action=12), - "action_dist": torch.distributions.Categorical - - spec_dict.validate({ - "obs": { - "arm": torch.randn(32, 64), - "gripper": torch.randn(32, 12) - }, - "action": torch.randn(32, 12), - "action_dist": torch.distributions.Categorical(torch.randn(32, 12)) - }) # No er - spec_dict.validate({ - "obs": { - "arm": torch.randn(32, 32), # Wrong shape - "gripper": torch.randn(32, 12) - }, - "action": torch.randn(32, 12), - "action_dist": torch.distributions.Categorical(torch.randn(32, 12)) - }) # raises ValueError - - Filtering input data: - --------------------- - - .. testcode:: - :skipif: True - - input_data = { - "obs": { - "arm": torch.randn(32, 64), - "gripper": torch.randn(32, 12), - "unused": torch.randn(32, 12) - }, - "action": torch.randn(32, 12), - "action_dist": torch.distributions.Categorical(torch.randn(32, 12)), - "unused": torch.randn(32, 12) - } - input_data.filter(spec_dict) # returns a dict with only the keys in the spec - - .. testoutput:: - - { - "obs": { - "arm": input_data["obs"]["arm"], - "gripper": input_data["obs"]["gripper"] - }, - "action": input_data["action"], - "action_dist": input_data["action_dist"] - } - - Raises: - ValueError: If the data doesn't match the spec. - """ - - @override(Spec) def validate( self, data: DATA_TYPE, exact_match: bool = False, ) -> None: - """Checks whether the data matches the spec. - - Args: - data: The data which should match the spec. It can also be a spec. - exact_match: If true, the data and the spec must be exactly identical. - Otherwise, the data is considered valid as long as it contains at least - the elements of the spec, but can contain more entries. - Raises: - ValueError: If the data doesn't match the spec. - """ check = self.is_subset(self, data, exact_match) if not check[0]: - # Collect all (nested) keys in `data`. data_keys_set = set() def _map(path, s): @@ -126,42 +38,26 @@ def _map(path, s): @staticmethod def is_subset(spec_dict, data_dict, exact_match=False): - """Whether `spec_dict` is a subset of `data_dict`.""" - if exact_match: tree.assert_same_structure(data_dict, spec_dict, check_types=False) for key in spec_dict: - # A key of `spec_dict` cannot be found in `data_dict` -> `spec_dict` is not - # a subset. if key not in data_dict: return False, key - # `data_dict` has same key. - - # `spec_dict`'s leaf value is None -> User does not want to specify - # further -> continue. if spec_dict[key] is None: continue - # `data_dict`'s leaf is another dict. elif isinstance(data_dict[key], dict): - # `spec_dict`'s leaf is NOT another dict -> No match, return False and - # the unmatched key. if not isinstance(spec_dict[key], dict): return False, key - # `spec_dict`'s leaf is another dict -> Recurse. res = SpecDict.is_subset(spec_dict[key], data_dict[key], exact_match) if not res[0]: return res - # `spec_dict`'s leaf is a dict (`data_dict`'s is not) -> No match, return - # False and the unmatched key. elif isinstance(spec_dict[key], dict): return False, key - # Neither `spec_dict`'s leaf nor `data_dict`'s leaf are dicts (and not None) - # -> Compare spec with data. elif isinstance(spec_dict[key], Spec): try: spec_dict[key].validate(data_dict[key]) @@ -186,6 +82,3 @@ def is_subset(spec_dict, data_dict, exact_match=False): ) return True, None - - def __repr__(self) -> str: - return f"SpecDict({repr(self.keys())})" diff --git a/rllib/core/models/specs/tests/__init__.py b/rllib/core/models/specs/tests/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/rllib/core/models/specs/tests/test_check_specs.py b/rllib/core/models/specs/tests/test_check_specs.py deleted file mode 100644 index 158da6f86aac..000000000000 --- a/rllib/core/models/specs/tests/test_check_specs.py +++ /dev/null @@ -1,332 +0,0 @@ -import abc -import time -import unittest -from typing import Dict, Any, Type - -import numpy as np -import torch - -from ray.rllib.core.models.specs.checker import SpecCheckingError -from ray.rllib.core.models.specs.checker import ( - convert_to_canonical_format, - check_input_specs, - check_output_specs, -) -from ray.rllib.core.models.specs.specs_base import TensorSpec, TypeSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict -from ray.rllib.utils.annotations import override - -ONLY_ONE_KEY_ALLOWED = "Only one key is allowed in the data dict." - - -class AbstractInterfaceClass(abc.ABC): - """An abstract class that has a couple of methods, each having their own - input/output constraints.""" - - @property - @abc.abstractmethod - def input_specs(self) -> SpecDict: - pass - - @property - @abc.abstractmethod - def output_specs(self) -> SpecDict: - pass - - @check_input_specs("input_specs", cache=False, only_check_on_retry=False) - @check_output_specs("output_specs", cache=False) - def check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - return self._check_input_and_output(input_dict) - - @abc.abstractmethod - def _check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - pass - - @check_input_specs("input_specs", cache=False, only_check_on_retry=False) - def check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - """should not override this method""" - return self._check_only_input(input_dict) - - @abc.abstractmethod - def _check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - pass - - @check_output_specs("output_specs", cache=False) - def check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - """should not override this method""" - return self._check_only_output(input_dict) - - @abc.abstractmethod - def _check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - pass - - @check_input_specs("input_specs", cache=True, only_check_on_retry=False) - @check_output_specs("output_specs", cache=True) - def check_input_and_output_with_cache( - self, input_dict: Dict[str, Any] - ) -> Dict[str, Any]: - """should not override this method""" - return self._check_input_and_output(input_dict) - - @check_input_specs("input_specs", cache=False, only_check_on_retry=False) - @check_output_specs("output_specs", cache=False) - def check_input_and_output_wo_filter(self, input_dict) -> Dict[str, Any]: - """should not override this method""" - return self._check_input_and_output(input_dict) - - -class InputNumberOutputFloat(AbstractInterfaceClass): - """This is an abstract class enforcing a contraint on input/output""" - - @property - def input_specs(self) -> SpecDict: - return SpecDict({"input": (float, int)}) - - @property - def output_specs(self) -> SpecDict: - return SpecDict({"output": float}) - - -class CorrectImplementation(InputNumberOutputFloat): - def run(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - output = float(input_dict["input"]) * 2 - return {"output": output} - - @override(AbstractInterfaceClass) - def _check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - # check if there is any key other than input in the input_dict - if "input" not in input_dict: - raise ValueError(ONLY_ONE_KEY_ALLOWED) - return self.run(input_dict) - - @override(AbstractInterfaceClass) - def _check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: - # check if there is any key other than input in the input_dict - if "input" not in input_dict: - raise ValueError(ONLY_ONE_KEY_ALLOWED) - - out = self.run(input_dict) - - # Output can be anything since there are no `output_specs`. - return {"output": str(out)} - - @override(AbstractInterfaceClass) - def _check_only_output(self, input_dict) -> Dict[str, Any]: - # there is no input spec, so we can pass anything - if "input" in input_dict: - raise ValueError( - "input_dict should not have `input` key in check_only_output" - ) - - return self.run({"input": input_dict["not_input"]}) - - -class IncorrectImplementation(CorrectImplementation): - @override(CorrectImplementation) - def run(self, input_dict) -> Dict[str, Any]: - output = str(input_dict["input"] * 2) - return {"output": output} - - -class TestCheckSpecs(unittest.TestCase): - def test_check_input_and_output(self): - - correct_module = CorrectImplementation() - - output = correct_module.check_input_and_output({"input": 2}) - # Output should also match the `output_specs`. - correct_module.output_specs.validate(output) - - # This should raise an error saying that the `input` key is missing. - self.assertRaises( - SpecCheckingError, - lambda: correct_module.check_input_and_output({"not_input": 2}), - ) - - def test_check_only_input(self): - correct_module = CorrectImplementation() - # this should not raise any error since input matches the input specs - output = correct_module.check_only_input({"input": 2}) - # Output can be anything since ther is no `output_specs`. - self.assertRaises( - ValueError, - lambda: correct_module.output_specs.validate(output), - ) - - def test_check_only_output(self): - correct_module = CorrectImplementation() - # This should not raise any error since input does not have to match - # `input_specs`. - output = correct_module.check_only_output({"not_input": 2}) - # Output should match the `output_specs`. - correct_module.output_specs.validate(output) - - def test_incorrect_implementation(self): - incorrect_module = IncorrectImplementation() - # this should raise an error saying that the output does not match the - # `output_specs`. - self.assertRaises( - SpecCheckingError, - lambda: incorrect_module.check_input_and_output({"input": 2}), - ) - - # this should not raise an error because output is not forced to be checked - incorrect_module.check_only_input({"input": 2}) - - # This should raise an error because output does not match the `output_specs`. - self.assertRaises( - SpecCheckingError, - lambda: incorrect_module.check_only_output({"not_input": 2}), - ) - - def test_filter(self): - # create an arbitrary large input dict and test the behavior with and without a - # filter - input_dict = {"input": 2} - for i in range(100): - inds = (str(i),) + tuple(str(j) for j in range(i + 1, i + 11)) - input_dict[inds] = i - - correct_module = CorrectImplementation() - - # should run without errors - correct_module.check_input_and_output(input_dict) - - def test_cache(self): - # warning: this could be a flakey test - # for N times, run the function twice and compare the time of each run. - # the second run should be faster since the output is cached - # to make sure the time is not too small, we run this on an input dict that is - # arbitrarily large and nested. - # we also check if cache is not working the second run is as slow as the first - # run. - - input_dict = {"input": 2} - for i in range(100): - inds = (str(i),) + tuple(str(j) for j in range(i + 1, i + 11)) - input_dict[inds] = i - - N = 500 - time1, time2 = [], [] - for _ in range(N): - - module = CorrectImplementation() - - fn = getattr(module, "check_input_and_output_with_cache") - start = time.time() - fn(input_dict) - end = time.time() - time1.append(end - start) - - start = time.time() - fn(input_dict) - end = time.time() - time2.append(end - start) - - lower_bound_time1 = np.mean(time1) # - 3 * np.std(time1) - upper_bound_time2 = np.mean(time2) # + 3 * np.std(time2) - print(f"time1: {np.mean(time1)}") - print(f"time2: {np.mean(time2)}") - - self.assertGreater(lower_bound_time1, upper_bound_time2) - - def test_tensor_specs(self): - # Test if the `input_specs` can be a tensor spec. - class ClassWithTensorSpec: - @property - def input_spec1(self) -> TensorSpec: - return TensorSpec("b, h", h=4, framework="torch") - - @check_input_specs("input_spec1", cache=False, only_check_on_retry=False) - def forward(self, input_data) -> Any: - return input_data - - module = ClassWithTensorSpec() - module.forward(torch.rand(2, 4)) - self.assertRaises(SpecCheckingError, lambda: module.forward(torch.rand(2, 3))) - - def test_type_specs(self): - class SpecialOutputType: - pass - - class WrongOutputType: - pass - - class ClassWithTypeSpec: - @property - def output_specs(self) -> Type: - return SpecialOutputType - - @check_output_specs("output_specs", cache=False) - def forward_pass(self, input_data) -> Any: - return SpecialOutputType() - - @check_output_specs("output_specs", cache=False) - def forward_fail(self, input_data) -> Any: - return WrongOutputType() - - module = ClassWithTypeSpec() - output = module.forward_pass(torch.rand(2, 4)) - self.assertIsInstance(output, SpecialOutputType) - self.assertRaises( - SpecCheckingError, lambda: module.forward_fail(torch.rand(2, 3)) - ) - - def test_convert_to_canonical_format(self): - - # Case: input is a list of strs - self.assertDictEqual( - convert_to_canonical_format(["foo", "bar"]), - SpecDict({"foo": None, "bar": None}), - ) - - # Case: input is a list of strs and nested strs - self.assertDictEqual( - convert_to_canonical_format(["foo", ("bar", "jar")]), - SpecDict({"foo": None, "bar": {"jar": None}}), - ) - - # Case: input is a Nested Mapping - returned = convert_to_canonical_format( - { - "foo": {"bar": TensorSpec("b", framework="torch")}, - "jar": {"tar": TypeSpec(int), "car": None}, - } - ) - self.assertIsInstance(returned, SpecDict) - self.assertDictEqual( - returned, - SpecDict( - { - "foo": {"bar": TensorSpec("b", framework="torch")}, - "jar": {"tar": TypeSpec(int), "car": None}, - } - ), - ) - - # Case: input is a SpecDict already - returned = convert_to_canonical_format( - SpecDict( - { - "foo": {"bar": TensorSpec("b", framework="torch")}, - "jar": {"tar": TypeSpec(int)}, - } - ) - ) - self.assertIsInstance(returned, SpecDict) - self.assertDictEqual( - returned, - SpecDict( - { - "foo": {"bar": TensorSpec("b", framework="torch")}, - "jar": {"tar": TypeSpec(int)}, - } - ), - ) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/specs/tests/test_spec_dict.py b/rllib/core/models/specs/tests/test_spec_dict.py deleted file mode 100644 index f5ae12b47fe5..000000000000 --- a/rllib/core/models/specs/tests/test_spec_dict.py +++ /dev/null @@ -1,244 +0,0 @@ -import unittest -import numpy as np - -from ray.rllib.core.models.specs.specs_base import TensorSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict -from ray.rllib.core.models.specs.checker import ( - check_input_specs, - convert_to_canonical_format, -) -from ray.rllib.core.models.specs.checker import SpecCheckingError - - -class TypeClass1: - pass - - -class TypeClass2: - pass - - -class TestSpecDict(unittest.TestCase): - def test_basic_validation(self): - """Tests basic validation of SpecDict.""" - - h1, h2 = 3, 4 - spec_1 = SpecDict( - { - "out_tensor_1": TensorSpec("b, h", h=h1, framework="np"), - "out_tensor_2": TensorSpec("b, h", h=h2, framework="np"), - "out_class_1": TypeClass1, - } - ) - - # test validation. - tensor_1 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - "out_class_1": TypeClass1(), - } - - spec_1.validate(tensor_1) - - # test missing key in tensor - tensor_2 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - } - - self.assertRaises(ValueError, lambda: spec_1.validate(tensor_2)) - - # test additional key in tensor (not mentioned in spec) - tensor_3 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - "out_class_1": TypeClass1(), - "out_class_2": TypeClass1(), - } - - # this should pass because exact_match is False - spec_1.validate(tensor_3, exact_match=False) - - # this should fail because exact_match is True - self.assertRaises( - ValueError, lambda: spec_1.validate(tensor_3, exact_match=True) - ) - - # raise type mismatch - tensor_4 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - "out_class_1": TypeClass2(), - } - - self.assertRaises(ValueError, lambda: spec_1.validate(tensor_4)) - - # test nested specs - spec_2 = SpecDict( - { - "encoder": { - "input": TensorSpec("b, h", h=h1, framework="np"), - "output": TensorSpec("b, h", h=h2, framework="np"), - }, - "decoder": { - "input": TensorSpec("b, h", h=h2, framework="np"), - "output": TensorSpec("b, h", h=h1, framework="np"), - }, - } - ) - - tensor_5 = { - "encoder": { - "input": np.random.randn(2, h1), - "output": np.random.randn(2, h2), - }, - "decoder": { - "input": np.random.randn(2, h2), - "output": np.random.randn(2, h1), - }, - } - - spec_2.validate(tensor_5) - - def test_key_existance_specs(self): - - # One level of keys - spec1 = convert_to_canonical_format(["foo", "bar"]) - - # # This should pass - data1 = {"foo": 1, "bar": 2} - spec1.validate(data1) - - # This should also pass - data2 = {"foo": {"tar": 1}, "bar": 2} - spec1.validate(data2) - - # This should fail - data3 = {"foo": 1} - self.assertRaises(ValueError, lambda: spec1.validate(data3)) - - # nested specs for keys - spec2 = convert_to_canonical_format([("foo", "bar"), "tar"]) - - # This should pass - data4 = {"foo": {"bar": 1}, "tar": 2} - spec2.validate(data4) - - # This should fail - data5 = {"foo": 2, "tar": 2} - self.assertRaises(ValueError, lambda: spec2.validate(data5)) - - # Another way to describe nested specs for keys - spec3 = convert_to_canonical_format({"foo": ["bar"], "tar": None}) - - # This should pass - spec3.validate(data4) - - # This should fail - self.assertRaises(ValueError, lambda: spec3.validate(data5)) - - def test_spec_check_integration(self): - """Tests the integration of SpecDict with the check_input_specs.""" - - class Model: - @property - def nested_key_spec(self): - return ["a", ("b", "c"), ("d",), ("e", "f"), ("e", "g")] - - @property - def dict_key_spec_with_none_leaves(self): - return { - "a": None, - "b": { - "c": None, - }, - "d": None, - "e": { - "f": None, - "g": None, - }, - } - - @property - def spec_with_type_and_tensor_leaves(self): - return {"a": TypeClass1, "b": TensorSpec("b, h", h=3, framework="np")} - - @check_input_specs( - "nested_key_spec", - only_check_on_retry=False, - cache=False, - ) - def forward_nested_key(self, input_dict): - return input_dict - - @check_input_specs( - "dict_key_spec_with_none_leaves", only_check_on_retry=False, cache=False - ) - def forward_dict_key_with_none_leaves(self, input_dict): - return input_dict - - @check_input_specs( - "spec_with_type_and_tensor_leaves", only_check_on_retry=False - ) - def forward_spec_with_type_and_tensor_leaves(self, input_dict): - return input_dict - - model = Model() - - # test nested key spec - input_dict_1 = { - "a": 1, - "b": { - "c": 2, - "foo": 3, - }, - "d": 3, - "e": { - "f": 4, - "g": 5, - }, - } - - # should run fine - model.forward_nested_key(input_dict_1) - model.forward_dict_key_with_none_leaves(input_dict_1) - - # test missing key - input_dict_2 = { - "a": 1, - "b": { - "c": 2, - "foo": 3, - }, - "d": 3, - "e": { - "f": 4, - }, - } - - self.assertRaises( - SpecCheckingError, lambda: model.forward_nested_key(input_dict_2) - ) - - self.assertRaises( - SpecCheckingError, - lambda: model.forward_dict_key_with_none_leaves(input_dict_2), - ) - - input_dict_3 = { - "a": TypeClass1(), - "b": np.array([1, 2, 3]), - } - - # should raise shape mismatch - self.assertRaises( - SpecCheckingError, - lambda: model.forward_spec_with_type_and_tensor_leaves(input_dict_3), - ) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/specs/tests/test_tensor_spec.py b/rllib/core/models/specs/tests/test_tensor_spec.py deleted file mode 100644 index 11d005489852..000000000000 --- a/rllib/core/models/specs/tests/test_tensor_spec.py +++ /dev/null @@ -1,220 +0,0 @@ -import itertools -import unittest -import numpy as np -from ray.rllib.utils import try_import_jax, try_import_tf, try_import_torch - -from ray.rllib.utils.test_utils import check -from ray.rllib.core.models.specs.specs_base import TensorSpec - -_, tf, _ = try_import_tf() -torch, _ = try_import_torch() -jax, _ = try_import_jax() -jnp = jax.numpy - -# This makes it so that does not convert 64-bit floats to 32-bit -jax.config.update("jax_enable_x64", True) - -FRAMEWORKS_TO_TEST = {"torch", "np", "tf2", "jax"} -DOUBLE_TYPE = { - "torch": torch.float64, - "np": np.float64, - "tf2": tf.float64, - "jax": jnp.float64, -} -FLOAT_TYPE = { - "torch": torch.float32, - "np": np.float32, - "tf2": tf.float32, - "jax": jnp.float32, -} - - -class TestSpecs(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - pass - - def test_fill(self): - - for fw in FRAMEWORKS_TO_TEST: - double_type = DOUBLE_TYPE[fw] - - # if un-specified dims should be 1, dtype is not important - x = TensorSpec("b,h", framework=fw).fill(float(2.0)) - - # check the shape - self.assertEqual(x.shape, (1, 1)) - # check the value - check(x, np.array([[2.0]])) - - x = TensorSpec("b,h", b=2, h=3, framework=fw).fill(2.0) - self.assertEqual(x.shape, (2, 3)) - - x = TensorSpec( - "b,h1,h2,h3", h1=2, h2=3, h3=3, framework=fw, dtype=double_type - ).fill(2) - self.assertEqual(x.shape, (1, 2, 3, 3)) - self.assertEqual(x.dtype, double_type) - - def test_validation(self): - - b, h = 2, 3 - - for fw in FRAMEWORKS_TO_TEST: - double_type = DOUBLE_TYPE[fw] - float_type = FLOAT_TYPE[fw] - - tensor_2d = TensorSpec( - "b,h", b=b, h=h, framework=fw, dtype=double_type - ).fill() - - matching_specs = [ - TensorSpec("b,h", framework=fw), - TensorSpec("b,h", h=h, framework=fw), - TensorSpec("b,h", h=h, b=b, framework=fw), - TensorSpec("b,h", b=b, framework=fw, dtype=double_type), - ] - - # check if get_shape returns a tuple of ints - shape = matching_specs[0].get_shape(tensor_2d) - self.assertIsInstance(shape, tuple) - print(fw) - print(shape) - self.assertTrue(all(isinstance(x, int) for x in shape)) - - # check matching - for spec in matching_specs: - spec.validate(tensor_2d) - - non_matching_specs = [ - TensorSpec("b", framework=fw), - TensorSpec("b,h1,h2", framework=fw), - TensorSpec("b,h", h=h + 1, framework=fw), - ] - if fw != "jax": - non_matching_specs.append( - TensorSpec("b,h", framework=fw, dtype=float_type) - ) - - for spec in non_matching_specs: - self.assertRaises(ValueError, lambda: spec.validate(tensor_2d)) - - # non unique dimensions - self.assertRaises(ValueError, lambda: TensorSpec("b,b", framework=fw)) - # unknown dimensions - self.assertRaises( - ValueError, lambda: TensorSpec("b,h", b=1, h=2, c=3, framework=fw) - ) - self.assertRaises(ValueError, lambda: TensorSpec("b1", b2=1, framework=fw)) - # zero dimensions - self.assertRaises( - ValueError, lambda: TensorSpec("b,h", b=1, h=0, framework=fw) - ) - # non-integer dimension - self.assertRaises( - ValueError, lambda: TensorSpec("b,h", b=1, h="h", framework=fw) - ) - - def test_equal(self): - - for fw in FRAMEWORKS_TO_TEST: - spec_eq_1 = TensorSpec("b,h", b=2, h=3, framework=fw) - spec_eq_2 = TensorSpec("b, h", b=2, h=3, framework=fw) - spec_eq_3 = TensorSpec(" b, h", b=2, h=3, framework=fw) - spec_neq_1 = TensorSpec("b, h", h=3, b=3, framework=fw) - spec_neq_2 = TensorSpec( - "b, h", h=3, b=3, framework=fw, dtype=DOUBLE_TYPE[fw] - ) - - self.assertTrue(spec_eq_1 == spec_eq_2) - self.assertTrue(spec_eq_2 == spec_eq_3) - self.assertTrue(spec_eq_1 != spec_neq_1) - self.assertTrue(spec_eq_1 != spec_neq_2) - - def test_type_validation(self): - # check all combinations of spec fws with tensor fws - for spec_fw, tensor_fw in itertools.product( - FRAMEWORKS_TO_TEST, FRAMEWORKS_TO_TEST - ): - - spec = TensorSpec("b, h", b=2, h=3, framework=spec_fw) - tensor = TensorSpec("b, h", b=2, h=3, framework=tensor_fw).fill(0) - - print("spec:", type(spec), ", tensor: ", type(tensor)) - - if spec_fw == tensor_fw: - spec.validate(tensor) - else: - self.assertRaises(ValueError, lambda: spec.validate(tensor)) - - def test_no_framework_arg(self): - """ - Test that a TensorSpec without a framework can be created and used except - for filling. - """ - spec = TensorSpec("b, h", b=2, h=3) - self.assertRaises(ValueError, lambda: spec.fill(0)) - - for fw in FRAMEWORKS_TO_TEST: - tensor = TensorSpec("b, h", b=2, h=3, framework=fw).fill(0) - spec.validate(tensor) - - def test_validate_framework(self): - """ - Test that a TensorSpec with a framework raises an error - when being used with a tensor from a different framework. - """ - for spec_fw, tensor_fw in itertools.product( - FRAMEWORKS_TO_TEST, FRAMEWORKS_TO_TEST - ): - spec = TensorSpec("b, h", b=2, h=3, framework=spec_fw) - tensor = TensorSpec("b, h", b=2, h=3, framework=tensor_fw).fill(0) - if spec_fw == tensor_fw: - spec.validate(tensor) - else: - self.assertRaises(ValueError, lambda: spec.validate(tensor)) - - def test_validate_dtype(self): - """ - Test that a TensorSpec with a dtype raises an error - when being used with a tensor from a different dtype but works otherwise. - """ - - all_types = [DOUBLE_TYPE, FLOAT_TYPE] - - for spec_types, tensor_types in itertools.product(all_types, all_types): - for spec_fw, tensor_fw in itertools.product( - FRAMEWORKS_TO_TEST, FRAMEWORKS_TO_TEST - ): - - # Pick the correct types for the frameworks - spec_type = spec_types[spec_fw] - tensor_type = tensor_types[tensor_fw] - - print( - "\nTesting.." "\nspec_fw: ", - spec_fw, - "\ntensor_fw: ", - tensor_fw, - "\nspec_type: ", - spec_type, - "\ntensor_type: ", - tensor_type, - ) - - spec = TensorSpec("b, h", b=2, h=3, dtype=spec_type) - tensor = TensorSpec( - "b, h", b=2, h=3, framework=tensor_fw, dtype=tensor_type - ).fill(0) - - if spec_type != tensor_type: - self.assertRaises(ValueError, lambda: spec.validate(tensor)) - else: - spec.validate(tensor) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/core/models/tests/test_base_models.py b/rllib/core/models/tests/test_base_models.py index bd0d1cb20311..bc951aa0f6af 100644 --- a/rllib/core/models/tests/test_base_models.py +++ b/rllib/core/models/tests/test_base_models.py @@ -4,9 +4,6 @@ import gymnasium as gym from ray.rllib.core.models.configs import ModelConfig -from ray.rllib.core.models.specs.checker import SpecCheckingError -from ray.rllib.core.models.specs.specs_base import TensorSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -27,150 +24,6 @@ class TestModelBase(unittest.TestCase): - def test_model_input_spec_checking(self): - """Tests if model input spec checking works correctly. - - This test is centered around the `always_check_shapes` flag of the - ModelConfig class. If this flag is set to True, the model will always - check if the inputs conform to the specs. If this flag is set to False, - the model will only check the input if we encounter an error in side - the forward call. - """ - - class CatModel: - """Simple model that concatenates parts of its input.""" - - def __init__(self, config): - super().__init__(config) - - def get_output_specs(self): - return SpecDict( - { - "out_1": TensorSpec("b, h", h=1, framework="torch"), - # out_2 is simply 2x stacked in_1 - "out_2": TensorSpec("b, h", h=4, framework="torch"), - } - ) - - def get_input_specs(self): - return SpecDict( - { - "in_1": TensorSpec("b, h", h=1, framework="torch"), - "in_2": TensorSpec("b, h", h=2, framework="torch"), - } - ) - - class TestModel(CatModel, TorchModel): - def _forward(self, input_dict): - out_2 = torch.cat([input_dict["in_2"], input_dict["in_2"]], dim=1) - return {"out_1": input_dict["in_1"], "out_2": out_2} - - @dataclass - class CatModelConfig(ModelConfig): - def build(self, framework: str): - # Since we define the correct model above anyway, we don't need - # to distinguish between frameworks here. - return TestModel(self) - - # 1) Check if model behaves correctly with always_check_shapes=True first - # We expect model to raise an error if the input shapes are not correct. - # This is the behaviour we use for debugging with model specs. - - config = CatModelConfig(always_check_shapes=True) - - model = config.build(framework="spam") - - # We want to raise an input spec validation error here since the input - # consists of lists and not torch Tensors - with self.assertRaisesRegex(SpecCheckingError, "input spec validation failed"): - model({"in_1": [1], "in_2": [1, 2]}) - - # We don't want to raise an input spec validation error here since the - # input consists of valid tensors - model({"in_1": torch.Tensor([[1]]), "in_2": torch.Tensor([[1, 2]])}) - - # 2) Check if model behaves correctly with always_check_shapes=False. - # We don't expect model to raise an error if the input shapes are not - # correct. - # This is the more performant default behaviour - - config = CatModelConfig(always_check_shapes=False) - - model = config.build(framework="spam") - - # This should not raise an error since the specs are correct and the - # model does not raise an error either. - model({"in_1": torch.Tensor([[1]]), "in_2": torch.Tensor([[1, 2]])}) - - # This should not raise an error since specs would be violated, but they - # are not checked and the model does not raise an error. - model({"in_1": torch.Tensor([[1]]), "in_2": torch.Tensor([[1, 2, 3, 4]])}) - - # We want to raise an input spec validation error here since the model - # raises an exception that stems from inputs that could have been caught - # with input spec checking. - with self.assertRaisesRegex(SpecCheckingError, "input spec validation failed"): - model({"in_1": [1], "in_2": [1, 2]}) - - def test_model_output_spec_checking(self): - """Tests if model output spec checking works correctly. - - This test is centered around the `always_check_shapes` flag of the - ModelConfig class. If this flag is set to True, the model will always - check if the outputs conform to the specs. If this flag is set to False, - the model will never check the outputs. - """ - - class BadModel: - """Simple model that produces bad outputs.""" - - def get_output_specs(self): - return SpecDict( - { - "out": TensorSpec("b, h", h=1, framework="torch"), - } - ) - - def get_input_specs(self): - return SpecDict( - { - "in": TensorSpec("b, h", h=1, framework="torch"), - } - ) - - class TestModel(BadModel, TorchModel): - def _forward(self, input_dict): - return {"out": torch.tensor([[1, 2]])} - - @dataclass - class CatModelConfig(ModelConfig): - def build(self, framework: str): - # Since we define the correct model above anyway, we don't need - # to distinguish between frameworks here. - return TestModel(self) - - # 1) Check if model behaves correctly with always_check_shapes=True first. - # We expect model to raise an error if the output shapes are not correct. - # This is the behaviour we use for debugging with model specs. - - config = CatModelConfig(always_check_shapes=True) - - model = config.build(framework="spam") - - # We want to raise an output spec validation error here since the output - # has the wrong shape - with self.assertRaisesRegex(SpecCheckingError, "output spec validation failed"): - model({"in": torch.Tensor([[1]])}) - - # 2) Check if model behaves correctly with always_check_shapes=False. - # We don't expect model to raise an error. - # This is the more performant default behaviour - - config = CatModelConfig(always_check_shapes=False) - - model = config.build(framework="spam") - - model({"in_1": [[1]]}) # Todo (rllib-team): Fix for torch 2.0+ @unittest.skip("Failing with torch >= 2.0") @@ -189,20 +42,6 @@ def __init__(self, config): super().__init__(config) self._model = torch.nn.Linear(1, 1) - def get_output_specs(self): - return SpecDict( - { - "out": TensorSpec("b, h", h=1, framework="torch"), - } - ) - - def get_input_specs(self): - return SpecDict( - { - "in": TensorSpec("b, h", h=1, framework="torch"), - } - ) - def _forward(self, input_dict): return {"out": self._model(input_dict["in"])} diff --git a/rllib/core/models/tests/test_cnn_transpose_heads.py b/rllib/core/models/tests/test_cnn_transpose_heads.py index 2c5a0d13c037..3248ce17b24e 100644 --- a/rllib/core/models/tests/test_cnn_transpose_heads.py +++ b/rllib/core/models/tests/test_cnn_transpose_heads.py @@ -92,7 +92,7 @@ def test_cnn_transpose_heads(self): model_checker = ModelChecker(config) # Add this framework version of the model to our checker. - outputs = model_checker.add(framework="torch") + outputs = model_checker.add(framework="torch", obs=False) self.assertEqual(outputs.shape, (1,) + tuple(expected_output_dims)) # Check all added models against each other. diff --git a/rllib/core/models/tests/test_mlp_heads.py b/rllib/core/models/tests/test_mlp_heads.py index fcdcf0ac9695..d40f9880a5af 100644 --- a/rllib/core/models/tests/test_mlp_heads.py +++ b/rllib/core/models/tests/test_mlp_heads.py @@ -77,7 +77,7 @@ def test_mlp_heads(self): model_checker = ModelChecker(config) # Add this framework version of the model to our checker. - outputs = model_checker.add(framework="torch") + outputs = model_checker.add(framework="torch", obs=False) self.assertEqual(outputs.shape, (1, output_dim)) # Check all added models against each other. diff --git a/rllib/core/models/tests/test_recurrent_encoders.py b/rllib/core/models/tests/test_recurrent_encoders.py index e2ba68be01b7..3ac411bc0945 100644 --- a/rllib/core/models/tests/test_recurrent_encoders.py +++ b/rllib/core/models/tests/test_recurrent_encoders.py @@ -1,6 +1,8 @@ import itertools import unittest +import numpy as np + from ray.rllib.core.columns import Columns from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.models.configs import RecurrentEncoderConfig @@ -54,7 +56,9 @@ def test_gru_encoders(self): model_checker = ModelChecker(config) # Add this framework version of the model to our checker. - outputs = model_checker.add(framework="torch") + outputs = model_checker.add( + framework="torch", state={"h": np.array([num_layers, hidden_dim])} + ) # Output shape: [1=B, 1=T, [output_dim]] self.assertEqual( outputs[ENCODER_OUT].shape, @@ -111,7 +115,13 @@ def test_lstm_encoders(self): model_checker = ModelChecker(config) # Add this framework version of the model to our checker. - outputs = model_checker.add(framework="torch") + outputs = model_checker.add( + framework="torch", + state={ + "h": np.array([num_layers, hidden_dim]), + "c": np.array([num_layers, hidden_dim]), + }, + ) # Output shape: [1=B, 1=T, [output_dim]] self.assertEqual( outputs[ENCODER_OUT].shape, diff --git a/rllib/core/models/tf/base.py b/rllib/core/models/tf/base.py index e38137bff46e..48e346812c42 100644 --- a/rllib/core/models/tf/base.py +++ b/rllib/core/models/tf/base.py @@ -6,15 +6,8 @@ from ray.rllib.core.models.base import Model from ray.rllib.core.models.configs import ModelConfig -from ray.rllib.core.models.specs.checker import ( - check_input_specs, - is_input_decorated, - is_output_decorated, - check_output_specs, -) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf -from ray.util import log_once logger = logging.getLogger(__name__) _, tf, _ = try_import_tf() @@ -32,25 +25,6 @@ def __init__(self, config: ModelConfig): tf.keras.Model.__init__(self) Model.__init__(self, config) - # Raise errors if forward method is not decorated to check input specs. - if not is_input_decorated(self.call): - raise ValueError( - f"`{type(self).__name__}.call()` not decorated with input " - f"specification. Decorate it with @check_input_specs() to define a " - f"specification and resolve this Error. If you don't want to check " - f"anything, you can use an empty spec." - ) - - if is_output_decorated(self.call): - if log_once("tf_model_forward_output_decorated"): - logger.warning( - f"`{type(self).__name__}.call()` decorated with output " - f"specification. This is not recommended because it can lead to " - f"slower execution. Remove @check_output_specs() from the " - f"forward method to resolve this." - ) - - @check_input_specs("input_specs") def call(self, input_dict: dict, **kwargs) -> dict: """Returns the output of this model for the given input. @@ -63,19 +37,6 @@ def call(self, input_dict: dict, **kwargs) -> dict: Returns: dict: The output tensors. """ - - # When `always_check_shapes` is set, we always check input and output specs. - # Note that we check the input specs twice because we need the following - # check to always check the input specs. - if self.config.always_check_shapes: - - @check_input_specs("input_specs", only_check_on_retry=False) - @check_output_specs("output_specs") - def checked_forward(self, input_data, **kwargs): - return self._forward(input_data, **kwargs) - - return checked_forward(self, input_dict, **kwargs) - return self._forward(input_dict, **kwargs) @override(Model) diff --git a/rllib/core/models/tf/encoder.py b/rllib/core/models/tf/encoder.py index ff4956df4a8b..3d280e23cda7 100644 --- a/rllib/core/models/tf/encoder.py +++ b/rllib/core/models/tf/encoder.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict import tree # pip install dm_tree @@ -19,8 +19,6 @@ ) from ray.rllib.core.models.tf.base import TfModel from ray.rllib.core.models.tf.primitives import TfMLP, TfCNN -from ray.rllib.core.models.specs.specs_base import Spec, TensorSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.models.utils import get_initializer_fn from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf @@ -82,38 +80,6 @@ def __init__(self, config: CNNEncoderConfig) -> None: # Create the network from gathered layers. self.net = tf.keras.Sequential(layers) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - Columns.OBS: TensorSpec( - "b, w, h, c", - w=self.config.input_dims[0], - h=self.config.input_dims[1], - c=self.config.input_dims[2], - framework="tf2", - ), - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: ( - TensorSpec("b, d", d=self.config.output_dims[0], framework="tf2") - if self.config.flatten_at_end - else TensorSpec( - "b, w, h, c", - w=self.config.output_dims[0], - h=self.config.output_dims[1], - d=self.config.output_dims[2], - framework="tf2", - ) - ) - } - ) - @override(Model) def _forward(self, inputs: dict, **kwargs) -> dict: return {ENCODER_OUT: self.net(inputs[Columns.OBS])} @@ -150,26 +116,6 @@ def __init__(self, config: MLPEncoderConfig) -> None: output_bias_initializer_config=config.output_layer_bias_initializer_config, ) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - Columns.OBS: TensorSpec( - "b, d", d=self.config.input_dims[0], framework="tf2" - ), - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, d", d=self.config.output_dims[0], framework="tf2" - ), - } - ) - @override(Model) def _forward(self, inputs: Dict, **kwargs) -> Dict: return {ENCODER_OUT: self.net(inputs[Columns.OBS])} @@ -235,43 +181,6 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: input_dims = (1, 1, config.hidden_dim) self.grus.append(layer) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - # b, t for batch major; t, b for time major. - Columns.OBS: TensorSpec( - "b, t, d", d=self.config.input_dims[0], framework="tf2" - ), - Columns.STATE_IN: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="tf2", - ), - }, - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, t, d", d=self.config.output_dims[0], framework="tf2" - ), - Columns.STATE_OUT: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="tf2", - ), - }, - } - ) - @override(Model) def get_initial_state(self): return { @@ -366,55 +275,6 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: input_dims = (1, 1, config.hidden_dim) self.lstms.append(layer) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - # b, t for batch major; t, b for time major. - Columns.OBS: TensorSpec( - "b, t, d", d=self.config.input_dims[0], framework="tf2" - ), - Columns.STATE_IN: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="tf2", - ), - "c": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="tf2", - ), - }, - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, t, d", d=self.config.output_dims[0], framework="tf2" - ), - Columns.STATE_OUT: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="tf2", - ), - "c": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="tf2", - ), - }, - } - ) - @override(Model) def get_initial_state(self): return { diff --git a/rllib/core/models/tf/heads.py b/rllib/core/models/tf/heads.py index 823946efc3e8..e92ee5e0577e 100644 --- a/rllib/core/models/tf/heads.py +++ b/rllib/core/models/tf/heads.py @@ -1,6 +1,3 @@ -import functools -from typing import Optional - import numpy as np from ray.rllib.core.models.base import Model @@ -9,9 +6,6 @@ FreeLogStdMLPHeadConfig, MLPHeadConfig, ) -from ray.rllib.core.models.specs.checker import SpecCheckingError -from ray.rllib.core.models.specs.specs_base import Spec -from ray.rllib.core.models.specs.specs_base import TensorSpec from ray.rllib.core.models.tf.base import TfModel from ray.rllib.core.models.tf.primitives import TfCNNTranspose, TfMLP from ray.rllib.models.utils import get_initializer_fn @@ -21,80 +15,6 @@ tf1, tf, tfv = try_import_tf() -def auto_fold_unfold_time(input_spec: str): - """Automatically folds/unfolds the time dimension of a tensor. - - This is useful when calling the model requires a batch dimension only, but the - input data has a batch- and a time-dimension. This decorator will automatically - fold the time dimension into the batch dimension before calling the model and - unfold the batch dimension back into the time dimension after calling the model. - - Args: - input_spec: The input spec of the model. - - Returns: - A decorator that automatically folds/unfolds the time_dimension if present. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(self, input_data, **kwargs): - if not hasattr(self, input_spec): - raise ValueError( - "The model must have an input_specs attribute to " - "automatically fold/unfold the time dimension." - ) - if not tf.is_tensor(input_data): - raise ValueError( - f"input_data must be a tf.Tensor to fold/unfold " - f"time automatically, but got {type(input_data)}." - ) - # Attempt to fold/unfold the time dimension. - actual_shape = tf.shape(input_data) - spec = getattr(self, input_spec) - - try: - # Validate the input data against the input spec to find out it we - # should attempt to fold/unfold the time dimension. - spec.validate(input_data) - except ValueError as original_error: - # Attempt to fold/unfold the time dimension. - # Calculate a new shape for the input data. - b, t = actual_shape[0], actual_shape[1] - other_dims = actual_shape[2:] - reshaped_b = b * t - new_shape = tf.concat([[reshaped_b], other_dims], axis=0) - reshaped_inputs = tf.reshape(input_data, new_shape) - try: - spec.validate(reshaped_inputs) - except ValueError as new_error: - raise SpecCheckingError( - f"Attempted to call {func} with input data of shape " - f"{actual_shape}. RLlib attempts to automatically fold/unfold " - f"the time dimension because {actual_shape} does not match the " - f"input spec {spec}. In an attempt to fold the time " - f"dimensions to possibly fit the input specs of {func}, " - f"RLlib has calculated the new shape {new_shape} and " - f"reshaped the input data to {reshaped_inputs}. However, " - f"the input data still does not match the input spec. " - f"\nOriginal error: \n{original_error}. \nNew error:" - f" \n{new_error}." - ) - # Call the actual wrapped function - outputs = func(self, reshaped_inputs, **kwargs) - # Attempt to unfold the time dimension. - return tf.reshape( - outputs, tf.concat([[b, t], tf.shape(outputs)[1:]], axis=0) - ) - # If above we could validate the spec, we can call the actual wrapped - # function. - return func(self, input_data, **kwargs) - - return wrapper - - return decorator - - class TfMLPHead(TfModel): def __init__(self, config: MLPHeadConfig) -> None: TfModel.__init__(self, config) @@ -130,15 +50,6 @@ def __init__(self, config: MLPHeadConfig) -> None: self.log_std_clip_param = tf.constant([config.log_std_clip_param]) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="tf2") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.output_dims[0], framework="tf2") - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: # Only clip the log standard deviations, if the user wants to clip. This # avoids also clipping value heads. @@ -202,15 +113,6 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None: self.log_std_clip_param = tf.constant([config.log_std_clip_param]) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="tf2") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.output_dims[0], framework="tf2") - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: # Compute the mean first, then append the log_std. mean = self.net(inputs) @@ -281,21 +183,6 @@ def __init__(self, config: CNNTransposeHeadConfig) -> None: ) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="tf2") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec( - "b, w, h, c", - w=self.config.output_dims[0], - h=self.config.output_dims[1], - c=self.config.output_dims[2], - framework="tf2", - ) - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: # Push through initial dense layer to get dimensions of first "image". out = self.initial_dense(inputs) diff --git a/rllib/core/models/torch/base.py b/rllib/core/models/torch/base.py index cdf61f863d1c..a737ed6cfc91 100644 --- a/rllib/core/models/torch/base.py +++ b/rllib/core/models/torch/base.py @@ -6,16 +6,9 @@ from ray.rllib.core.models.base import Model from ray.rllib.core.models.configs import ModelConfig -from ray.rllib.core.models.specs.checker import ( - is_input_decorated, - is_output_decorated, - check_input_specs, - check_output_specs, -) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType -from ray.util import log_once torch, nn = try_import_torch() @@ -69,27 +62,6 @@ def __init__(self, config: ModelConfig): nn.Module.__init__(self) Model.__init__(self, config) - # Raise errors if forward method is not decorated to check input specs. - if not is_input_decorated(self.forward): - raise ValueError( - f"`{type(self).__name__}.forward()` not decorated with input " - f"specification. Decorate it with @check_input_specs() to define a " - f"specification and resolve this Error. If you don't want to check " - f"anything, you can use an empty spec." - ) - - if is_output_decorated(self.forward): - if log_once("torch_model_forward_output_decorated"): - logger.warning( - f"`{type(self).__name__}.forward()` decorated with output " - f"specification. This is not recommended for torch models " - f"that are used with torch.compile() because it breaks " - f"torch dynamo's graph. This can lead lead to slower execution." - f"Remove @check_output_specs() from the forward() method to " - f"resolve this." - ) - - @check_input_specs("input_specs") def forward( self, inputs: Union[dict, TensorType], **kwargs ) -> Union[dict, TensorType]: @@ -104,19 +76,6 @@ def forward( Returns: dict: The output tensors. """ - - # When `always_check_shapes` is set, we always check input and output specs. - # Note that we check the input specs twice because we need the following - # check to always check the input specs. - if self.config.always_check_shapes: - - @check_input_specs("input_specs", only_check_on_retry=False) - @check_output_specs("output_specs") - def checked_forward(self, input_data, **kwargs): - return self._forward(input_data, **kwargs) - - return checked_forward(self, inputs, **kwargs) - return self._forward(inputs, **kwargs) @override(Model) diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index f9e59bdc6f2f..82812e43fc61 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -1,5 +1,3 @@ -from typing import Optional - import tree from ray.rllib.core.columns import Columns @@ -16,9 +14,6 @@ MLPEncoderConfig, RecurrentEncoderConfig, ) -from ray.rllib.core.models.specs.specs_base import Spec -from ray.rllib.core.models.specs.specs_base import TensorSpec -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.models.torch.primitives import TorchMLP, TorchCNN from ray.rllib.models.utils import get_initializer_fn @@ -79,26 +74,6 @@ def __init__(self, config: MLPEncoderConfig) -> None: output_bias_initializer_config=config.output_layer_bias_initializer_config, ) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - Columns.OBS: TensorSpec( - "b, d", d=self.config.input_dims[0], framework="torch" - ), - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, d", d=self.config.output_dims[0], framework="torch" - ), - } - ) - @override(Model) def _forward(self, inputs: dict, **kwargs) -> dict: return {ENCODER_OUT: self.net(inputs[Columns.OBS])} @@ -131,38 +106,6 @@ def __init__(self, config: CNNEncoderConfig) -> None: # Create the network from gathered layers. self.net = nn.Sequential(*layers) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - Columns.OBS: TensorSpec( - "b, w, h, c", - w=self.config.input_dims[0], - h=self.config.input_dims[1], - c=self.config.input_dims[2], - framework="torch", - ), - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: ( - TensorSpec("b, d", d=self.config.output_dims[0], framework="torch") - if self.config.flatten_at_end - else TensorSpec( - "b, w, h, c", - w=self.config.output_dims[0], - h=self.config.output_dims[1], - d=self.config.output_dims[2], - framework="torch", - ) - ) - } - ) - @override(Model) def _forward(self, inputs: dict, **kwargs) -> dict: return {ENCODER_OUT: self.net(inputs[Columns.OBS])} @@ -218,45 +161,6 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: self.gru.weight, **config.hidden_bias_initializer_config or {} ) - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - # b, t for batch major; t, b for time major. - Columns.OBS: TensorSpec( - "b, t, d", - d=self.config.input_dims[0], - framework="torch", - ), - Columns.STATE_IN: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - }, - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, t, d", d=self.config.output_dims[0], framework="torch" - ), - Columns.STATE_OUT: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - }, - } - ) - @override(Model) def get_initial_state(self): return { @@ -346,44 +250,6 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: layer[3], **config.hidden_bias_initializer_config or {} ) - self._state_in_out_spec = { - "h": TensorSpec( - "b, l, d", - d=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - "c": TensorSpec( - "b, l, d", - d=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - } - - @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - # b, t for batch major; t, b for time major. - Columns.OBS: TensorSpec( - "b, t, d", d=self.config.input_dims[0], framework="torch" - ), - Columns.STATE_IN: self._state_in_out_spec, - } - ) - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return SpecDict( - { - ENCODER_OUT: TensorSpec( - "b, t, d", d=self.config.output_dims[0], framework="torch" - ), - Columns.STATE_OUT: self._state_in_out_spec, - } - ) - @override(Model) def get_initial_state(self): return { diff --git a/rllib/core/models/torch/heads.py b/rllib/core/models/torch/heads.py index 2e9e23cd969a..844c40c4a44c 100644 --- a/rllib/core/models/torch/heads.py +++ b/rllib/core/models/torch/heads.py @@ -1,6 +1,3 @@ -import functools -from typing import Optional - import numpy as np from ray.rllib.core.models.base import Model @@ -9,9 +6,6 @@ FreeLogStdMLPHeadConfig, MLPHeadConfig, ) -from ray.rllib.core.models.specs.checker import SpecCheckingError -from ray.rllib.core.models.specs.specs_base import Spec -from ray.rllib.core.models.specs.specs_base import TensorSpec from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.models.torch.primitives import TorchCNNTranspose, TorchMLP from ray.rllib.models.utils import get_initializer_fn @@ -21,78 +15,6 @@ torch, nn = try_import_torch() -def auto_fold_unfold_time(input_spec: str): - """Automatically folds/unfolds the time dimension of a tensor. - - This is useful when calling the model requires a batch dimension only, but the - input data has a batch- and a time-dimension. This decorator will automatically - fold the time dimension into the batch dimension before calling the model and - unfold the batch dimension back into the time dimension after calling the model. - - Args: - input_spec: The input spec of the model. - - Returns: - A decorator that automatically folds/unfolds the time_dimension if present. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(self, input_data, **kwargs): - if not hasattr(self, input_spec): - raise ValueError( - "The model must have an input_specs attribute to " - "automatically fold/unfold the time dimension." - ) - if not torch.is_tensor(input_data): - raise ValueError( - f"input_data must be a torch.Tensor to fold/unfold " - f"time automatically, but got {type(input_data)}." - ) - # Attempt to fold/unfold the time dimension. - actual_shape = list(input_data.shape) - spec = getattr(self, input_spec) - - try: - # Validate the input data against the input spec to find out it we - # should attempt to fold/unfold the time dimension. - spec.validate(input_data) - except ValueError as original_error: - # Attempt to fold/unfold the time dimension. - # Calculate a new shape for the input data. - b, t = actual_shape[:2] - other_dims = actual_shape[2:] - reshaped_b = b * t - new_shape = tuple([reshaped_b] + other_dims) - reshaped_inputs = input_data.reshape(new_shape) - try: - spec.validate(reshaped_inputs) - except ValueError as new_error: - raise SpecCheckingError( - f"Attempted to call {func} with input data of shape " - f"{actual_shape}. RLlib attempts to automatically fold/unfold " - f"the time dimension because {actual_shape} does not match the " - f"input spec {spec}. In an attempt to fold the time " - f"dimensions to possibly fit the input specs of {func}, " - f"RLlib has calculated the new shape {new_shape} and " - f"reshaped the input data to {reshaped_inputs}. However, " - f"the input data still does not match the input spec. " - f"\nOriginal error: \n{original_error}. \nNew error:" - f" \n{new_error}." - ) - # Call the actual wrapped function - outputs = func(self, reshaped_inputs, **kwargs) - # Attempt to unfold the time dimension. - return outputs.reshape((b, t) + tuple(outputs.shape[1:])) - # If above we could validate the spec, we can call the actual wrapped - # function. - return func(self, input_data, **kwargs) - - return wrapper - - return decorator - - class TorchMLPHead(TorchModel): def __init__(self, config: MLPHeadConfig) -> None: super().__init__(config) @@ -130,15 +52,6 @@ def __init__(self, config: MLPHeadConfig) -> None: self.register_buffer("log_std_clip_param_const", self.log_std_clip_param) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="torch") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.output_dims[0], framework="torch") - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: # Only clip the log standard deviations, if the user wants to clip. This # avoids also clipping value heads. @@ -203,15 +116,6 @@ def __init__(self, config: FreeLogStdMLPHeadConfig) -> None: self.register_buffer("log_std_clip_param_const", self.log_std_clip_param) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="torch") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.output_dims[0], framework="torch") - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: # Compute the mean first, then append the log_std. mean = self.net(inputs) @@ -283,21 +187,6 @@ def __init__(self, config: CNNTransposeHeadConfig) -> None: ) @override(Model) - def get_input_specs(self) -> Optional[Spec]: - return TensorSpec("b, d", d=self.config.input_dims[0], framework="torch") - - @override(Model) - def get_output_specs(self) -> Optional[Spec]: - return TensorSpec( - "b, w, h, c", - w=self.config.output_dims[0], - h=self.config.output_dims[1], - c=self.config.output_dims[2], - framework="torch", - ) - - @override(Model) - @auto_fold_unfold_time("input_specs") def _forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: out = self.initial_dense(inputs) # Reshape to initial 3D (image-like) format to enter CNN transpose stack. diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 1ddd33471a29..fe1ed3ec72b5 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -7,17 +7,11 @@ from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.models.specs.typing import SpecType -from ray.rllib.core.models.specs.checker import ( - check_input_specs, - check_output_specs, - convert_to_canonical_format, -) from ray.rllib.models.distributions import Distribution from ray.rllib.utils.annotations import ( ExperimentalAPI, override, OverrideToImplementCustomLogic, - OverrideToImplementCustomLogic_CallToSuperRecommended, ) from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.deprecation import Deprecated @@ -411,60 +405,18 @@ def __init__(self, config: RLModuleConfig): framework=self.framework ) - # Make sure, `setup()` is only called once, no matter what. In some cases - # of multiple inheritance (and with our __post_init__ functionality in place, - # this might get called twice. + # Make sure, `setup()` is only called once, no matter what. if hasattr(self, "_is_setup") and self._is_setup: raise RuntimeError( "`RLModule.setup()` called twice within your RLModule implementation " f"{self}! Make sure you are using the proper inheritance order " "(TorchRLModule before [Algo]RLModule) or (TfRLModule before " - "[Algo]RLModule) and that you are using `super().__init__(...)` in " - "your custom constructor." + "[Algo]RLModule) and that you are NOT overriding the constructor, but " + "only the `setup()` method of your subclass." ) self.setup() self._is_setup = True - def __init_subclass__(cls, **kwargs): - # Automatically add a __post_init__ method to all subclasses of RLModule. - # This method is called after the __init__ method of the subclass. - def init_decorator(previous_init): - def new_init(self, *args, **kwargs): - previous_init(self, *args, **kwargs) - if type(self) is cls: - self.__post_init__() - - return new_init - - cls.__init__ = init_decorator(cls.__init__) - - def __post_init__(self): - """Called automatically after the __init__ method of the subclass. - - The module first calls the __init__ method of the subclass, With in the - __init__ you should call the super().__init__ method. Then after the __init__ - method of the subclass is called, the __post_init__ method is called. - - This is a good place to do any initialization that requires access to the - subclass's attributes. - """ - self._input_specs_train = convert_to_canonical_format(self.input_specs_train()) - self._output_specs_train = convert_to_canonical_format( - self.output_specs_train() - ) - self._input_specs_exploration = convert_to_canonical_format( - self.input_specs_exploration() - ) - self._output_specs_exploration = convert_to_canonical_format( - self.output_specs_exploration() - ) - self._input_specs_inference = convert_to_canonical_format( - self.input_specs_inference() - ) - self._output_specs_inference = convert_to_canonical_format( - self.output_specs_inference() - ) - @OverrideToImplementCustomLogic def setup(self): """Sets up the components of the module. @@ -543,8 +495,6 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """ return {} - @check_input_specs("_input_specs_inference") - @check_output_specs("_output_specs_inference") def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler. @@ -574,8 +524,6 @@ def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """ return self._forward(batch, **kwargs) - @check_input_specs("_input_specs_exploration") - @check_output_specs("_output_specs_exploration") def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler. @@ -605,8 +553,6 @@ def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any """ return self._forward(batch, **kwargs) - @check_input_specs("_input_specs_train") - @check_output_specs("_output_specs_train") def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """DO NOT OVERRIDE! Forward-pass during training called from the learner. @@ -709,48 +655,6 @@ def get_ctor_args_and_kwargs(self): {}, # **kwargs ) - @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_inference(self) -> SpecType: - """Returns the output specs of the `forward_inference()` method. - - Override this method to customize the output specs of the inference call. - The default implementation requires the `forward_inference()` method to return - a dict that has `action_dist` key and its value is an instance of - `Distribution`. - """ - return [Columns.ACTION_DIST_INPUTS] - - @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_exploration(self) -> SpecType: - """Returns the output specs of the `forward_exploration()` method. - - Override this method to customize the output specs of the exploration call. - The default implementation requires the `forward_exploration()` method to return - a dict that has `action_dist` key and its value is an instance of - `Distribution`. - """ - return [Columns.ACTION_DIST_INPUTS] - - def output_specs_train(self) -> SpecType: - """Returns the output specs of the forward_train method.""" - return {} - - def input_specs_inference(self) -> SpecType: - """Returns the input specs of the forward_inference method.""" - return self._default_input_specs() - - def input_specs_exploration(self) -> SpecType: - """Returns the input specs of the forward_exploration method.""" - return self._default_input_specs() - - def input_specs_train(self) -> SpecType: - """Returns the input specs of the forward_train method.""" - return self._default_input_specs() - - def _default_input_specs(self) -> SpecType: - """Returns the default input specs.""" - return [Columns.OBS] - def as_multi_rl_module(self) -> "MultiRLModule": """Returns a multi-agent wrapper around this module.""" from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule @@ -785,3 +689,29 @@ def load_state(self, *args, **kwargs): @Deprecated(new="RLModule.save_to_path(...)", error=True) def save_to_checkpoint(self, *args, **kwargs): pass + + def output_specs_inference(self) -> SpecType: + return [Columns.ACTION_DIST_INPUTS] + + def output_specs_exploration(self) -> SpecType: + return [Columns.ACTION_DIST_INPUTS] + + def output_specs_train(self) -> SpecType: + """Returns the output specs of the forward_train method.""" + return {} + + def input_specs_inference(self) -> SpecType: + """Returns the input specs of the forward_inference method.""" + return self._default_input_specs() + + def input_specs_exploration(self) -> SpecType: + """Returns the input specs of the forward_exploration method.""" + return self._default_input_specs() + + def input_specs_train(self) -> SpecType: + """Returns the input specs of the forward_train method.""" + return self._default_input_specs() + + def _default_input_specs(self) -> SpecType: + """Returns the default input specs.""" + return [Columns.OBS] diff --git a/rllib/examples/custom_recurrent_rnn_tokenizer.py b/rllib/examples/custom_recurrent_rnn_tokenizer.py index f41c432a3519..fd7bab9edab5 100644 --- a/rllib/examples/custom_recurrent_rnn_tokenizer.py +++ b/rllib/examples/custom_recurrent_rnn_tokenizer.py @@ -21,8 +21,6 @@ from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.policy.sample_batch import SampleBatch from dataclasses import dataclass -from ray.rllib.core.models.specs.specs_dict import SpecDict -from ray.rllib.core.models.specs.specs_base import TensorSpec from ray.rllib.core.models.base import Encoder, ENCODER_OUT from ray.rllib.core.models.torch.base import TorchModel from ray.rllib.core.models.tf.base import TfModel @@ -85,16 +83,6 @@ def __init__(self, config) -> None: nn.Linear(config.input_dims[0], config.output_dims[0]), ) - # Since we use this model as a tokenizer, we need to define it's output - # dimensions so that we know the input dim for the recurent cells that follow. - def get_output_specs(self): - # In this example, the output dim will be 64, but we still fetch it from - # config so that this code is more reusable. - output_dim = self.config.output_dims[0] - return SpecDict( - {ENCODER_OUT: TensorSpec("b, d", d=output_dim, framework="torch")} - ) - def _forward(self, inputs: dict, **kwargs): return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} @@ -111,12 +99,6 @@ def __init__(self, config) -> None: ] ) - def get_output_specs(self): - output_dim = self.config.output_dims[0] - return SpecDict( - {ENCODER_OUT: TensorSpec("b, d", d=output_dim, framework="tf2")} - ) - def _forward(self, inputs: dict, **kwargs): return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index bbfcb6982151..2263ab740500 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -5,7 +5,6 @@ from ray.rllib.core import Columns from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.models.configs import MLPHeadConfig -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule @@ -97,23 +96,6 @@ def setup(self): ) self.vf = vf_config.build(framework=self.framework) - @override(RLModule) - def output_specs_inference(self) -> SpecDict: - return [Columns.ACTIONS] - - @override(RLModule) - def output_specs_exploration(self) -> SpecDict: - return [Columns.ACTION_DIST_INPUTS, Columns.ACTIONS, Columns.ACTION_LOGP] - - @override(RLModule) - def output_specs_train(self) -> SpecDict: - return [ - Columns.ACTION_DIST_INPUTS, - Columns.ACTIONS, - Columns.ACTION_LOGP, - Columns.VF_PREDS, - ] - @abstractmethod def pi(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: """Computes the policy outputs given a batch of data. diff --git a/rllib/models/utils.py b/rllib/models/utils.py index ac23a09e6fb2..c57b94bbfd18 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -1,81 +1,9 @@ from typing import Callable, Optional, Union -from ray.rllib.core.models.specs.specs_base import TensorSpec - -from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch -@DeveloperAPI -def input_to_output_specs( - input_specs, - num_input_feature_dims: int, - output_key: str, - output_feature_spec: TensorSpec, -): - """Convert an input spec to an output spec, based on a module. - - Drops the feature dimension(s) from an input_specs, replacing them with - output_feature_spec dimension(s). - - Examples: - input_to_output_specs( - input_specs=SpecDict({ - "bork": "batch, time, feature0", - "dork": "batch, time, feature1" - }, feature0=2, feature1=3 - ), - num_input_feature_dims=1, - output_key="outer_product", - output_feature_spec=TensorSpec("row, col", row=2, col=3) - ) - - will return: - SpecDict({"outer_product": "batch, time, row, col", row=2, col=3}) - - input_to_output_specs( - input_specs=SpecDict({ - "bork": "batch, time, h, w, c", - }, h=32, w=32, c=3, - ), - num_input_feature_dims=3, - output_key="latent_image_representation", - output_feature_spec=TensorSpec("feature", feature=128) - ) - - will return: - SpecDict({"latent_image_representation": "batch, time, feature"}, feature=128) - - - Args: - input_specs: SpecDict describing input to a specified module - num_input_dims: How many feature dimensions the module will process. E.g. - a linear layer will only process the last dimension (1), while a CNN - might process the last two dimensions (2) - output_key: The key in the output spec we will write the resulting shape to - output_feature_spec: A spec denoting the feature dimensions output by a - specified module - - Returns: - A SpecDict based on the input_specs, with the trailing dimensions replaced - by the output_feature_spec - - """ - assert num_input_feature_dims >= 1, "Must specify at least one feature dim" - num_dims = [len(v.shape) != len for v in input_specs.values()] - assert all( - nd == num_dims[0] for nd in num_dims - ), "All specs in input_specs must all have the same number of dimensions" - - # All keys in input should have the same numbers of dims - # so it doesn't matter which key we use - key = list(input_specs.keys())[0] - batch_spec = input_specs[key].rdrop(num_input_feature_dims) - full_spec = batch_spec.append(output_feature_spec) - return SpecDict({output_key: full_spec}) - - @DeveloperAPI def get_activation_fn( name: Optional[Union[Callable, str]] = None, diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 7ffe87469557..7bbb3bf63f03 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -32,7 +32,7 @@ from ray.air.constants import TRAINING_ITERATION from ray.air.integrations.wandb import WandbLoggerCallback, WANDB_ENV_VAR from ray.rllib.common import SupportedFileType -from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core import DEFAULT_MODULE_ID, Columns from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind from ray.rllib.train import load_experiments_from_file from ray.rllib.utils.annotations import OldAPIStack @@ -1833,36 +1833,26 @@ def __init__(self, config): # Dict of models to check against each other. self.models = {} - def add(self, framework: str = "torch") -> Any: + def add(self, framework: str = "torch", obs=True, state=False) -> Any: """Builds a new Model for the given framework.""" model = self.models[framework] = self.config.build(framework=framework) # Pass a B=1 observation through the model. - from ray.rllib.core.models.specs.specs_dict import SpecDict - - if isinstance(model.input_specs, SpecDict): - # inputs = {} + inputs = np.full( + [1] + ([1] if state else []) + list(self.config.input_dims), + self.random_fill_input_value, + ) + if obs: + inputs = {Columns.OBS: inputs} + if state: + inputs[Columns.STATE_IN] = tree.map_structure( + lambda s: np.zeros(shape=[1] + list(s)), state + ) + if framework == "torch": + from ray.rllib.utils.torch_utils import convert_to_torch_tensor - def _fill(s): - if s is not None: - return s.fill(self.random_fill_input_value) - else: - return None - - inputs = tree.map_structure(_fill, dict(model.input_specs)) - # for key, spec in model.input_specs.items(): - # dict_ = inputs - # for i, sub_key in enumerate(key): - # if sub_key not in dict_: - # dict_[sub_key] = {} - # if i < len(key) - 1: - # dict_ = dict_[sub_key] - # if spec is not None: - # dict_[sub_key] = spec.fill(self.random_fill_input_value) - # else: - # dict_[sub_key] = None - else: - inputs = model.input_specs.fill(self.random_fill_input_value) + inputs = convert_to_torch_tensor(inputs) + # w/ old specs: inputs = model.input_specs.fill(self.random_fill_input_value) outputs = model(inputs)