-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] created check_specs decorator, RLModule PR 1/N #29599
Changes from 10 commits
b22a6d8
422d5ad
8098120
5be7182
30fe44f
9c74871
f7dcb90
a019dd5
11fe4f4
f0ec25f
e9ffeb1
9cfb161
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,45 @@ | ||
import functools | ||
from typing import Union, Type, Mapping, Any | ||
|
||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.nested_dict import NestedDict | ||
from ray.rllib.models.specs.specs_base import TensorSpecs | ||
from ray.rllib.models.specs.specs_base import TensorSpec | ||
|
||
|
||
_MISSING_KEYS_FROM_SPEC = ( | ||
"The data dict does not match the model specs. Keys {} are " | ||
"in the data dict but not on the given spec dict, and exact_match is set to True." | ||
"in the data dict but not on the given spec dict, and exact_match is set to True" | ||
) | ||
_MISSING_KEYS_FROM_DATA = ( | ||
"The data dict does not match the model specs. Keys {} are " | ||
"in the spec dict but not on the data dict." | ||
"in the spec dict but not on the data dict. Data keys are {}" | ||
) | ||
_TYPE_MISMATCH = ( | ||
"The data does not match the spec. The data element " | ||
"{} has type {} (expected type {})." | ||
) | ||
|
||
SPEC_LEAF_TYPE = Union[Type, TensorSpecs] | ||
SPEC_LEAF_TYPE = Union[Type, TensorSpec] | ||
DATA_TYPE = Union[NestedDict[Any], Mapping[str, Any]] | ||
|
||
IS_NOT_PROPERTY = "Spec {} must be a property of the class {}." | ||
|
||
class ModelSpecDict(NestedDict[SPEC_LEAF_TYPE]): | ||
"""A NestedDict containing `TensorSpecs` and `Types`. | ||
|
||
class ModelSpec(NestedDict[SPEC_LEAF_TYPE]): | ||
"""A NestedDict containing `TensorSpec` and `Types`. | ||
|
||
It can be used to validate an incoming data against a nested dictionary of specs. | ||
|
||
Examples: | ||
|
||
Basic validation: | ||
----------------- | ||
>>> spec_dict = ModelSpecDict({ | ||
>>> spec_dict = ModelSpec({ | ||
... "obs": { | ||
... "arm": TensorSpecs("b, d_a", d_a=64), | ||
... "gripper": TensorSpecs("b, d_g", d_g=12) | ||
... "arm": TensorSpec("b, d_a", d_a=64), | ||
... "gripper": TensorSpec("b, d_g", d_g=12) | ||
... }, | ||
... "action": TensorSpecs("b, d_a", h=12), | ||
... "action": TensorSpec("b, d_a", h=12), | ||
... "action_dist": torch.distributions.Categorical | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a random question. do you intend for these specs to be checkpointed and restored? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. This is a good point. This should be incorporated into the RLModule PR then. I'll look into this there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to write this directly to the |
||
... }) | ||
|
||
|
@@ -91,7 +94,7 @@ def __init__(self, *args, **kwargs): | |
def validate( | ||
self, | ||
data: DATA_TYPE, | ||
exact_match: bool = True, | ||
exact_match: bool = False, | ||
) -> None: | ||
"""Checks whether the data matches the spec. | ||
|
||
|
@@ -107,24 +110,184 @@ def validate( | |
data_keys_set = set(data.keys()) | ||
missing_keys = self._keys_set.difference(data_keys_set) | ||
if missing_keys: | ||
raise ValueError(_MISSING_KEYS_FROM_DATA.format(missing_keys)) | ||
raise ValueError( | ||
_MISSING_KEYS_FROM_DATA.format(missing_keys, data_keys_set) | ||
) | ||
if exact_match: | ||
data_spec_missing_keys = data_keys_set.difference(self._keys_set) | ||
if data_spec_missing_keys: | ||
raise ValueError(_MISSING_KEYS_FROM_SPEC.format(data_spec_missing_keys)) | ||
|
||
for spec_name, spec in self.items(): | ||
data_to_validate = data[spec_name] | ||
if isinstance(spec, TensorSpecs): | ||
spec.validate(data_to_validate) | ||
if isinstance(spec, TensorSpec): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for safety, this needs an "else" clause that just raises error if spec is some random data? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
try: | ||
spec.validate(data_to_validate) | ||
except ValueError as e: | ||
raise ValueError( | ||
f"Mismatch found in data element {spec_name}, " | ||
f"which is a TensorSpec: {e}" | ||
) | ||
elif isinstance(spec, Type): | ||
if not isinstance(data_to_validate, spec): | ||
raise ValueError( | ||
_TYPE_MISMATCH.format( | ||
spec_name, type(data_to_validate).__name__, spec.__name__ | ||
) | ||
) | ||
else: | ||
raise ValueError( | ||
f"The spec type has to be either TensorSpec or Type. " | ||
f"got {type(spec)}" | ||
) | ||
|
||
@override(NestedDict) | ||
def __repr__(self) -> str: | ||
return f"ModelSpecDict({repr(self._data)})" | ||
return f"ModelSpec({repr(self._data)})" | ||
|
||
|
||
def check_specs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Super nice! This is going to be dope for early-catching user errors. |
||
input_spec: str = "", | ||
output_spec: str = "", | ||
filter: bool = True, | ||
cache: bool = False, | ||
input_exact_match: bool = False, | ||
output_exact_match: bool = False, | ||
): | ||
"""A general-purpose [stateful decorator] | ||
(https://realpython.com/primer-on-python-decorators/#stateful-decorators) to | ||
enforce input/output specs for any instance method that has `input_data` in input | ||
args and returns and a single object. | ||
|
||
|
||
It adds the ability to filter the input data if it is a mappinga to only contain | ||
the keys in the spec. It can also cache the validation to make sure the spec is | ||
only validated once in the lifetime of the instance. | ||
|
||
Examples (See more exmaples in ../tests/test_specs_dict.py): | ||
|
||
>>> class MyModel(nn.Module): | ||
... def input_spec(self): | ||
... return ModelSpec({"obs": TensorSpec("b, d", d=64)}) | ||
... | ||
... @check_specs(input_spec="input_spec") | ||
... def forward(self, input_data, return_loss=False): | ||
... ... | ||
... output_dict = ... | ||
... return output_dict | ||
|
||
>>> model = MyModel() | ||
>>> model.forward({"obs": torch.randn(32, 64)}) # No error | ||
>>> model.forward({"obs": torch.randn(32, 32)}) # raises ValueError | ||
|
||
Args: | ||
func: The instance method to decorate. It should be a callable that takes | ||
`self` as the first argument, `input_data` as the second argument and any | ||
other keyword argument thereafter. It should return a single object. | ||
input_spec: `self` should have an instance method that is named input_spec and | ||
returns the `ModelSpec`, `TensorSpec`, or simply the `Type` that the | ||
`input_data` should comply with. | ||
output_spec: `self` should have an instance method that is named output_spec | ||
and returns the spec that the output should comply with. | ||
filter: If True, and `input_data` is a nested dict the `input_data` will be | ||
filtered by its corresponding spec tree structure and then passed into the | ||
implemented function to make sure user is not confounded with unnecessary | ||
data. | ||
cache: If True, only checks the input/output validation for the first time the | ||
instance method is called. | ||
input_exact_match: If True, the input data (should be a nested dict) must match | ||
the spec exactly. Otherwise, the data is validated as long as it contains | ||
at least the elements of the spec, but can contain more entries. | ||
output_exact_match: If True, the output data (should be a nested dict) must | ||
match the spec exactly. Otherwise, the data is validated as long as it | ||
contains at least the elements of the spec, but can contain more entries. | ||
|
||
Returns: | ||
A wrapped instance method. In case of `cache=True`, after the first invokation | ||
of the decorated method, the intance will have `__checked_specs_cache__` | ||
attribute that store which method has been invoked at least once. This is a | ||
special attribute can be used for the cache itself. The wrapped class method | ||
also has a special attribute `__checked_specs__` that marks the method as | ||
decorated. | ||
""" | ||
|
||
if not input_spec and not output_spec: | ||
raise ValueError("At least one of input_spec or output_spec must be provided.") | ||
|
||
def decorator(func): | ||
@functools.wraps(func) | ||
def wrapper(self, input_data, **kwargs): | ||
|
||
if cache and not hasattr(self, "__checked_specs_cache__"): | ||
self.__checked_specs_cache__ = {} | ||
|
||
def should_validate(): | ||
return not cache or func.__name__ not in self.__checked_specs_cache__ | ||
|
||
def validate(data, spec, exact_match, tag="data"): | ||
is_mapping = isinstance(spec, ModelSpec) | ||
is_tensor = isinstance(spec, TensorSpec) | ||
|
||
if is_mapping: | ||
if not isinstance(data, Mapping): | ||
raise ValueError( | ||
f"{tag} must be a Mapping, got {type(data).__name__}" | ||
) | ||
data = NestedDict(data) | ||
|
||
if should_validate(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should you check should_validate() first thing, so we don't waste time building NestedDict, etc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, You may still need to filter the data if it's a mapping regardless of whether you should validate or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👌 👌 |
||
try: | ||
if is_mapping: | ||
spec.validate(data, exact_match=exact_match) | ||
elif is_tensor: | ||
spec.validate(data) | ||
except ValueError as e: | ||
raise ValueError( | ||
f"{tag} spec validation failed on " | ||
f"{self.__class__.__name__}.{func.__name__}, {e}." | ||
) | ||
|
||
if not (is_tensor or is_mapping): | ||
if not isinstance(data, spec): | ||
raise ValueError( | ||
f"Input spec validation failed on " | ||
f"{self.__class__.__name__}.{func.__name__}, " | ||
f"expected {spec.__name__}, got " | ||
f"{type(data).__name__}." | ||
) | ||
return data | ||
|
||
input_data_ = input_data | ||
if input_spec: | ||
input_spec_ = getattr(self, input_spec)() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can input_spec_ just be a member variable on self? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That will assume some hard-coded names on the base class and is not the intention of this general-purpose decorator. The decorator let's the user choose their own spec names. so it can be applied to any base class essentially. It may become handy in defining Pi base class that looks different than RLModule base class. You can see the use-case in the RLModule PR. |
||
|
||
input_data_ = validate( | ||
input_data, | ||
input_spec_, | ||
exact_match=input_exact_match, | ||
tag="input_data", | ||
) | ||
|
||
if filter and isinstance(input_spec_, ModelSpec): | ||
# filtering should happen regardless of cache | ||
input_data_ = input_data_.filter(input_spec_) | ||
|
||
output_data = func(self, input_data_, **kwargs) | ||
if output_spec: | ||
output_spec_ = getattr(self, output_spec)() | ||
validate( | ||
output_data, | ||
output_spec_, | ||
exact_match=output_exact_match, | ||
tag="output_data", | ||
) | ||
|
||
if cache: | ||
self.__checked_specs_cache__[func.__name__] = True | ||
|
||
return output_data | ||
|
||
wrapper.__check_specs__ = True | ||
return wrapper | ||
|
||
return decorator |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
from ray.rllib.utils.annotations import DeveloperAPI, override | ||
from ray.rllib.utils.framework import try_import_jax | ||
from ray.rllib.models.specs.specs_base import TensorSpecs | ||
from ray.rllib.models.specs.specs_base import TensorSpec | ||
|
||
jax, _ = try_import_jax() | ||
jnp = None | ||
|
@@ -11,20 +11,20 @@ | |
|
||
|
||
@DeveloperAPI | ||
class JAXSpecs(TensorSpecs): | ||
@override(TensorSpecs) | ||
class JAXTensorSpec(TensorSpec): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why call this JAXTensorSpec, but the others only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed everything to |
||
@override(TensorSpec) | ||
def get_type(cls) -> Type: | ||
return jnp.ndarray | ||
|
||
@override(TensorSpecs) | ||
@override(TensorSpec) | ||
def get_shape(self, tensor: jnp.ndarray) -> Tuple[int]: | ||
return tuple(tensor.shape) | ||
|
||
@override(TensorSpecs) | ||
@override(TensorSpec) | ||
def get_dtype(self, tensor: jnp.ndarray) -> Any: | ||
return tensor.dtype | ||
|
||
@override(TensorSpecs) | ||
@override(TensorSpec) | ||
def _full( | ||
self, shape: Tuple[int], fill_value: Union[float, int] = 0 | ||
) -> jnp.ndarray: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realized they can represent only one spec, hence renaming :)