From b22a6d881fc3fbe26f7cb43d981d2526b27600c2 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Sun, 23 Oct 2022 23:39:12 -0700 Subject: [PATCH 01/11] 1. created check_specs decorator 2. updated unittests 3. refactored the names a little bit for generality Signed-off-by: Kourosh Hakhamaneshi --- rllib/BUILD | 10 +- rllib/models/specs/specs_base.py | 8 +- rllib/models/specs/specs_dict.py | 188 +++++++++++- rllib/models/specs/specs_jax.py | 12 +- rllib/models/specs/specs_np.py | 12 +- rllib/models/specs/specs_tf.py | 12 +- rllib/models/specs/specs_torch.py | 12 +- rllib/models/specs/tests/test_check_specs.py | 268 ++++++++++++++++++ ...ensor_specs_dict.py => test_model_spec.py} | 22 +- ...st_tensor_specs.py => test_tensor_spec.py} | 6 +- 10 files changed, 488 insertions(+), 62 deletions(-) create mode 100644 rllib/models/specs/tests/test_check_specs.py rename rllib/models/specs/tests/{test_tensor_specs_dict.py => test_model_spec.py} (79%) rename rllib/models/specs/tests/{test_tensor_specs.py => test_tensor_spec.py} (95%) diff --git a/rllib/BUILD b/rllib/BUILD index 28493db7da75..2c109e091eeb 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1801,10 +1801,10 @@ py_test( # Test Tensor specs py_test( - name = "test_tensor_specs", + name = "test_tensor_spec", tags = ["team:rllib", "models"], size = "small", - srcs = ["models/specs/tests/test_tensor_specs.py"] + srcs = ["models/specs/tests/test_tensor_spec.py"] ) # test abstract base models @@ -1823,12 +1823,12 @@ py_test( srcs = ["models/tests/test_torch_model.py"] ) -# test ModelSpecDict +# test ModelSpec py_test( - name = "test_tensor_specs_dict", + name = "test_model_spec", tags = ["team:rllib", "models"], size = "small", - srcs = ["models/specs/tests/test_tensor_specs_dict.py"] + srcs = ["models/specs/tests/test_model_spec.py"] ) diff --git a/rllib/models/specs/specs_base.py b/rllib/models/specs/specs_base.py index 81a94f00dd74..7b7498a68457 100644 --- a/rllib/models/specs/specs_base.py +++ b/rllib/models/specs/specs_base.py @@ -28,7 +28,7 @@ def validate(self, data: Any) -> None: @DeveloperAPI -class TensorSpecs(SpecsAbstract): +class TensorSpec(SpecsAbstract): """A base class that specifies the shape and dtype of a tensor. Args: @@ -230,11 +230,11 @@ def _validate_shape_vals( def __repr__(self) -> str: return f"TensorSpec(shape={tuple(self.shape)}, dtype={self.dtype})" - def __eq__(self, other: "TensorSpecs") -> bool: + def __eq__(self, other: "TensorSpec") -> bool: """Checks if the shape and dtype of two specs are equal.""" - if not isinstance(other, TensorSpecs): + if not isinstance(other, TensorSpec): return False return self.shape == other.shape and self.dtype == other.dtype - def __ne__(self, other: "TensorSpecs") -> bool: + def __ne__(self, other: "TensorSpec") -> bool: return not self == other diff --git a/rllib/models/specs/specs_dict.py b/rllib/models/specs/specs_dict.py index b46246c87fa0..51e170f0a93d 100644 --- a/rllib/models/specs/specs_dict.py +++ b/rllib/models/specs/specs_dict.py @@ -1,29 +1,32 @@ +import functools from typing import Union, Type, Mapping, Any from ray.rllib.utils.annotations import override from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.models.specs.specs_base import TensorSpecs +from ray.rllib.models.specs.specs_base import TensorSpec _MISSING_KEYS_FROM_SPEC = ( "The data dict does not match the model specs. Keys {} are " - "in the data dict but not on the given spec dict, and exact_match is set to True." + "in the data dict but not on the given spec dict, and exact_match is set to True" ) _MISSING_KEYS_FROM_DATA = ( "The data dict does not match the model specs. Keys {} are " - "in the spec dict but not on the data dict." + "in the spec dict but not on the data dict. Data keys are {}" ) _TYPE_MISMATCH = ( "The data does not match the spec. The data element " "{} has type {} (expected type {})." ) -SPEC_LEAF_TYPE = Union[Type, TensorSpecs] +SPEC_LEAF_TYPE = Union[Type, TensorSpec] DATA_TYPE = Union[NestedDict[Any], Mapping[str, Any]] +IS_NOT_PROPERTY = "Spec {} must be a property of the class {}." -class ModelSpecDict(NestedDict[SPEC_LEAF_TYPE]): - """A NestedDict containing `TensorSpecs` and `Types`. + +class ModelSpec(NestedDict[SPEC_LEAF_TYPE]): + """A NestedDict containing `TensorSpec` and `Types`. It can be used to validate an incoming data against a nested dictionary of specs. @@ -31,12 +34,12 @@ class ModelSpecDict(NestedDict[SPEC_LEAF_TYPE]): Basic validation: ----------------- - >>> spec_dict = ModelSpecDict({ + >>> spec_dict = ModelSpec({ ... "obs": { - ... "arm": TensorSpecs("b, d_a", d_a=64), - ... "gripper": TensorSpecs("b, d_g", d_g=12) + ... "arm": TensorSpec("b, d_a", d_a=64), + ... "gripper": TensorSpec("b, d_g", d_g=12) ... }, - ... "action": TensorSpecs("b, d_a", h=12), + ... "action": TensorSpec("b, d_a", h=12), ... "action_dist": torch.distributions.Categorical ... }) @@ -91,7 +94,7 @@ def __init__(self, *args, **kwargs): def validate( self, data: DATA_TYPE, - exact_match: bool = True, + exact_match: bool = False, ) -> None: """Checks whether the data matches the spec. @@ -107,7 +110,9 @@ def validate( data_keys_set = set(data.keys()) missing_keys = self._keys_set.difference(data_keys_set) if missing_keys: - raise ValueError(_MISSING_KEYS_FROM_DATA.format(missing_keys)) + raise ValueError( + _MISSING_KEYS_FROM_DATA.format(missing_keys, data_keys_set) + ) if exact_match: data_spec_missing_keys = data_keys_set.difference(self._keys_set) if data_spec_missing_keys: @@ -115,8 +120,14 @@ def validate( for spec_name, spec in self.items(): data_to_validate = data[spec_name] - if isinstance(spec, TensorSpecs): - spec.validate(data_to_validate) + if isinstance(spec, TensorSpec): + try: + spec.validate(data_to_validate) + except ValueError as e: + raise ValueError( + f"Mismatch found in data element {spec_name}, " + f"which is a TensorSpec: {e}" + ) elif isinstance(spec, Type): if not isinstance(data_to_validate, spec): raise ValueError( @@ -127,4 +138,151 @@ def validate( @override(NestedDict) def __repr__(self) -> str: - return f"ModelSpecDict({repr(self._data)})" + return f"ModelSpec({repr(self._data)})" + + +def check_specs( + input_spec: str = "", + output_spec: str = "", + filter: bool = True, + cache: bool = False, + input_exact_match: bool = False, + output_exact_match: bool = False, +): + """A general-purpose [stateful decorator] + (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to + enforce input/output specs for any instance method that has `input_data` in input + args and returns and a single object. + + + It adds the ability to filter the input data if it is a mappinga to only contain + the keys in the spec. It can also cache the validation to make sure the spec is + only validated once in the lifetime of the instance. + + Examples (See more exmaples in ../tests/test_specs_dict.py): + + >>> class MyModel(nn.Module): + ... def input_spec(self): + ... return ModelSpec({"obs": TensorSpec("b, d", d=64)}) + ... + ... @check_specs(input_spec="input_spec") + ... def forward(self, input_data, return_loss=False): + ... ... + ... output_dict = ... + ... return output_dict + + >>> model = MyModel() + >>> model.forward({"obs": torch.randn(32, 64)}) # No error + >>> model.forward({"obs": torch.randn(32, 32)}) # raises ValueError + + Args: + func: The instance method to decorate. It should be a callable that takes + `self` as the first argument, `input_data` as the second argument and any + other keyword argument thereafter. It should return a single object. + input_spec: `self` should have an instance method that is named input_spec and + returns the `ModelSpec`, `TensorSpec`, or simply the `Type` that the + `input_data` should comply with. + output_spec: `self` should have an instance method that is named output_spec + and returns the spec that the output should comply with. + filter: If True, and `input_data` is a nested dict the `input_data` will be + filtered by its corresponding spec tree structure and then passed into the + implemented function to make sure user is not confounded with unnecessary + data. + cache: If True, only checks the input/output validation for the first time the + instance method is called. + input_exact_match: If True, the input data (should be a nested dict) must match + the spec exactly. Otherwise, the data is validated as long as it contains + at least the elements of the spec, but can contain more entries. + output_exact_match: If True, the output data (should be a nested dict) must + match the spec exactly. Otherwise, the data is validated as long as it + contains at least the elements of the spec, but can contain more entries. + + Returns: + A wrapped instance method. In case of `cache=True`, after the first invokation + of the decorated method, the intance will have `__checked_specs_cache__` + attribute that store which method has been invoked at least once. This is a + special attribute can be used for the cache itself. The wrapped class method + also has a special attribute `__checked_specs__` that marks the method as + decorated. + """ + + if not input_spec and not output_spec: + raise ValueError("At least one of input_spec or output_spec must be provided.") + + def decorator(func): + @functools.wraps(func) + def wrapper(self, input_data, **kwargs): + + if cache and not hasattr(self, "__checked_specs_cache__"): + self.__checked_specs_cache__ = {} + + def should_validate(): + return not cache or func.__name__ not in self.__checked_specs_cache__ + + def validate(data, spec, exact_match, tag="data"): + is_mapping = isinstance(spec, ModelSpec) + is_tensor = isinstance(spec, TensorSpec) + + if is_mapping: + if not isinstance(data, Mapping): + raise ValueError( + f"{tag} must be a Mapping, got {type(data).__name__}" + ) + data = NestedDict(data) + + if should_validate(): + try: + if is_mapping: + spec.validate(data, exact_match=exact_match) + elif is_tensor: + spec.validate(data) + except ValueError as e: + raise ValueError( + f"{tag} spec validation failed on " + f"{self.__class__.__name__}.{func.__name__}, {e}." + ) + + if not (is_tensor or is_mapping): + if not isinstance(data, spec): + raise ValueError( + f"Input spec validation failed on " + f"{self.__class__.__name__}.{func.__name__}, " + f"expected {spec.__name__}, got " + f"{type(data).__name__}." + ) + return data + + input_data_ = input_data + if input_spec: + input_spec_ = getattr(self, input_spec)() + + input_data_ = validate( + input_data, + input_spec_, + exact_match=input_exact_match, + tag="input_data", + ) + + if filter and isinstance(input_spec_, ModelSpec): + # filtering should happen regardless of cache + input_data_ = input_data_.filter(input_spec_) + + output_data = func(self, input_data_, **kwargs) + if output_spec: + output_spec_ = getattr(self, output_spec)() + validate( + output_data, + output_spec_, + exact_match=output_exact_match, + tag="output_data", + ) + + if cache: + self.__checked_specs_cache__[func.__name__] = True + + return output_data + + wrapper.__check_specs__ = True + return wrapper + + return decorator diff --git a/rllib/models/specs/specs_jax.py b/rllib/models/specs/specs_jax.py index ec9a7b672acc..58412cde8b3e 100644 --- a/rllib/models/specs/specs_jax.py +++ b/rllib/models/specs/specs_jax.py @@ -2,7 +2,7 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.framework import try_import_jax -from ray.rllib.models.specs.specs_base import TensorSpecs +from ray.rllib.models.specs.specs_base import TensorSpec jax, _ = try_import_jax() jnp = None @@ -11,20 +11,20 @@ @DeveloperAPI -class JAXSpecs(TensorSpecs): - @override(TensorSpecs) +class JAXTensorSpec(TensorSpec): + @override(TensorSpec) def get_type(cls) -> Type: return jnp.ndarray - @override(TensorSpecs) + @override(TensorSpec) def get_shape(self, tensor: jnp.ndarray) -> Tuple[int]: return tuple(tensor.shape) - @override(TensorSpecs) + @override(TensorSpec) def get_dtype(self, tensor: jnp.ndarray) -> Any: return tensor.dtype - @override(TensorSpecs) + @override(TensorSpec) def _full( self, shape: Tuple[int], fill_value: Union[float, int] = 0 ) -> jnp.ndarray: diff --git a/rllib/models/specs/specs_np.py b/rllib/models/specs/specs_np.py index 262f3b6cfca5..8fefec98032a 100644 --- a/rllib/models/specs/specs_np.py +++ b/rllib/models/specs/specs_np.py @@ -2,23 +2,23 @@ import numpy as np from ray.rllib.utils.annotations import DeveloperAPI, override -from ray.rllib.models.specs.specs_base import TensorSpecs +from ray.rllib.models.specs.specs_base import TensorSpec @DeveloperAPI -class NPSpecs(TensorSpecs): - @override(TensorSpecs) +class NPSpec(TensorSpec): + @override(TensorSpec) def get_type(cls) -> Type: return np.ndarray - @override(TensorSpecs) + @override(TensorSpec) def get_shape(self, tensor: np.ndarray) -> Tuple[int]: return tuple(tensor.shape) - @override(TensorSpecs) + @override(TensorSpec) def get_dtype(self, tensor: np.ndarray) -> Any: return tensor.dtype - @override(TensorSpecs) + @override(TensorSpec) def _full(self, shape: Tuple[int], fill_value: Union[float, int] = 0) -> np.ndarray: return np.full(shape, fill_value, dtype=self.dtype) diff --git a/rllib/models/specs/specs_tf.py b/rllib/models/specs/specs_tf.py index 6c45754b8a75..17d83871fb79 100644 --- a/rllib/models/specs/specs_tf.py +++ b/rllib/models/specs/specs_tf.py @@ -2,26 +2,26 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.framework import try_import_tf -from ray.rllib.models.specs.specs_base import TensorSpecs +from ray.rllib.models.specs.specs_base import TensorSpec _, tf, tfv = try_import_tf() @DeveloperAPI -class TFSpecs(TensorSpecs): - @override(TensorSpecs) +class TFSpecs(TensorSpec): + @override(TensorSpec) def get_type(cls) -> Type: return tf.Tensor - @override(TensorSpecs) + @override(TensorSpec) def get_shape(self, tensor: tf.Tensor) -> Tuple[int]: return tuple(tensor.shape) - @override(TensorSpecs) + @override(TensorSpec) def get_dtype(self, tensor: tf.Tensor) -> Any: return tensor.dtype - @override(TensorSpecs) + @override(TensorSpec) def _full(self, shape: Tuple[int], fill_value: Union[float, int] = 0) -> tf.Tensor: if self.dtype: return tf.ones(shape, dtype=self.dtype) * fill_value diff --git a/rllib/models/specs/specs_torch.py b/rllib/models/specs/specs_torch.py index 627e4d435022..e19e68ff7816 100644 --- a/rllib/models/specs/specs_torch.py +++ b/rllib/models/specs/specs_torch.py @@ -2,27 +2,27 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.models.specs.specs_base import TensorSpecs +from ray.rllib.models.specs.specs_base import TensorSpec torch, _ = try_import_torch() @DeveloperAPI -class TorchSpecs(TensorSpecs): - @override(TensorSpecs) +class TorchTensorSpec(TensorSpec): + @override(TensorSpec) def get_type(cls) -> Type: return torch.Tensor - @override(TensorSpecs) + @override(TensorSpec) def get_shape(self, tensor: torch.Tensor) -> Tuple[int]: return tuple(tensor.shape) - @override(TensorSpecs) + @override(TensorSpec) def get_dtype(self, tensor: torch.Tensor) -> Any: return tensor.dtype - @override(TensorSpecs) + @override(TensorSpec) def _full( self, shape: Tuple[int], fill_value: Union[float, int] = 0 ) -> torch.Tensor: diff --git a/rllib/models/specs/tests/test_check_specs.py b/rllib/models/specs/tests/test_check_specs.py new file mode 100644 index 000000000000..78bd0afbb073 --- /dev/null +++ b/rllib/models/specs/tests/test_check_specs.py @@ -0,0 +1,268 @@ +import abc +import numpy as np +import time +import torch +from typing import Dict, Any, Type +import unittest + +from ray.rllib.models.specs.specs_base import TensorSpec +from ray.rllib.models.specs.specs_dict import ModelSpec, check_specs +from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.utils.annotations import override +from ray.rllib.utils.nested_dict import NestedDict + +ONLY_ONE_KEY_ALLOWED = "Only one key is allowed in the data dict." + + +class AbstractInterfaceClass(abc.ABC): + """An abstract class that has a couple of methods, each having their own + input/output constraints.""" + + @abc.abstractmethod + def input_spec(self) -> ModelSpec: + pass + + @abc.abstractmethod + def output_spec(self) -> ModelSpec: + pass + + @check_specs(input_spec="input_spec", output_spec="output_spec") + def check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + return self._check_input_and_output(input_dict) + + @abc.abstractmethod + def _check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + pass + + @check_specs(input_spec="input_spec") + def check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + """should not override this method""" + return self._check_only_input(input_dict) + + @abc.abstractmethod + def _check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + pass + + @check_specs(output_spec="output_spec") + def check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + """should not override this method""" + return self._check_only_output(input_dict) + + @abc.abstractmethod + def _check_only_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + pass + + @check_specs(input_spec="input_spec", output_spec="output_spec", cache=True) + def check_input_and_output_with_cache( + self, input_dict: Dict[str, Any] + ) -> Dict[str, Any]: + """should not override this method""" + return self._check_input_and_output(input_dict) + + @check_specs(input_spec="input_spec", output_spec="output_spec", filter=False) + def check_input_and_output_wo_filter(self, input_dict) -> Dict[str, Any]: + """should not override this method""" + return self._check_input_and_output(input_dict) + + +class InputNumberOutputFloat(AbstractInterfaceClass): + """This is an abstract class enforcing a contraint on input/output""" + + def input_spec(self) -> ModelSpec: + return ModelSpec({"input": (float, int)}) + + def output_spec(self) -> ModelSpec: + return ModelSpec({"output": float}) + + +class CorrectImplementation(InputNumberOutputFloat): + def run(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + output = float(input_dict["input"]) * 2 + return {"output": output} + + @override(AbstractInterfaceClass) + def _check_input_and_output(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + # check if there is any key other than input in the input_dict + if len(input_dict) > 1 or "input" not in input_dict: + raise ValueError(ONLY_ONE_KEY_ALLOWED) + return self.run(input_dict) + + @override(AbstractInterfaceClass) + def _check_only_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + # check if there is any key other than input in the input_dict + if len(input_dict) > 1 or "input" not in input_dict: + raise ValueError(ONLY_ONE_KEY_ALLOWED) + + out = self.run(input_dict) + + # output can be anything since ther is no output_spec + return {"output": str(out)} + + @override(AbstractInterfaceClass) + def _check_only_output(self, input_dict) -> Dict[str, Any]: + # there is no input spec, so we can pass anything + if "input" in input_dict: + raise ValueError( + "input_dict should not have `input` key in check_only_output" + ) + + return self.run({"input": input_dict["not_input"]}) + + +class IncorrectImplementation(CorrectImplementation): + @override(CorrectImplementation) + def run(self, input_dict) -> Dict[str, Any]: + output = str(input_dict["input"] * 2) + return {"output": output} + + +class TestCheckSpecs(unittest.TestCase): + def test_check_input_and_output(self): + + correct_module = CorrectImplementation() + + output = correct_module.check_input_and_output({"input": 2}) + # output should also match the output_spec + correct_module.output_spec().validate(NestedDict(output)) + + # this should raise an error saying that the `input` key is missing + self.assertRaises( + ValueError, lambda: correct_module.check_input_and_output({"not_input": 2}) + ) + + def test_check_only_input(self): + correct_module = CorrectImplementation() + # this should not raise any error since input matches the input specs + output = correct_module.check_only_input({"input": 2}) + # output can be anything since ther is no output_spec + self.assertRaises( + ValueError, + lambda: correct_module.output_spec().validate(NestedDict(output)), + ) + + def test_check_only_output(self): + correct_module = CorrectImplementation() + # this should not raise any error since input does not have to match input_spec + output = correct_module.check_only_output({"not_input": 2}) + # output should match the output specs + correct_module.output_spec().validate(NestedDict(output)) + + def test_incorrect_implementation(self): + incorrect_module = IncorrectImplementation() + # this should raise an error saying that the output does not match the + # output_spec + self.assertRaises( + ValueError, lambda: incorrect_module.check_input_and_output({"input": 2}) + ) + + # this should not raise an error because output is not forced to be checked + incorrect_module.check_only_input({"input": 2}) + + # this should raise an error because output does not match the output_spec + self.assertRaises( + ValueError, lambda: incorrect_module.check_only_output({"not_input": 2}) + ) + + def test_filter(self): + # create an arbitrary large input dict and test the behavior with and without a + # filter + input_dict = NestedDict({"input": 2}) + for i in range(100): + inds = (str(i),) + tuple(str(j) for j in range(i + 1, i + 11)) + input_dict[inds] = i + + correct_module = CorrectImplementation() + + # should run without errors + correct_module.check_input_and_output(input_dict) + + # should raise an error (read the implementation of + # check_input_and_output_wo_filter) + self.assertRaises( + ValueError, + lambda: correct_module.check_input_and_output_wo_filter(input_dict), + ) + + def test_cache(self): + # for N times, run the function twice and compare the time of each run. + # the second run should be faster since the output is cached + # to make sure the time is not too small, we run this on an input dict that is + # arbitrarily large and nested. + # we also check if cache is not working the second run is as slow as the first + # run. + + input_dict = NestedDict({"input": 2}) + for i in range(100): + inds = (str(i),) + tuple(str(j) for j in range(i + 1, i + 11)) + input_dict[inds] = i + + N = 500 + for fname in ["check_input_and_output", "check_input_and_output_with_cache"]: + time1, time2 = [], [] + for _ in range(N): + + module = CorrectImplementation() + + fn = getattr(module, fname) + start = time.time() + fn(input_dict) + end = time.time() + time1.append(end - start) + + start = time.time() + fn(input_dict) + end = time.time() + time2.append(end - start) + + lower_bound_time1 = np.mean(time1) - 3 * np.std(time1) + upper_bound_time2 = np.mean(time2) + 3 * np.std(time2) + + if fname == "check_input_and_output_with_cache": + self.assertGreater(lower_bound_time1, upper_bound_time2) + else: + self.assertGreater(upper_bound_time2, lower_bound_time1) + + def test_tensor_specs(self): + # test if the input_spec can be a tensor spec + class ClassWithTensorSpec: + def input_spec1(self) -> TensorSpec: + return TorchTensorSpec("b, h", h=4) + + @check_specs(input_spec="input_spec1") + def forward(self, input_data) -> Any: + return input_data + + module = ClassWithTensorSpec() + module.forward(torch.rand(2, 4)) + self.assertRaises(ValueError, lambda: module.forward(torch.rand(2, 3))) + + def test_type_specs(self): + class SpecialOutputType: + pass + + class WrongOutputType: + pass + + class ClassWithTypeSpec: + def output_spec(self) -> Type: + return SpecialOutputType + + @check_specs(output_spec="output_spec") + def forward_pass(self, input_data) -> Any: + return SpecialOutputType() + + @check_specs(output_spec="output_spec") + def forward_fail(self, input_data) -> Any: + return WrongOutputType() + + module = ClassWithTypeSpec() + output = module.forward_pass(torch.rand(2, 4)) + self.assertIsInstance(output, SpecialOutputType) + self.assertRaises(ValueError, lambda: module.forward_fail(torch.rand(2, 3))) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/specs/tests/test_tensor_specs_dict.py b/rllib/models/specs/tests/test_model_spec.py similarity index 79% rename from rllib/models/specs/tests/test_tensor_specs_dict.py rename to rllib/models/specs/tests/test_model_spec.py index bdacf4659ef1..d053d8afab88 100644 --- a/rllib/models/specs/tests/test_tensor_specs_dict.py +++ b/rllib/models/specs/tests/test_model_spec.py @@ -1,8 +1,8 @@ import unittest import numpy as np -from ray.rllib.models.specs.specs_np import NPSpecs -from ray.rllib.models.specs.specs_dict import ModelSpecDict +from ray.rllib.models.specs.specs_np import NPSpec +from ray.rllib.models.specs.specs_dict import ModelSpec class TypeClass1: @@ -13,14 +13,14 @@ class TypeClass2: pass -class TestModelSpecDict(unittest.TestCase): +class TestModelSpec(unittest.TestCase): def test_basic_validation(self): h1, h2 = 3, 4 - spec_1 = ModelSpecDict( + spec_1 = ModelSpec( { - "out_tensor_1": NPSpecs("b, h", h=h1), - "out_tensor_2": NPSpecs("b, h", h=h2), + "out_tensor_1": NPSpec("b, h", h=h1), + "out_tensor_2": NPSpec("b, h", h=h2), "out_class_1": TypeClass1, } ) @@ -68,15 +68,15 @@ def test_basic_validation(self): self.assertRaises(ValueError, lambda: spec_1.validate(tensor_4)) # test nested specs - spec_2 = ModelSpecDict( + spec_2 = ModelSpec( { "encoder": { - "input": NPSpecs("b, h", h=h1), - "output": NPSpecs("b, h", h=h2), + "input": NPSpec("b, h", h=h1), + "output": NPSpec("b, h", h=h2), }, "decoder": { - "input": NPSpecs("b, h", h=h2), - "output": NPSpecs("b, h", h=h1), + "input": NPSpec("b, h", h=h2), + "output": NPSpec("b, h", h=h1), }, } ) diff --git a/rllib/models/specs/tests/test_tensor_specs.py b/rllib/models/specs/tests/test_tensor_spec.py similarity index 95% rename from rllib/models/specs/tests/test_tensor_specs.py rename to rllib/models/specs/tests/test_tensor_spec.py index 7389d04f0c81..8b02b9d22831 100644 --- a/rllib/models/specs/tests/test_tensor_specs.py +++ b/rllib/models/specs/tests/test_tensor_spec.py @@ -5,13 +5,13 @@ import tensorflow as tf from ray.rllib.utils.test_utils import check -from ray.rllib.models.specs.specs_torch import TorchSpecs -from ray.rllib.models.specs.specs_np import NPSpecs +from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.models.specs.specs_np import NPSpec from ray.rllib.models.specs.specs_tf import TFSpecs # TODO: add jax tests -SPEC_CLASSES = {"torch": TorchSpecs, "np": NPSpecs, "tf": TFSpecs} +SPEC_CLASSES = {"torch": TorchTensorSpec, "np": NPSpec, "tf": TFSpecs} DOUBLE_TYPE = { "torch": torch.float64, "np": np.float64, From 422d5ad868dc1c2dadd524facbfbfa2bf771fb16 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Sun, 23 Oct 2022 23:41:03 -0700 Subject: [PATCH 02/11] updated errors in nested dict Signed-off-by: Kourosh Hakhamaneshi --- rllib/utils/nested_dict.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/rllib/utils/nested_dict.py b/rllib/utils/nested_dict.py index 12bbb5877dcb..f41045550e02 100644 --- a/rllib/utils/nested_dict.py +++ b/rllib/utils/nested_dict.py @@ -85,7 +85,7 @@ class NestedDict(Generic[T], MutableMapping[str, Union[T, "NestedDict"]]): >>> # 'b': {'c': 200, 'd': 300}} >>> # Getting elements, possibly nested: >>> print(foo_dict['b', 'c']) # 200 - >>> print(foo_dict['b']) # IndexError("Use get for partial indexing.") + >>> print(foo_dict['b']) # IndexError >>> print(foo_dict.get('b')) # {'c': 200, 'd': 300} >>> print(foo_dict) # {'a': 100, 'b': {'c': 200, 'd': 300}} >>> # Converting to a dict: @@ -181,7 +181,12 @@ def get( def __getitem__(self, k: SeqStrType) -> T: output = self.get(k) if isinstance(output, NestedDict): - raise IndexError("Use get for partial indexing.") + raise IndexError( + f"Key `{k}` is not a complete key in the given " + f"{self.__class__.__name__}. It results in a container " + f"with subkeys {set(output.keys())}. To get partial indexing, " + f"use {self.__class__.__name__}.get(key) instead." + ) return output def __setitem__(self, k: SeqStrType, v: Union[T, _NestedMappingType]) -> None: @@ -190,7 +195,9 @@ def __setitem__(self, k: SeqStrType, v: Union[T, _NestedMappingType]) -> None: if isinstance(v, Mapping) and len(v) == 0: return if not k: - raise IndexError("Use valid index value.") + raise IndexError( + f"Key for {self.__class__.__name__} cannot be empty. Got {k}." + ) k = _flatten_index(k) v = self.__class__(v) if isinstance(v, Mapping) else v data_ptr = self._data From 5be71820ebd7883333b34167f3b3032b8da13f8d Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 10:25:56 -0700 Subject: [PATCH 03/11] updated bazel Signed-off-by: Kourosh Hakhamaneshi --- rllib/BUILD | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rllib/BUILD b/rllib/BUILD index 374ff9fb7ac7..b48c838782d9 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1750,6 +1750,13 @@ py_test( srcs = ["models/tests/test_attention_nets.py"] ) +py_test( + name = "test_check_specs", + tags = ["team:rllib", "models"], + size = "medium", + srcs = ["models/tests/test_check_specs.py"] +) + py_test( name = "test_conv2d_default_stacks", tags = ["team:rllib", "models"], From 30fe44f5a51730c06073d20291191082bb39db57 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 10:41:42 -0700 Subject: [PATCH 04/11] made the names consistent Signed-off-by: Kourosh Hakhamaneshi --- rllib/models/specs/specs_np.py | 2 +- rllib/models/specs/specs_tf.py | 2 +- rllib/models/specs/tests/test_model_spec.py | 14 +++++++------- rllib/models/specs/tests/test_tensor_spec.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/rllib/models/specs/specs_np.py b/rllib/models/specs/specs_np.py index 8fefec98032a..4782b2649516 100644 --- a/rllib/models/specs/specs_np.py +++ b/rllib/models/specs/specs_np.py @@ -6,7 +6,7 @@ @DeveloperAPI -class NPSpec(TensorSpec): +class NPTensorSpec(TensorSpec): @override(TensorSpec) def get_type(cls) -> Type: return np.ndarray diff --git a/rllib/models/specs/specs_tf.py b/rllib/models/specs/specs_tf.py index 17d83871fb79..f438880c2d89 100644 --- a/rllib/models/specs/specs_tf.py +++ b/rllib/models/specs/specs_tf.py @@ -8,7 +8,7 @@ @DeveloperAPI -class TFSpecs(TensorSpec): +class TFTensorSpecs(TensorSpec): @override(TensorSpec) def get_type(cls) -> Type: return tf.Tensor diff --git a/rllib/models/specs/tests/test_model_spec.py b/rllib/models/specs/tests/test_model_spec.py index d053d8afab88..1239edc039b9 100644 --- a/rllib/models/specs/tests/test_model_spec.py +++ b/rllib/models/specs/tests/test_model_spec.py @@ -1,7 +1,7 @@ import unittest import numpy as np -from ray.rllib.models.specs.specs_np import NPSpec +from ray.rllib.models.specs.specs_np import NPTensorSpec from ray.rllib.models.specs.specs_dict import ModelSpec @@ -19,8 +19,8 @@ def test_basic_validation(self): h1, h2 = 3, 4 spec_1 = ModelSpec( { - "out_tensor_1": NPSpec("b, h", h=h1), - "out_tensor_2": NPSpec("b, h", h=h2), + "out_tensor_1": NPTensorSpec("b, h", h=h1), + "out_tensor_2": NPTensorSpec("b, h", h=h2), "out_class_1": TypeClass1, } ) @@ -71,12 +71,12 @@ def test_basic_validation(self): spec_2 = ModelSpec( { "encoder": { - "input": NPSpec("b, h", h=h1), - "output": NPSpec("b, h", h=h2), + "input": NPTensorSpec("b, h", h=h1), + "output": NPTensorSpec("b, h", h=h2), }, "decoder": { - "input": NPSpec("b, h", h=h2), - "output": NPSpec("b, h", h=h1), + "input": NPTensorSpec("b, h", h=h2), + "output": NPTensorSpec("b, h", h=h1), }, } ) diff --git a/rllib/models/specs/tests/test_tensor_spec.py b/rllib/models/specs/tests/test_tensor_spec.py index 8b02b9d22831..7040a2f6021f 100644 --- a/rllib/models/specs/tests/test_tensor_spec.py +++ b/rllib/models/specs/tests/test_tensor_spec.py @@ -6,12 +6,12 @@ from ray.rllib.utils.test_utils import check from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.models.specs.specs_np import NPSpec -from ray.rllib.models.specs.specs_tf import TFSpecs +from ray.rllib.models.specs.specs_np import NPTensorSpec +from ray.rllib.models.specs.specs_tf import TFTensorSpecs # TODO: add jax tests -SPEC_CLASSES = {"torch": TorchTensorSpec, "np": NPSpec, "tf": TFSpecs} +SPEC_CLASSES = {"torch": TorchTensorSpec, "np": NPTensorSpec, "tf": TFTensorSpecs} DOUBLE_TYPE = { "torch": torch.float64, "np": np.float64, From 9c74871a193115782045509789d5cfb5e9d48f19 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 10:44:01 -0700 Subject: [PATCH 05/11] removed the partial indexing error in __getitem__ Signed-off-by: Kourosh Hakhamaneshi --- rllib/utils/nested_dict.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/rllib/utils/nested_dict.py b/rllib/utils/nested_dict.py index f41045550e02..a16e8bffc7a3 100644 --- a/rllib/utils/nested_dict.py +++ b/rllib/utils/nested_dict.py @@ -85,7 +85,7 @@ class NestedDict(Generic[T], MutableMapping[str, Union[T, "NestedDict"]]): >>> # 'b': {'c': 200, 'd': 300}} >>> # Getting elements, possibly nested: >>> print(foo_dict['b', 'c']) # 200 - >>> print(foo_dict['b']) # IndexError + >>> print(foo_dict['b']) # {'c': 200, 'd': 300} >>> print(foo_dict.get('b')) # {'c': 200, 'd': 300} >>> print(foo_dict) # {'a': 100, 'b': {'c': 200, 'd': 300}} >>> # Converting to a dict: @@ -180,13 +180,6 @@ def get( def __getitem__(self, k: SeqStrType) -> T: output = self.get(k) - if isinstance(output, NestedDict): - raise IndexError( - f"Key `{k}` is not a complete key in the given " - f"{self.__class__.__name__}. It results in a container " - f"with subkeys {set(output.keys())}. To get partial indexing, " - f"use {self.__class__.__name__}.get(key) instead." - ) return output def __setitem__(self, k: SeqStrType, v: Union[T, _NestedMappingType]) -> None: From f7dcb9060a356df67f5d9c7d9e644cfeb239b098 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 11:53:46 -0700 Subject: [PATCH 06/11] nested dict test update Signed-off-by: Kourosh Hakhamaneshi --- rllib/utils/tests/test_nested_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/utils/tests/test_nested_dict.py b/rllib/utils/tests/test_nested_dict.py index 54c08a20ee63..6dd8b0a059e5 100644 --- a/rllib/utils/tests/test_nested_dict.py +++ b/rllib/utils/tests/test_nested_dict.py @@ -81,7 +81,7 @@ def set_invalid_item_2(): self.assertEqual(foo_dict["b", "c"], 200) self.assertEqual(foo_dict["c", "e", "f"], 400) self.assertEqual(foo_dict["d", "g", "h", "i"], 500) - self.assertRaises(IndexError, lambda: foo_dict["b"]) + self.assertEqual(foo_dict["b"], NestedDict({"c": 200, "d": 300})) # test __str__ self.assertEqual(str(foo_dict), str(desired_dict)) From a019dd565ce2326aadd73d12fffe996ca196cbe0 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 11:54:42 -0700 Subject: [PATCH 07/11] bazel update Signed-off-by: Kourosh Hakhamaneshi --- rllib/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/BUILD b/rllib/BUILD index b48c838782d9..0886ea12162f 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1754,7 +1754,7 @@ py_test( name = "test_check_specs", tags = ["team:rllib", "models"], size = "medium", - srcs = ["models/tests/test_check_specs.py"] + srcs = ["models/specs/tests/test_check_specs.py"] ) py_test( From 11fe4f453bd977d4af6b64f481864057c1b965c1 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 12:57:58 -0700 Subject: [PATCH 08/11] making cache test unflakey Signed-off-by: Kourosh Hakhamaneshi --- rllib/models/specs/tests/test_check_specs.py | 44 +++++++++----------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/rllib/models/specs/tests/test_check_specs.py b/rllib/models/specs/tests/test_check_specs.py index 78bd0afbb073..b7402ecda4ee 100644 --- a/rllib/models/specs/tests/test_check_specs.py +++ b/rllib/models/specs/tests/test_check_specs.py @@ -197,30 +197,26 @@ def test_cache(self): input_dict[inds] = i N = 500 - for fname in ["check_input_and_output", "check_input_and_output_with_cache"]: - time1, time2 = [], [] - for _ in range(N): - - module = CorrectImplementation() - - fn = getattr(module, fname) - start = time.time() - fn(input_dict) - end = time.time() - time1.append(end - start) - - start = time.time() - fn(input_dict) - end = time.time() - time2.append(end - start) - - lower_bound_time1 = np.mean(time1) - 3 * np.std(time1) - upper_bound_time2 = np.mean(time2) + 3 * np.std(time2) - - if fname == "check_input_and_output_with_cache": - self.assertGreater(lower_bound_time1, upper_bound_time2) - else: - self.assertGreater(upper_bound_time2, lower_bound_time1) + time1, time2 = [], [] + for _ in range(N): + + module = CorrectImplementation() + + fn = getattr(module, "check_input_and_output_with_cache") + start = time.time() + fn(input_dict) + end = time.time() + time1.append(end - start) + + start = time.time() + fn(input_dict) + end = time.time() + time2.append(end - start) + + lower_bound_time1 = np.mean(time1) - 3 * np.std(time1) + upper_bound_time2 = np.mean(time2) + 3 * np.std(time2) + + self.assertGreater(lower_bound_time1, upper_bound_time2) def test_tensor_specs(self): # test if the input_spec can be a tensor spec From f0ec25f80dba498560d1c72031d2851b75779632 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 15:18:13 -0700 Subject: [PATCH 09/11] made the condition checking safer Signed-off-by: Kourosh Hakhamaneshi --- rllib/models/specs/specs_dict.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/rllib/models/specs/specs_dict.py b/rllib/models/specs/specs_dict.py index 51e170f0a93d..e1462a1517da 100644 --- a/rllib/models/specs/specs_dict.py +++ b/rllib/models/specs/specs_dict.py @@ -135,6 +135,11 @@ def validate( spec_name, type(data_to_validate).__name__, spec.__name__ ) ) + else: + raise ValueError( + f"The spec type has to be either TensorSpec or Type. " + f"got {type(spec)}" + ) @override(NestedDict) def __repr__(self) -> str: From e9ffeb1ceb14905879b87126592b2d287894a942 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 16:09:27 -0700 Subject: [PATCH 10/11] fixed a hidden bug with tuple's getting skipped in ModelSpecs Signed-off-by: Kourosh Hakhamaneshi --- rllib/models/specs/specs_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/models/specs/specs_dict.py b/rllib/models/specs/specs_dict.py index e1462a1517da..1547cca3ad12 100644 --- a/rllib/models/specs/specs_dict.py +++ b/rllib/models/specs/specs_dict.py @@ -128,7 +128,7 @@ def validate( f"Mismatch found in data element {spec_name}, " f"which is a TensorSpec: {e}" ) - elif isinstance(spec, Type): + elif isinstance(spec, (Type, tuple)): if not isinstance(data_to_validate, spec): raise ValueError( _TYPE_MISMATCH.format( From 9cfb1619237995684601c5651ef765422aca2d3b Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 24 Oct 2022 17:09:43 -0700 Subject: [PATCH 11/11] attempt to deflake Signed-off-by: Kourosh Hakhamaneshi --- rllib/models/specs/tests/test_check_specs.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rllib/models/specs/tests/test_check_specs.py b/rllib/models/specs/tests/test_check_specs.py index b7402ecda4ee..4d172da02b7f 100644 --- a/rllib/models/specs/tests/test_check_specs.py +++ b/rllib/models/specs/tests/test_check_specs.py @@ -184,6 +184,7 @@ def test_filter(self): ) def test_cache(self): + # warning: this could be a flakey test # for N times, run the function twice and compare the time of each run. # the second run should be faster since the output is cached # to make sure the time is not too small, we run this on an input dict that is @@ -213,8 +214,10 @@ def test_cache(self): end = time.time() time2.append(end - start) - lower_bound_time1 = np.mean(time1) - 3 * np.std(time1) - upper_bound_time2 = np.mean(time2) + 3 * np.std(time2) + lower_bound_time1 = np.mean(time1) # - 3 * np.std(time1) + upper_bound_time2 = np.mean(time2) # + 3 * np.std(time2) + print(f"time1: {np.mean(time1)}") + print(f"time2: {np.mean(time2)}") self.assertGreater(lower_bound_time1, upper_bound_time2)