diff --git a/rllib/BUILD b/rllib/BUILD index 21a2c675da4e..f074a42b0075 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1916,10 +1916,10 @@ py_test( # test ModelSpec py_test( - name = "test_model_spec", + name = "test_spec_dict", tags = ["team:rllib", "models"], size = "small", - srcs = ["models/specs/tests/test_model_spec.py"] + srcs = ["models/specs/tests/test_spec_dict.py"] ) # test TorchVectorEncoder diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index cba78d3b41a2..4331f046a2f4 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.framework import try_import_torch -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.torch_distributions import ( TorchCategorical, @@ -227,12 +227,12 @@ def get_initial_state(self) -> NestedDict: return {} @override(RLModule) - def input_specs_inference(self) -> ModelSpec: + def input_specs_inference(self) -> SpecDict: return self.input_specs_exploration() @override(RLModule) - def output_specs_inference(self) -> ModelSpec: - return ModelSpec({SampleBatch.ACTION_DIST: TorchDeterministic}) + def output_specs_inference(self) -> SpecDict: + return SpecDict({SampleBatch.ACTION_DIST: TorchDeterministic}) @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: @@ -255,7 +255,7 @@ def input_specs_exploration(self): return self.shared_encoder.input_spec() @override(RLModule) - def output_specs_exploration(self) -> ModelSpec: + def output_specs_exploration(self) -> SpecDict: specs = {SampleBatch.ACTION_DIST: self.__get_action_dist_type()} if self._is_discrete: specs[SampleBatch.ACTION_DIST_INPUTS] = { @@ -267,7 +267,7 @@ def output_specs_exploration(self) -> ModelSpec: "scale": TorchTensorSpec("b, h", h=self.config.action_space.shape[0]), } - return ModelSpec(specs) + return SpecDict(specs) @override(RLModule) def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: @@ -299,7 +299,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: return output @override(RLModule) - def input_specs_train(self) -> ModelSpec: + def input_specs_train(self) -> SpecDict: if self._is_discrete: action_spec = TorchTensorSpec("b") else: @@ -310,12 +310,12 @@ def input_specs_train(self) -> ModelSpec: spec_dict.update({SampleBatch.ACTIONS: action_spec}) if SampleBatch.OBS in spec_dict: spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] - spec = ModelSpec(spec_dict) + spec = SpecDict(spec_dict) return spec @override(RLModule) - def output_specs_train(self) -> ModelSpec: - spec = ModelSpec( + def output_specs_train(self) -> SpecDict: + spec = SpecDict( { SampleBatch.ACTION_DIST: self.__get_action_dist_type(), SampleBatch.ACTION_LOGP: TorchTensorSpec("b", dtype=torch.float32), diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index 2b5e02bed9ae..09c71417157e 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -7,7 +7,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import add_time_dimension -from ray.rllib.models.specs.specs_dict import ModelSpec, check_specs +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.primitives import FCNet @@ -66,12 +67,13 @@ def get_inital_state(self): raise [] def input_spec(self): - return ModelSpec() + return SpecDict() def output_spec(self): - return ModelSpec() + return SpecDict() - @check_specs(input_spec="_input_spec", output_spec="_output_spec") + @check_input_specs("_input_spec") + @check_output_specs("_output_spec") def forward(self, input_dict): return self._forward(input_dict) @@ -91,12 +93,12 @@ def __init__(self, config: FCConfig) -> None: ) def input_spec(self): - return ModelSpec( + return SpecDict( {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} ) def output_spec(self): - return ModelSpec( + return SpecDict( {"embedding": TorchTensorSpec("b, h", h=self.config.output_dim)} ) @@ -125,7 +127,7 @@ def get_inital_state(self): def input_spec(self): config = self.config - return ModelSpec( + return SpecDict( { # bxt is just a name for better readability to indicated padded batch SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), @@ -142,7 +144,7 @@ def input_spec(self): def output_spec(self): config = self.config - return ModelSpec( + return SpecDict( { "embedding": TorchTensorSpec("bxt, h", h=config.output_dim), "state_out": { @@ -181,10 +183,10 @@ def __init__(self, config: EncoderConfig) -> None: super().__init__(config) def input_spec(self): - return ModelSpec() + return SpecDict() def output_spec(self): - return ModelSpec() + return SpecDict() def _forward(self, input_dict): return input_dict diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py index 43b253c080a2..5caf9059016f 100644 --- a/rllib/core/rl_module/marl_module.py +++ b/rllib/core/rl_module/marl_module.py @@ -5,7 +5,7 @@ from ray.util.annotations import PublicAPI from ray.rllib.utils.annotations import override -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.core.rl_module import RLModule @@ -243,32 +243,32 @@ def __getitem__(self, module_id: ModuleID) -> RLModule: return self._rl_modules[module_id] @override(RLModule) - def output_specs_train(self) -> ModelSpec: + def output_specs_train(self) -> SpecDict: return self._get_specs_for_modules("output_specs_train") @override(RLModule) - def output_specs_inference(self) -> ModelSpec: + def output_specs_inference(self) -> SpecDict: return self._get_specs_for_modules("output_specs_inference") @override(RLModule) - def output_specs_exploration(self) -> ModelSpec: + def output_specs_exploration(self) -> SpecDict: return self._get_specs_for_modules("output_specs_exploration") @override(RLModule) - def input_specs_train(self) -> ModelSpec: + def input_specs_train(self) -> SpecDict: return self._get_specs_for_modules("input_specs_train") @override(RLModule) - def input_specs_inference(self) -> ModelSpec: + def input_specs_inference(self) -> SpecDict: return self._get_specs_for_modules("input_specs_inference") @override(RLModule) - def input_specs_exploration(self) -> ModelSpec: + def input_specs_exploration(self) -> SpecDict: return self._get_specs_for_modules("input_specs_exploration") - def _get_specs_for_modules(self, method_name: str) -> ModelSpec: + def _get_specs_for_modules(self, method_name: str) -> SpecDict: """Returns a ModelSpec from the given method_name for all modules.""" - return ModelSpec( + return SpecDict( { module_id: getattr(module, method_name)() for module_id, module in self._rl_modules.items() diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 46ec77a0efa2..c66ed068a2f9 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -11,7 +11,8 @@ OverrideToImplementCustomLogic_CallToSuperRecommended, ) -from ray.rllib.models.specs.specs_dict import ModelSpec, check_specs +from ray.rllib.models.specs.typing import SpecType +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.distributions import Distribution from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.nested_dict import NestedDict @@ -188,7 +189,7 @@ def get_initial_state(self) -> NestedDict: return {} @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_inference(self) -> ModelSpec: + def output_specs_inference(self) -> SpecType: """Returns the output specs of the forward_inference method. Override this method to customize the output specs of the inference call. @@ -196,10 +197,10 @@ def output_specs_inference(self) -> ModelSpec: has `action_dist` key and its value is an instance of `Distribution`. This assumption must always hold. """ - return ModelSpec({"action_dist": Distribution}) + return {"action_dist": Distribution} @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_exploration(self) -> ModelSpec: + def output_specs_exploration(self) -> SpecType: """Returns the output specs of the forward_exploration method. Override this method to customize the output specs of the inference call. @@ -207,27 +208,26 @@ def output_specs_exploration(self) -> ModelSpec: that has `action_dist` key and its value is an instance of `Distribution`. This assumption must always hold. """ - return ModelSpec({"action_dist": Distribution}) + return {"action_dist": Distribution} - def output_specs_train(self) -> ModelSpec: + def output_specs_train(self) -> SpecType: """Returns the output specs of the forward_train method.""" - return ModelSpec() + return {} - def input_specs_inference(self) -> ModelSpec: + def input_specs_inference(self) -> SpecType: """Returns the input specs of the forward_inference method.""" - return ModelSpec() + return {} - def input_specs_exploration(self) -> ModelSpec: + def input_specs_exploration(self) -> SpecType: """Returns the input specs of the forward_exploration method.""" - return ModelSpec() + return {} - def input_specs_train(self) -> ModelSpec: + def input_specs_train(self) -> SpecType: """Returns the input specs of the forward_train method.""" - return ModelSpec() + return {} - @check_specs( - input_spec="_input_specs_inference", output_spec="_output_specs_inference" - ) + @check_input_specs("_input_specs_inference") + @check_output_specs("_output_specs_inference") def forward_inference(self, batch: SampleBatchType, **kwargs) -> Mapping[str, Any]: """Forward-pass during evaluation, called from the sampler. This method should not be overriden. Instead, override the _forward_inference method. @@ -247,9 +247,8 @@ def forward_inference(self, batch: SampleBatchType, **kwargs) -> Mapping[str, An def _forward_inference(self, batch: NestedDict, **kwargs) -> Mapping[str, Any]: """Forward-pass during evaluation. See forward_inference for details.""" - @check_specs( - input_spec="_input_specs_exploration", output_spec="_output_specs_exploration" - ) + @check_input_specs("_input_specs_exploration") + @check_output_specs("_output_specs_exploration") def forward_exploration( self, batch: SampleBatchType, **kwargs ) -> Mapping[str, Any]: @@ -271,7 +270,8 @@ def forward_exploration( def _forward_exploration(self, batch: NestedDict, **kwargs) -> Mapping[str, Any]: """Forward-pass during exploration. See forward_exploration for details.""" - @check_specs(input_spec="_input_specs_train", output_spec="_output_specs_train") + @check_input_specs("_input_specs_train") + @check_output_specs("_output_specs_train") def forward_train( self, batch: SampleBatchType, diff --git a/rllib/core/testing/tf/bc_module.py b/rllib/core/testing/tf/bc_module.py index 2d6f4621c4e8..ed6499dd2a08 100644 --- a/rllib/core/testing/tf/bc_module.py +++ b/rllib/core/testing/tf/bc_module.py @@ -5,7 +5,8 @@ from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.typing import SpecType from ray.rllib.models.specs.specs_tf import TFTensorSpecs from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override @@ -33,28 +34,28 @@ def __init__( self._output_dim = output_dim @override(RLModule) - def input_specs_exploration(self) -> ModelSpec: - return ModelSpec(self._default_inputs()) + def input_specs_exploration(self) -> SpecType: + return self._default_inputs() @override(RLModule) - def input_specs_inference(self) -> ModelSpec: - return ModelSpec(self._default_inputs()) + def input_specs_inference(self) -> SpecType: + return self._default_inputs() @override(RLModule) - def input_specs_train(self) -> ModelSpec: - return ModelSpec(self._default_inputs()) + def input_specs_train(self) -> SpecType: + return self._default_inputs() @override(RLModule) - def output_specs_exploration(self) -> ModelSpec: - return ModelSpec(self._default_outputs()) + def output_specs_exploration(self) -> SpecType: + return self._default_outputs() @override(RLModule) - def output_specs_inference(self) -> ModelSpec: - return ModelSpec(self._default_outputs()) + def output_specs_inference(self) -> SpecType: + return self._default_outputs() @override(RLModule) - def output_specs_train(self) -> ModelSpec: - return ModelSpec(self._default_outputs()) + def output_specs_train(self) -> SpecType: + return self._default_outputs() @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: @@ -104,9 +105,9 @@ def from_model_config( return cls(**config) - def _default_inputs(self) -> ModelSpec: + def _default_inputs(self) -> SpecDict: obs_dim = self._input_dim return {"obs": TFTensorSpecs("b, do", do=obs_dim)} - def _default_outputs(self) -> ModelSpec: + def _default_outputs(self) -> SpecDict: return {"action_dist": tfp.distributions.Distribution} diff --git a/rllib/core/testing/torch/bc_module.py b/rllib/core/testing/torch/bc_module.py index e11df5b43dcd..f3f27448d371 100644 --- a/rllib/core/testing/torch/bc_module.py +++ b/rllib/core/testing/torch/bc_module.py @@ -7,8 +7,8 @@ from ray.rllib.core.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule -from ray.rllib.models.specs.specs_dict import ModelSpec from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.models.specs.typing import SpecType from ray.rllib.utils.annotations import override from ray.rllib.utils.nested_dict import NestedDict @@ -30,28 +30,28 @@ def __init__( self.input_dim = input_dim @override(RLModule) - def input_specs_exploration(self) -> ModelSpec: - return ModelSpec(self._default_inputs()) + def input_specs_exploration(self) -> SpecType: + return self._default_inputs() @override(RLModule) - def input_specs_inference(self) -> ModelSpec: - return ModelSpec(self._default_inputs()) + def input_specs_inference(self) -> SpecType: + return self._default_inputs() @override(RLModule) - def input_specs_train(self) -> ModelSpec: - return ModelSpec(self._default_inputs()) + def input_specs_train(self) -> SpecType: + return self._default_inputs() @override(RLModule) - def output_specs_exploration(self) -> ModelSpec: - return ModelSpec(self._default_outputs()) + def output_specs_exploration(self) -> SpecType: + return self._default_outputs() @override(RLModule) - def output_specs_inference(self) -> ModelSpec: - return ModelSpec(self._default_outputs()) + def output_specs_inference(self) -> SpecType: + return self._default_outputs() @override(RLModule) - def output_specs_train(self) -> ModelSpec: - return ModelSpec(self._default_outputs()) + def output_specs_train(self) -> SpecType: + return self._default_outputs() @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: diff --git a/rllib/models/configs/encoder.py b/rllib/models/configs/encoder.py index 07d87e6f7df2..38a7f305123a 100644 --- a/rllib/models/configs/encoder.py +++ b/rllib/models/configs/encoder.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Tuple -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.torch.encoders.vector import TorchVectorEncoder if TYPE_CHECKING: @@ -32,7 +32,7 @@ class EncoderConfig: framework_str: str = "torch" @abc.abstractmethod - def build(self, input_spec: ModelSpec, **kwargs) -> "Encoder": + def build(self, input_spec: SpecDict, **kwargs) -> "Encoder": """Builds the EncoderConfig into an Encoder instance""" @@ -56,7 +56,7 @@ class VectorEncoderConfig(EncoderConfig): hidden_layer_sizes: Tuple[int, ...] = (128, 128) output_key: str = "encoding" - def build(self, input_spec: ModelSpec) -> TorchVectorEncoder: + def build(self, input_spec: SpecDict) -> TorchVectorEncoder: """Build the config into a VectorEncoder model instance. Args: diff --git a/rllib/models/specs/checker.py b/rllib/models/specs/checker.py new file mode 100644 index 000000000000..b7ca04c74325 --- /dev/null +++ b/rllib/models/specs/checker.py @@ -0,0 +1,338 @@ +from collections import abc +import functools +from typing import Union, Mapping, Any, Callable + +from ray.util.annotations import DeveloperAPI + +from ray.rllib.utils.nested_dict import NestedDict +from ray.rllib.models.specs.specs_base import Spec, TypeSpec +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.typing import SpecType + + +def _convert_to_canonical_format(spec: SpecType) -> Union[Spec, SpecDict]: + """Converts a spec type input to the canonical format. + + The canonical format is either + + 1. A nested SpecDict when the input spec is dict-like tree of specs and types or + nested list of nested_keys. + 2. A single SpecType object, if the spec is a single constraint. + + The input can be any of the following: + - a list of nested_keys. nested_keys are either strings or tuples of strings + specifying the path to a leaf in the tree. + - a tree of constraints. The tree structure can be specified as any nested + hash-map structure (e.g. dict, SpecDict, NestedDict, etc.) The leaves of the + tree can be either a Spec object, a type, or None. If the leaf is a type, it is + converted to a TypeSpec. If the leaf is None, only the existance of the key is + checked and the value will be None in the canonical format. + - a single constraint. The constraint can be a Spec object, a type, or None. + + + Examples of canoncial format #1: + + .. code-block:: python + spec = ["foo", ("bar", "baz")] + output = _convert_to_canonical_format(spec) + # output = SpecDict({"foo": None, ("bar", "baz"): None}) + + spec = {"foo": int, "bar": {"baz": None}} + output = _convert_to_canonical_format(spec) + # output = SpecDict( + # {"foo": TypeSpec(int), "bar": SpecDict({"baz": None})} + # ) + + spec = {"foo": int, "bar": {"baz": str}} + output = _convert_to_canonical_format(spec) + # output = SpecDict( + # {"foo": TypeSpec(int), "bar": SpecDict({"baz": TypeSpec(str)})} + # ) + + spec = {"foo": int, "bar": {"baz": TorchTensorSpec("b,h")}} + output = _convert_to_canonical_format(spec) + # output = SpecDict( + # {"foo": TypeSpec(int), "bar": SpecDict({"baz": TorchTensorSpec("b,h")})} + # ) + + + # Example of canoncial format #2: + + .. code-block:: python + spec = int + output = _convert_to_canonical_format(spec) + # output = TypeSpec(int) + + spec = None + output = _convert_to_canonical_format(spec) + # output = None + + spec = TorchTensorSpec("b,h") + output = _convert_to_canonical_format(spec) + # output = TorchTensorSpec("b,h") + + Args: + spec: The spec to convert to canonical format. + + Returns: + The canonical format of the spec. + """ + # convert spec of form list of nested_keys to model_spec with None leaves + if isinstance(spec, list): + spec = [(k,) if isinstance(k, str) else k for k in spec] + return SpecDict({k: None for k in spec}) + + # convert spec of form tree of constraints to model_spec + if isinstance(spec, abc.Mapping): + spec = SpecDict(spec) + for key in spec: + # if values are types or tuple of types, convert to TypeSpec + if isinstance(spec[key], (type, tuple)): + spec[key] = TypeSpec(spec[key]) + return spec + + if isinstance(spec, type): + return TypeSpec(spec) + + # otherwise, assume spec is already in canonical format + return spec + + +def _should_validate( + cls_instance: object, method: Callable, tag: str = "input" +) -> bool: + """Returns True if the spec should be validated, False otherwise. + + The spec should be validated if the method is not cached (i.e. there is no cache + storage attribute in the instance) or if the method is already cached. (i.e. it + exists in the cache storage attribute) + + Args: + cls_instance: The class instance that the method belongs to. + method: The method to apply the spec checking to. + tag: The tag of the spec to check. Either "input" or "output". This is used + internally to defined an internal cache storage attribute based on the tag. + + Returns: + True if the spec should be validated, False otherwise. + """ + cache_store = getattr(cls_instance, f"__checked_{tag}_specs_cache__", None) + return cache_store is None or method.__name__ not in cache_store + + +def _validate( + *, + cls_instance: object, + method: Callable, + data: Mapping[str, Any], + spec: Spec, + filter: bool = False, + tag: str = "input", +) -> NestedDict: + """Validate the data against the spec. + + Args: + cls_instance: The class instance that the method belongs to. + method: The method to apply the spec checking to. + data: The data to validate. + spec: The spec to validate against. + filter: If True, the data will be filtered to only include the keys that are + specified in the spec. + tag: The tag of the spec to check. Either "input" or "output". This is used + internally to defined an internal cache storage attribute based on the tag. + + Returns: + The data, filtered if filter is True. + """ + cache_miss = _should_validate(cls_instance, method, tag=tag) + + if isinstance(spec, SpecDict): + if not isinstance(data, abc.Mapping): + raise ValueError(f"{tag} must be a Mapping, got {type(data).__name__}") + if cache_miss or filter: + data = NestedDict(data) + + if cache_miss: + try: + spec.validate(data) + except ValueError as e: + raise ValueError( + f"{tag} spec validation failed on " + f"{cls_instance.__class__.__name__}.{method.__name__}, {e}." + ) + + return data + + +@DeveloperAPI(stability="alpha") +def check_input_specs( + input_spec: str, + *, + filter: bool = False, + cache: bool = False, +): + """A general-purpose spec checker decorator for neural network base classes. + + This is a stateful decorator + (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to + enforce input specs for any instance method that has an argument named + `input_data` in its args. + + It also allows you to filter the input data dictionary to only include those keys + that are specified in the model specs. It also allows you to cache the validation + to make sure the spec is only validated once in the entire lifetime of the instance. + + Examples (See more examples in ../tests/test_specs_dict.py): + + >>> class MyModel(nn.Module): + ... @property + ... def input_spec(self): + ... return {"obs": TensorSpec("b, d", d=64)} + ... + ... @check_input_specs("input_spec") + ... def forward(self, input_data, return_loss=False): + ... ... + + >>> model = MyModel() + >>> model.forward({"obs": torch.randn(32, 64)}) # No error + >>> model.forward({"obs": torch.randn(32, 32)}) # raises ValueError + + Args: + func: The instance method to decorate. It should be a callable that takes + `self` as the first argument, `input_data` as the second argument and any + other keyword argument thereafter. + input_spec: `self` should have an instance attribute whose name matches the + string in input_spec and returns the `SpecDict`, `Spec`, or simply the + `Type` that the `input_data` should comply with. It can also be None or + empty list / dict to enforce no input spec. + filter: If True, and `input_data` is a nested dict the `input_data` will be + filtered by its corresponding spec tree structure and then passed into the + implemented function to make sure user is not confounded with unnecessary + data. + cache: If True, only checks the data validation for the first time the + instance method is called. + + Returns: + A wrapped instance method. In case of `cache=True`, after the first invokation + of the decorated method, the intance will have `__checked_input_specs_cache__` + attribute that stores which method has been invoked at least once. This is a + special attribute that can be used for the cache itself. The wrapped class + method also has a special attribute `__checked_input_specs__` that marks the + method as decorated. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, input_data, **kwargs): + if cache and not hasattr(self, "__checked_input_specs_cache__"): + self.__checked_input_specs_cache__ = {} + + checked_data = input_data + if input_spec: + spec = getattr(self, input_spec, "___NOT_FOUND___") + if spec == "___NOT_FOUND___": + raise ValueError(f"object {self} has no attribute {input_spec}.") + spec = _convert_to_canonical_format(spec) + checked_data = _validate( + cls_instance=self, + method=func, + data=input_data, + spec=spec, + filter=filter, + tag="input", + ) + + if filter and isinstance(checked_data, NestedDict): + # filtering should happen regardless of cache + checked_data = checked_data.filter(spec) + + output_data = func(self, checked_data, **kwargs) + + if cache and func.__name__ not in self.__checked_input_specs_cache__: + self.__checked_input_specs_cache__[func.__name__] = True + + return output_data + + wrapper.__checked_input_specs__ = True + return wrapper + + return decorator + + +@DeveloperAPI(stability="alpha") +def check_output_specs( + output_spec: str, + *, + cache: bool = False, +): + """A general-purpose spec checker decorator for Neural Network base classes. + + This is a stateful decorator + (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to + enforce output specs for any instance method that outputs a single dict-like object. + + It also allows you to cache the validation to make sure the spec is only validated + once in the entire lifetime of the instance. + + Examples (See more examples in ../tests/test_specs_dict.py): + + >>> class MyModel(nn.Module): + ... @property + ... def output_spec(self): + ... return {"obs": TensorSpec("b, d", d=64)} + ... + ... @check_output_specs("output_spec") + ... def forward(self, input_data, return_loss=False): + ... return {"obs": torch.randn(32, 64)} + + Args: + func: The instance method to decorate. It should be a callable that takes + `self` as the first argument, `input_data` as the second argument and any + other keyword argument thereafter. It should return a single dict-like + object (i.e. not a tuple). + input_spec: `self` should have an instance attribute whose name matches the + string in input_spec and returns the `SpecDict`, `Spec`, or simply the + `Type` that the `input_data` should comply with. It can alos be None or + empty list / dict to enforce no input spec. + cache: If True, only checks the data validation for the first time the + instance method is called. + + Returns: + A wrapped instance method. In case of `cache=True`, after the first invokation + of the decorated method, the intance will have `__checked_output_specs_cache__` + attribute that stores which method has been invoked at least once. This is a + special attribute that can be used for the cache itself. The wrapped class + method also has a special attribute `__checked_output_specs__` that marks the + method as decorated. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, input_data, **kwargs): + if cache and not hasattr(self, "__checked_output_specs_cache__"): + self.__checked_output_specs_cache__ = {} + + output_data = func(self, input_data, **kwargs) + + if output_spec: + spec = getattr(self, output_spec, "___NOT_FOUND___") + if spec == "___NOT_FOUND___": + raise ValueError(f"object {self} has no attribute {output_spec}.") + spec = _convert_to_canonical_format(spec) + _validate( + cls_instance=self, + method=func, + data=output_data, + spec=spec, + tag="output", + ) + + if cache and func.__name__ not in self.__checked_output_specs_cache__: + self.__checked_output_specs_cache__[func.__name__] = True + + return output_data + + wrapper.__checked_output_specs__ = True + return wrapper + + return decorator diff --git a/rllib/models/specs/specs_base.py b/rllib/models/specs/specs_base.py index 3d1e71735ec5..6069874eb0f0 100644 --- a/rllib/models/specs/specs_base.py +++ b/rllib/models/specs/specs_base.py @@ -10,11 +10,11 @@ _INVALID_INPUT_POSITIVE = "Dimension {} in ({}) must be positive, got {}" _INVALID_INPUT_INT_DIM = "Dimension {} in ({}) must be integer, got {}" _INVALID_SHAPE = "Expected shape {} but found {}" -_INVALID_TYPE = "Expected tensor type {} but found {}" +_INVALID_TYPE = "Expected data type {} but found {}" @DeveloperAPI -class SpecsAbstract(abc.ABC): +class Spec(abc.ABC): @DeveloperAPI @abc.abstractstaticmethod def validate(self, data: Any) -> None: @@ -29,7 +29,37 @@ def validate(self, data: Any) -> None: @DeveloperAPI -class TensorSpec(SpecsAbstract): +class TypeSpec(Spec): + """A base class that checks the type of the input data. + + Args: + dtype: The type of the object. + + Examples: + >>> spec = TypeSpec(tf.Tensor) + >>> spec.validate(tf.ones((2, 3))) # passes + >>> spec.validate(torch.ones((2, 3))) # ValueError + """ + + def __init__(self, dtype: Type) -> None: + self.dtype = dtype + + @override(Spec) + def validate(self, data: Any) -> None: + if not isinstance(data, self.dtype): + raise ValueError(_INVALID_TYPE.format(self.dtype, type(data))) + + def __eq__(self, other: "TypeSpec") -> bool: + if not isinstance(other, TypeSpec): + return False + return self.dtype == other.dtype + + def __ne__(self, other: "TypeSpec") -> bool: + return not self == other + + +@DeveloperAPI +class TensorSpec(Spec): """A base class that specifies the shape and dtype of a tensor. Args: @@ -125,7 +155,7 @@ def dtype(self) -> Any: """Returns a dtype specifying the tensor dtype.""" return self._dtype - @override(SpecsAbstract) + @override(Spec) def validate(self, tensor: TensorType) -> None: """Checks if the shape and dtype of the tensor matches the specification. diff --git a/rllib/models/specs/specs_dict.py b/rllib/models/specs/specs_dict.py index 897e2bc05187..a4ee97e8d540 100644 --- a/rllib/models/specs/specs_dict.py +++ b/rllib/models/specs/specs_dict.py @@ -1,9 +1,8 @@ -import functools -from typing import Union, Type, Mapping, Any +from typing import Union, Mapping, Any from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.models.specs.specs_base import TensorSpec +from ray.rllib.models.specs.specs_base import Spec _MISSING_KEYS_FROM_SPEC = ( @@ -19,14 +18,13 @@ "{} has type {} (expected type {})." ) -SPEC_LEAF_TYPE = Union[Type, TensorSpec] DATA_TYPE = Union[NestedDict[Any], Mapping[str, Any]] IS_NOT_PROPERTY = "Spec {} must be a property of the class {}." @ExperimentalAPI -class ModelSpec(NestedDict[SPEC_LEAF_TYPE]): +class SpecDict(NestedDict[Spec], Spec): """A NestedDict containing `TensorSpec` and `Types`. It can be used to validate an incoming data against a nested dictionary of specs. @@ -35,7 +33,7 @@ class ModelSpec(NestedDict[SPEC_LEAF_TYPE]): Basic validation: ----------------- - >>> spec_dict = ModelSpec({ + >>> spec_dict = SpecDict({ ... "obs": { ... "arm": TensorSpec("b, dim_arm", dim_arm=64), ... "gripper": TensorSpec("b, dim_grip", dim_grip=12) @@ -92,6 +90,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._keys_set = set(self.keys()) + @override(Spec) def validate( self, data: DATA_TYPE, @@ -121,7 +120,10 @@ def validate( for spec_name, spec in self.items(): data_to_validate = data[spec_name] - if isinstance(spec, TensorSpec): + if spec is None: + continue + + if isinstance(spec, Spec): try: spec.validate(data_to_validate) except ValueError as e: @@ -129,7 +131,7 @@ def validate( f"Mismatch found in data element {spec_name}, " f"which is a TensorSpec: {e}" ) - elif isinstance(spec, (Type, tuple)): + elif isinstance(spec, (type, tuple)): if not isinstance(data_to_validate, spec): raise ValueError( _TYPE_MISMATCH.format( @@ -144,156 +146,4 @@ def validate( @override(NestedDict) def __repr__(self) -> str: - return f"ModelSpec({repr(self._data)})" - - -@ExperimentalAPI -def check_specs( - input_spec: str = "", - output_spec: str = "", - filter: bool = False, - cache: bool = True, - input_exact_match: bool = False, - output_exact_match: bool = False, -): - """A general-purpose check_specs decorator for Neural Network modules. - - This is a stateful decorator - (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to - enforce input/output specs for any instance method that has an argument named - `input_data` in its args and returns a single object. - - It also allows you to filter the input data dictionary to only include those keys - that are specified in the model specs. It also allows you to cache the validation - to make sure the spec is only validated once in the entire lifetime of the instance. - - Examples (See more exmaples in ../tests/test_specs_dict.py): - - >>> class MyModel(nn.Module): - ... def input_spec(self): - ... return ModelSpec({"obs": TensorSpec("b, d", d=64)}) - ... - ... @check_specs(input_spec="input_spec") - ... def forward(self, input_data, return_loss=False): - ... ... - ... output_dict = ... - ... return output_dict - - >>> model = MyModel() - >>> model.forward({"obs": torch.randn(32, 64)}) # No error - >>> model.forward({"obs": torch.randn(32, 32)}) # raises ValueError - - Args: - func: The instance method to decorate. It should be a callable that takes - `self` as the first argument, `input_data` as the second argument and any - other keyword argument thereafter. It should return a single object - (i.e. not a tuple). - input_spec: `self` should have an instance method whose name matches the string - in input_spec and returns the `ModelSpec`, `TensorSpec`, or simply the - `Type` that the `input_data` should comply with. - output_spec: `self` should have an instance method whose name matches the - string in output_spec and returns the spec that the output should comply - with. - filter: If True, and `input_data` is a nested dict the `input_data` will be - filtered by its corresponding spec tree structure and then passed into the - implemented function to make sure user is not confounded with unnecessary - data. - cache: If True, only checks the input/output validation for the first time the - instance method is called. - input_exact_match: If True, the input data (should be a nested dict) must match - the spec exactly. Otherwise, the data is validated as long as it contains - at least the elements of the spec, but can contain more entries. - output_exact_match: If True, the output data (should be a nested dict) must - match the spec exactly. Otherwise, the data is validated as long as it - contains at least the elements of the spec, but can contain more entries. - - Returns: - A wrapped instance method. In case of `cache=True`, after the first invokation - of the decorated method, the intance will have `__checked_specs_cache__` - attribute that store which method has been invoked at least once. This is a - special attribute that can be used for the cache itself. The wrapped class - method also has a special attribute `__checked_specs__` that marks the method as - decorated. - """ - - if not input_spec and not output_spec: - raise ValueError("At least one of input_spec or output_spec must be provided.") - - def decorator(func): - @functools.wraps(func) - def wrapper(self, input_data, **kwargs): - if cache and not hasattr(self, "__checked_specs_cache__"): - self.__checked_specs_cache__ = {} - - def should_validate(): - return not cache or func.__name__ not in self.__checked_specs_cache__ - - def validate(data, spec, exact_match, tag="data"): - is_mapping = isinstance(spec, ModelSpec) - is_tensor = isinstance(spec, TensorSpec) - cache_miss = should_validate() - - if is_mapping: - if not isinstance(data, Mapping): - raise ValueError( - f"{tag} must be a Mapping, got {type(data).__name__}" - ) - if cache_miss or filter: - data = NestedDict(data) - - if cache_miss: - try: - if is_mapping: - spec.validate(data, exact_match=exact_match) - elif is_tensor: - spec.validate(data) - except ValueError as e: - raise ValueError( - f"{tag} spec validation failed on " - f"{self.__class__.__name__}.{func.__name__}, {e}." - ) - - if not (is_tensor or is_mapping): - if not isinstance(data, spec): - raise ValueError( - f"Input spec validation failed on " - f"{self.__class__.__name__}.{func.__name__}, " - f"expected {spec.__name__}, got " - f"{type(data).__name__}." - ) - return data - - input_data_ = input_data - if input_spec: - input_spec_ = getattr(self, input_spec) - - input_data_ = validate( - input_data, - input_spec_, - exact_match=input_exact_match, - tag="input_data", - ) - - if filter and isinstance(input_spec_, (ModelSpec, TensorSpec)): - # filtering should happen regardless of cache - input_data_ = input_data_.filter(input_spec_) - - output_data = func(self, input_data_, **kwargs) - if output_spec: - output_spec_ = getattr(self, output_spec) - validate( - output_data, - output_spec_, - exact_match=output_exact_match, - tag="output_data", - ) - - if cache and func.__name__ not in self.__checked_specs_cache__: - self.__checked_specs_cache__[func.__name__] = True - - return output_data - - wrapper.__checked_specs__ = True - return wrapper - - return decorator + return f"SpecDict({repr(self._data)})" diff --git a/rllib/models/specs/tests/test_check_specs.py b/rllib/models/specs/tests/test_check_specs.py index 4b83a727549c..60a026c15a61 100644 --- a/rllib/models/specs/tests/test_check_specs.py +++ b/rllib/models/specs/tests/test_check_specs.py @@ -5,11 +5,16 @@ from typing import Dict, Any, Type import unittest -from ray.rllib.models.specs.specs_base import TensorSpec -from ray.rllib.models.specs.specs_dict import ModelSpec, check_specs +from ray.rllib.models.specs.specs_base import TensorSpec, TypeSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.utils.annotations import override from ray.rllib.utils.nested_dict import NestedDict +from ray.rllib.models.specs.checker import ( + _convert_to_canonical_format, + check_input_specs, + check_output_specs, +) ONLY_ONE_KEY_ALLOWED = "Only one key is allowed in the data dict." @@ -19,16 +24,15 @@ class AbstractInterfaceClass(abc.ABC): input/output constraints.""" @property - def input_spec(self) -> ModelSpec: + def input_spec(self) -> SpecDict: pass @property - def output_spec(self) -> ModelSpec: + def output_spec(self) -> SpecDict: pass - @check_specs( - input_spec="input_spec", output_spec="output_spec", filter=True, cache=False - ) + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) def check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: return self._check_input_and_output(input_dict) @@ -36,7 +40,7 @@ def check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: def _check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: pass - @check_specs(input_spec="input_spec", filter=True, cache=False) + @check_input_specs("input_spec", filter=True, cache=False) def check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: """should not override this method""" return self._check_only_input(input_dict) @@ -45,7 +49,7 @@ def check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: def _check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: pass - @check_specs(output_spec="output_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) def check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: """should not override this method""" return self._check_only_output(input_dict) @@ -54,18 +58,16 @@ def check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: def _check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: pass - @check_specs( - input_spec="input_spec", output_spec="output_spec", filter=True, cache=True - ) + @check_input_specs("input_spec", filter=True, cache=True) + @check_output_specs("output_spec", cache=True) def check_input_and_output_with_cache( self, input_dict: Dict[str, Any] ) -> Dict[str, Any]: """should not override this method""" return self._check_input_and_output(input_dict) - @check_specs( - input_spec="input_spec", output_spec="output_spec", filter=False, cache=False - ) + @check_input_specs("input_spec", filter=False, cache=False) + @check_output_specs("output_spec", cache=False) def check_input_and_output_wo_filter(self, input_dict) -> Dict[str, Any]: """should not override this method""" return self._check_input_and_output(input_dict) @@ -75,12 +77,12 @@ class InputNumberOutputFloat(AbstractInterfaceClass): """This is an abstract class enforcing a contraint on input/output""" @property - def input_spec(self) -> ModelSpec: - return ModelSpec({"input": (float, int)}) + def input_spec(self) -> SpecDict: + return SpecDict({"input": (float, int)}) @property - def output_spec(self) -> ModelSpec: - return ModelSpec({"output": float}) + def output_spec(self) -> SpecDict: + return SpecDict({"output": float}) class CorrectImplementation(InputNumberOutputFloat): @@ -236,7 +238,7 @@ class ClassWithTensorSpec: def input_spec1(self) -> TensorSpec: return TorchTensorSpec("b, h", h=4) - @check_specs(input_spec="input_spec1", cache=False) + @check_input_specs("input_spec1", cache=False) def forward(self, input_data) -> Any: return input_data @@ -256,11 +258,11 @@ class ClassWithTypeSpec: def output_spec(self) -> Type: return SpecialOutputType - @check_specs(output_spec="output_spec", cache=False) + @check_output_specs("output_spec", cache=False) def forward_pass(self, input_data) -> Any: return SpecialOutputType() - @check_specs(output_spec="output_spec", cache=False) + @check_output_specs("output_spec", cache=False) def forward_fail(self, input_data) -> Any: return WrongOutputType() @@ -269,6 +271,47 @@ def forward_fail(self, input_data) -> Any: self.assertIsInstance(output, SpecialOutputType) self.assertRaises(ValueError, lambda: module.forward_fail(torch.rand(2, 3))) + def test_convert_to_canonical_format(self): + + # Case: input is a list of strs + self.assertDictEqual( + _convert_to_canonical_format(["foo", "bar"]).asdict(), + SpecDict({"foo": None, "bar": None}).asdict(), + ) + + # Case: input is a list of strs and nested strs + self.assertDictEqual( + _convert_to_canonical_format(["foo", ("bar", "jar")]).asdict(), + SpecDict({"foo": None, "bar": {"jar": None}}).asdict(), + ) + + # Case: input is a Nested Mapping + returned = _convert_to_canonical_format( + {"foo": {"bar": TorchTensorSpec("b")}, "jar": {"tar": int, "car": None}} + ) + self.assertIsInstance(returned, SpecDict) + self.assertDictEqual( + returned.asdict(), + SpecDict( + { + "foo": {"bar": TorchTensorSpec("b")}, + "jar": {"tar": TypeSpec(int), "car": None}, + } + ).asdict(), + ) + + # Case: input is a SpecDict already + returned = _convert_to_canonical_format( + SpecDict({"foo": {"bar": TorchTensorSpec("b")}, "jar": {"tar": int}}) + ) + self.assertIsInstance(returned, SpecDict) + self.assertDictEqual( + returned.asdict(), + SpecDict( + {"foo": {"bar": TorchTensorSpec("b")}, "jar": {"tar": TypeSpec(int)}} + ).asdict(), + ) + if __name__ == "__main__": import pytest diff --git a/rllib/models/specs/tests/test_model_spec.py b/rllib/models/specs/tests/test_model_spec.py deleted file mode 100644 index 1239edc039b9..000000000000 --- a/rllib/models/specs/tests/test_model_spec.py +++ /dev/null @@ -1,102 +0,0 @@ -import unittest -import numpy as np - -from ray.rllib.models.specs.specs_np import NPTensorSpec -from ray.rllib.models.specs.specs_dict import ModelSpec - - -class TypeClass1: - pass - - -class TypeClass2: - pass - - -class TestModelSpec(unittest.TestCase): - def test_basic_validation(self): - - h1, h2 = 3, 4 - spec_1 = ModelSpec( - { - "out_tensor_1": NPTensorSpec("b, h", h=h1), - "out_tensor_2": NPTensorSpec("b, h", h=h2), - "out_class_1": TypeClass1, - } - ) - - # test validation. - tensor_1 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - "out_class_1": TypeClass1(), - } - - spec_1.validate(tensor_1) - - # test missing key in specs - tensor_2 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - } - - self.assertRaises(ValueError, lambda: spec_1.validate(tensor_2)) - - # test missing key in data - tensor_3 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - "out_class_1": TypeClass1(), - "out_class_2": TypeClass1(), - } - - # this should pass because exact_match is False - spec_1.validate(tensor_3, exact_match=False) - - # this should fail because exact_match is True - self.assertRaises( - ValueError, lambda: spec_1.validate(tensor_3, exact_match=True) - ) - - # raise type mismatch - tensor_4 = { - "out_tensor_1": np.random.randn(2, h1), - "out_tensor_2": np.random.randn(2, h2), - "out_class_1": TypeClass2(), - } - - self.assertRaises(ValueError, lambda: spec_1.validate(tensor_4)) - - # test nested specs - spec_2 = ModelSpec( - { - "encoder": { - "input": NPTensorSpec("b, h", h=h1), - "output": NPTensorSpec("b, h", h=h2), - }, - "decoder": { - "input": NPTensorSpec("b, h", h=h2), - "output": NPTensorSpec("b, h", h=h1), - }, - } - ) - - tensor_5 = { - "encoder": { - "input": np.random.randn(2, h1), - "output": np.random.randn(2, h2), - }, - "decoder": { - "input": np.random.randn(2, h2), - "output": np.random.randn(2, h1), - }, - } - - spec_2.validate(tensor_5) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/specs/tests/test_spec_dict.py b/rllib/models/specs/tests/test_spec_dict.py new file mode 100644 index 000000000000..f247804c55e1 --- /dev/null +++ b/rllib/models/specs/tests/test_spec_dict.py @@ -0,0 +1,192 @@ +import unittest +import numpy as np + +from ray.rllib.models.specs.specs_np import NPTensorSpec +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import check_input_specs + + +class TypeClass1: + pass + + +class TypeClass2: + pass + + +class TestSpecDict(unittest.TestCase): + def test_basic_validation(self): + """Tests basic validation of SpecDict.""" + + h1, h2 = 3, 4 + spec_1 = SpecDict( + { + "out_tensor_1": NPTensorSpec("b, h", h=h1), + "out_tensor_2": NPTensorSpec("b, h", h=h2), + "out_class_1": TypeClass1, + } + ) + + # test validation. + tensor_1 = { + "out_tensor_1": np.random.randn(2, h1), + "out_tensor_2": np.random.randn(2, h2), + "out_class_1": TypeClass1(), + } + + spec_1.validate(tensor_1) + + # test missing key in tensor + tensor_2 = { + "out_tensor_1": np.random.randn(2, h1), + "out_tensor_2": np.random.randn(2, h2), + } + + self.assertRaises(ValueError, lambda: spec_1.validate(tensor_2)) + + # test additional key in tensor (not mentioned in spec) + tensor_3 = { + "out_tensor_1": np.random.randn(2, h1), + "out_tensor_2": np.random.randn(2, h2), + "out_class_1": TypeClass1(), + "out_class_2": TypeClass1(), + } + + # this should pass because exact_match is False + spec_1.validate(tensor_3, exact_match=False) + + # this should fail because exact_match is True + self.assertRaises( + ValueError, lambda: spec_1.validate(tensor_3, exact_match=True) + ) + + # raise type mismatch + tensor_4 = { + "out_tensor_1": np.random.randn(2, h1), + "out_tensor_2": np.random.randn(2, h2), + "out_class_1": TypeClass2(), + } + + self.assertRaises(ValueError, lambda: spec_1.validate(tensor_4)) + + # test nested specs + spec_2 = SpecDict( + { + "encoder": { + "input": NPTensorSpec("b, h", h=h1), + "output": NPTensorSpec("b, h", h=h2), + }, + "decoder": { + "input": NPTensorSpec("b, h", h=h2), + "output": NPTensorSpec("b, h", h=h1), + }, + } + ) + + tensor_5 = { + "encoder": { + "input": np.random.randn(2, h1), + "output": np.random.randn(2, h2), + }, + "decoder": { + "input": np.random.randn(2, h2), + "output": np.random.randn(2, h1), + }, + } + + spec_2.validate(tensor_5) + + def test_spec_check_integration(self): + """Tests the integration of SpecDict with the check_input_specs.""" + + class Model: + @property + def nested_key_spec(self): + return ["a", ("b", "c"), ("d",), ("e", "f"), ("e", "g")] + + @property + def dict_key_spec_with_none_leaves(self): + return { + "a": None, + "b": { + "c": None, + }, + "d": None, + "e": { + "f": None, + "g": None, + }, + } + + @property + def spec_with_type_and_tensor_leaves(self): + return {"a": TypeClass1, "b": NPTensorSpec("b, h", h=3)} + + @check_input_specs("nested_key_spec") + def forward_nested_key(self, input_dict): + return input_dict + + @check_input_specs("dict_key_spec_with_none_leaves") + def forward_dict_key_with_none_leaves(self, input_dict): + return input_dict + + @check_input_specs("spec_with_type_and_tensor_leaves") + def forward_spec_with_type_and_tensor_leaves(self, input_dict): + return input_dict + + model = Model() + + # test nested key spec + input_dict_1 = { + "a": 1, + "b": { + "c": 2, + "foo": 3, + }, + "d": 3, + "e": { + "f": 4, + "g": 5, + }, + } + + # should run fine + model.forward_nested_key(input_dict_1) + model.forward_dict_key_with_none_leaves(input_dict_1) + + # test missing key + input_dict_2 = { + "a": 1, + "b": { + "c": 2, + "foo": 3, + }, + "d": 3, + "e": { + "f": 4, + }, + } + + self.assertRaises(ValueError, lambda: model.forward_nested_key(input_dict_2)) + + self.assertRaises( + ValueError, lambda: model.forward_dict_key_with_none_leaves(input_dict_2) + ) + + input_dict_3 = { + "a": TypeClass1(), + "b": np.array([1, 2, 3]), + } + + # should raise shape mismatch + self.assertRaises( + ValueError, + lambda: model.forward_spec_with_type_and_tensor_leaves(input_dict_3), + ) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/specs/typing.py b/rllib/models/specs/typing.py new file mode 100644 index 000000000000..26ef387e175e --- /dev/null +++ b/rllib/models/specs/typing.py @@ -0,0 +1,11 @@ +from typing import Union, Type, Tuple, Optional, List, TYPE_CHECKING + +if TYPE_CHECKING: + from ray.rllib.utils.nested_dict import NestedDict + from ray.rllib.models.specs.specs_base import Spec + + +NestedKeys = List[Union[str, Tuple[str, ...]]] +Constraint = Union[Type, Tuple[Type, ...], "Spec"] +# Either a flat list of nested keys or a tree of constraints +SpecType = Union[NestedKeys, "NestedDict[Optional[Constraint]]"] diff --git a/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py b/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py index 49b02a29daf5..0f3743be2f43 100644 --- a/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py +++ b/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py @@ -3,7 +3,7 @@ import torch from ray.rllib.models.configs.encoder import VectorEncoderConfig -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.utils.nested_dict import NestedDict @@ -11,25 +11,25 @@ class TestConfig(unittest.TestCase): def test_error_no_feature_dim(self): """Ensure we error out if we don't know the input dim""" - input_spec = ModelSpec({"bork": TorchTensorSpec("a, b, c")}) + input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c")}) c = VectorEncoderConfig() with self.assertRaises(AssertionError): c.build(input_spec) def test_default_build(self): """Test building with the default config""" - input_spec = ModelSpec({"bork": TorchTensorSpec("a, b, c", c=3)}) + input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) c = VectorEncoderConfig() c.build(input_spec) def test_nonlinear_final_build(self): - input_spec = ModelSpec({"bork": TorchTensorSpec("a, b, c", c=3)}) + input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) c = VectorEncoderConfig(final_activation="relu") c.build(input_spec) def test_default_forward(self): """Test the default config/model _forward implementation""" - input_spec = ModelSpec({"bork": TorchTensorSpec("a, b, c", c=3)}) + input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) c = VectorEncoderConfig() m = c.build(input_spec) inputs = NestedDict({"bork": torch.rand((2, 4, 3))}) @@ -41,7 +41,7 @@ def test_two_inputs_forward(self): """Test the default model when we have two items in the input_spec. These two items will be concatenated and fed thru the mlp.""" """Test the default config/model _forward implementation""" - input_spec = ModelSpec( + input_spec = SpecDict( { "bork": TorchTensorSpec("a, b, c", c=3), "dork": TorchTensorSpec("x, y, z", z=5), @@ -58,7 +58,7 @@ def test_two_inputs_forward(self): self.assertEqual(outputs[c.output_key].shape[:-1], (2, 4)) def test_deep_build(self): - input_spec = ModelSpec({"bork": TorchTensorSpec("a, b, c", c=3)}) + input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) c = VectorEncoderConfig() c.build(input_spec) diff --git a/rllib/models/torch/encoders/vector.py b/rllib/models/torch/encoders/vector.py index 2feed58dd11d..91ef65d71f44 100644 --- a/rllib/models/torch/encoders/vector.py +++ b/rllib/models/torch/encoders/vector.py @@ -4,7 +4,7 @@ import torch from torch import nn -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.torch.model import TorchModel from ray.rllib.models.utils import get_activation_fn from ray.rllib.utils.nested_dict import NestedDict @@ -23,16 +23,16 @@ class TorchVectorEncoder(TorchModel): """ @property - def input_spec(self) -> ModelSpec: + def input_spec(self) -> SpecDict: return self._input_spec @property - def output_spec(self) -> ModelSpec: + def output_spec(self) -> SpecDict: return self._output_spec def __init__( self, - input_spec: ModelSpec, + input_spec: SpecDict, config: "VectorEncoderConfig", ): super().__init__(config=config) diff --git a/rllib/models/utils.py b/rllib/models/utils.py index f4aa6e588ea6..d79fa7562a03 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -1,18 +1,18 @@ from typing import Optional from ray.rllib.models.specs.specs_base import TensorSpec -from ray.rllib.models.specs.specs_dict import ModelSpec +from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch @ExperimentalAPI def input_to_output_spec( - input_spec: ModelSpec, + input_spec: SpecDict, num_input_feature_dims: int, output_key: str, output_feature_spec: TensorSpec, -) -> ModelSpec: +) -> SpecDict: """Convert an input spec to an output spec, based on a module. Drops the feature dimension(s) from an input_spec, replacing them with @@ -20,7 +20,7 @@ def input_to_output_spec( Examples: input_to_output_spec( - input_spec=ModelSpec({ + input_spec=SpecDict({ "bork": "batch, time, feature0", "dork": "batch, time, feature1" }, feature0=2, feature1=3 @@ -31,10 +31,10 @@ def input_to_output_spec( ) will return: - ModelSpec({"outer_product": "batch, time, row, col", row=2, col=3}) + SpecDict({"outer_product": "batch, time, row, col", row=2, col=3}) input_to_output_spec( - input_spec=ModelSpec({ + input_spec=SpecDict({ "bork": "batch, time, h, w, c", }, h=32, w=32, c=3, ), @@ -44,11 +44,11 @@ def input_to_output_spec( ) will return: - ModelSpec({"latent_image_representation": "batch, time, feature"}, feature=128) + SpecDict({"latent_image_representation": "batch, time, feature"}, feature=128) Args: - input_spec: ModelSpec describing input to a specified module + input_spec: SpecDict describing input to a specified module num_input_dims: How many feature dimensions the module will process. E.g. a linear layer will only process the last dimension (1), while a CNN might process the last two dimensions (2) @@ -57,7 +57,7 @@ def input_to_output_spec( specified module Returns: - A ModelSpec based on the input_spec, with the trailing dimensions replaced + A SpecDict based on the input_spec, with the trailing dimensions replaced by the output_feature_spec """ @@ -72,7 +72,7 @@ def input_to_output_spec( key = list(input_spec.keys())[0] batch_spec = input_spec[key].rdrop(num_input_feature_dims) full_spec = batch_spec.append(output_feature_spec) - return ModelSpec({output_key: full_spec}) + return SpecDict({output_key: full_spec}) @DeveloperAPI diff --git a/rllib/utils/nested_dict.py b/rllib/utils/nested_dict.py index 903c60716e59..1e4d308d1ef5 100644 --- a/rllib/utils/nested_dict.py +++ b/rllib/utils/nested_dict.py @@ -1,5 +1,5 @@ """Custom NestedDict datatype.""" - +from collections import abc import itertools from typing import ( AbstractSet, @@ -121,10 +121,10 @@ def __init__( x = x or {} if isinstance(x, NestedDict): self._data = x._data - elif isinstance(x, Mapping): + elif isinstance(x, abc.Mapping): for k in x: self[k] = x[k] - elif isinstance(x, Iterable): + elif isinstance(x, abc.Iterable): for k, v in x: self[k] = v else: