Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] created check_specs decorator, RLModule PR 1/N #29599

Merged
merged 12 commits into from
Oct 25, 2022
17 changes: 12 additions & 5 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,13 @@ py_test(
srcs = ["models/tests/test_attention_nets.py"]
)

py_test(
name = "test_check_specs",
tags = ["team:rllib", "models"],
size = "medium",
srcs = ["models/specs/tests/test_check_specs.py"]
)

py_test(
name = "test_conv2d_default_stacks",
tags = ["team:rllib", "models"],
Expand Down Expand Up @@ -1795,10 +1802,10 @@ py_test(

# Test Tensor specs
py_test(
name = "test_tensor_specs",
name = "test_tensor_spec",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/specs/tests/test_tensor_specs.py"]
srcs = ["models/specs/tests/test_tensor_spec.py"]
)

# test abstract base models
Expand All @@ -1817,12 +1824,12 @@ py_test(
srcs = ["models/tests/test_torch_model.py"]
)

# test ModelSpecDict
# test ModelSpec
py_test(
name = "test_tensor_specs_dict",
name = "test_model_spec",
tags = ["team:rllib", "models"],
size = "small",
srcs = ["models/specs/tests/test_tensor_specs_dict.py"]
srcs = ["models/specs/tests/test_model_spec.py"]
)


Expand Down
8 changes: 4 additions & 4 deletions rllib/models/specs/specs_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def validate(self, data: Any) -> None:


@DeveloperAPI
class TensorSpecs(SpecsAbstract):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized they can represent only one spec, hence renaming :)

class TensorSpec(SpecsAbstract):
"""A base class that specifies the shape and dtype of a tensor.

Args:
Expand Down Expand Up @@ -230,11 +230,11 @@ def _validate_shape_vals(
def __repr__(self) -> str:
return f"TensorSpec(shape={tuple(self.shape)}, dtype={self.dtype})"

def __eq__(self, other: "TensorSpecs") -> bool:
def __eq__(self, other: "TensorSpec") -> bool:
"""Checks if the shape and dtype of two specs are equal."""
if not isinstance(other, TensorSpecs):
if not isinstance(other, TensorSpec):
return False
return self.shape == other.shape and self.dtype == other.dtype

def __ne__(self, other: "TensorSpecs") -> bool:
def __ne__(self, other: "TensorSpec") -> bool:
return not self == other
193 changes: 178 additions & 15 deletions rllib/models/specs/specs_dict.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,45 @@
import functools
from typing import Union, Type, Mapping, Any

from ray.rllib.utils.annotations import override
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.models.specs.specs_base import TensorSpecs
from ray.rllib.models.specs.specs_base import TensorSpec


_MISSING_KEYS_FROM_SPEC = (
"The data dict does not match the model specs. Keys {} are "
"in the data dict but not on the given spec dict, and exact_match is set to True."
"in the data dict but not on the given spec dict, and exact_match is set to True"
)
_MISSING_KEYS_FROM_DATA = (
"The data dict does not match the model specs. Keys {} are "
"in the spec dict but not on the data dict."
"in the spec dict but not on the data dict. Data keys are {}"
)
_TYPE_MISMATCH = (
"The data does not match the spec. The data element "
"{} has type {} (expected type {})."
)

SPEC_LEAF_TYPE = Union[Type, TensorSpecs]
SPEC_LEAF_TYPE = Union[Type, TensorSpec]
DATA_TYPE = Union[NestedDict[Any], Mapping[str, Any]]

IS_NOT_PROPERTY = "Spec {} must be a property of the class {}."

class ModelSpecDict(NestedDict[SPEC_LEAF_TYPE]):
"""A NestedDict containing `TensorSpecs` and `Types`.

class ModelSpec(NestedDict[SPEC_LEAF_TYPE]):
"""A NestedDict containing `TensorSpec` and `Types`.

It can be used to validate an incoming data against a nested dictionary of specs.

Examples:

Basic validation:
-----------------
>>> spec_dict = ModelSpecDict({
>>> spec_dict = ModelSpec({
... "obs": {
... "arm": TensorSpecs("b, d_a", d_a=64),
... "gripper": TensorSpecs("b, d_g", d_g=12)
... "arm": TensorSpec("b, d_a", d_a=64),
... "gripper": TensorSpec("b, d_g", d_g=12)
... },
... "action": TensorSpecs("b, d_a", h=12),
... "action": TensorSpec("b, d_a", h=12),
... "action_dist": torch.distributions.Categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a random question. do you intend for these specs to be checkpointed and restored?
if so, maybe a registry of distribution is a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. This is a good point. This should be incorporated into the RLModule PR then. I'll look into this there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to write this directly to the state_dict, but this might break logic that assumes state_dict only contains tensors.

... })

Expand Down Expand Up @@ -91,7 +94,7 @@ def __init__(self, *args, **kwargs):
def validate(
self,
data: DATA_TYPE,
exact_match: bool = True,
exact_match: bool = False,
) -> None:
"""Checks whether the data matches the spec.

Expand All @@ -107,24 +110,184 @@ def validate(
data_keys_set = set(data.keys())
missing_keys = self._keys_set.difference(data_keys_set)
if missing_keys:
raise ValueError(_MISSING_KEYS_FROM_DATA.format(missing_keys))
raise ValueError(
_MISSING_KEYS_FROM_DATA.format(missing_keys, data_keys_set)
)
if exact_match:
data_spec_missing_keys = data_keys_set.difference(self._keys_set)
if data_spec_missing_keys:
raise ValueError(_MISSING_KEYS_FROM_SPEC.format(data_spec_missing_keys))

for spec_name, spec in self.items():
data_to_validate = data[spec_name]
if isinstance(spec, TensorSpecs):
spec.validate(data_to_validate)
if isinstance(spec, TensorSpec):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for safety, this needs an "else" clause that just raises error if spec is some random data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

try:
spec.validate(data_to_validate)
except ValueError as e:
raise ValueError(
f"Mismatch found in data element {spec_name}, "
f"which is a TensorSpec: {e}"
)
elif isinstance(spec, Type):
if not isinstance(data_to_validate, spec):
raise ValueError(
_TYPE_MISMATCH.format(
spec_name, type(data_to_validate).__name__, spec.__name__
)
)
else:
raise ValueError(
f"The spec type has to be either TensorSpec or Type. "
f"got {type(spec)}"
)

@override(NestedDict)
def __repr__(self) -> str:
return f"ModelSpecDict({repr(self._data)})"
return f"ModelSpec({repr(self._data)})"


def check_specs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice! This is going to be dope for early-catching user errors.

input_spec: str = "",
output_spec: str = "",
filter: bool = True,
cache: bool = False,
input_exact_match: bool = False,
output_exact_match: bool = False,
):
"""A general-purpose [stateful decorator]
(https://realpython.com/primer-on-python-decorators/#stateful-decorators) to
enforce input/output specs for any instance method that has `input_data` in input
args and returns and a single object.


It adds the ability to filter the input data if it is a mappinga to only contain
the keys in the spec. It can also cache the validation to make sure the spec is
only validated once in the lifetime of the instance.

Examples (See more exmaples in ../tests/test_specs_dict.py):

>>> class MyModel(nn.Module):
... def input_spec(self):
... return ModelSpec({"obs": TensorSpec("b, d", d=64)})
...
... @check_specs(input_spec="input_spec")
... def forward(self, input_data, return_loss=False):
... ...
... output_dict = ...
... return output_dict

>>> model = MyModel()
>>> model.forward({"obs": torch.randn(32, 64)}) # No error
>>> model.forward({"obs": torch.randn(32, 32)}) # raises ValueError

Args:
func: The instance method to decorate. It should be a callable that takes
`self` as the first argument, `input_data` as the second argument and any
other keyword argument thereafter. It should return a single object.
input_spec: `self` should have an instance method that is named input_spec and
returns the `ModelSpec`, `TensorSpec`, or simply the `Type` that the
`input_data` should comply with.
output_spec: `self` should have an instance method that is named output_spec
and returns the spec that the output should comply with.
filter: If True, and `input_data` is a nested dict the `input_data` will be
filtered by its corresponding spec tree structure and then passed into the
implemented function to make sure user is not confounded with unnecessary
data.
cache: If True, only checks the input/output validation for the first time the
instance method is called.
input_exact_match: If True, the input data (should be a nested dict) must match
the spec exactly. Otherwise, the data is validated as long as it contains
at least the elements of the spec, but can contain more entries.
output_exact_match: If True, the output data (should be a nested dict) must
match the spec exactly. Otherwise, the data is validated as long as it
contains at least the elements of the spec, but can contain more entries.

Returns:
A wrapped instance method. In case of `cache=True`, after the first invokation
of the decorated method, the intance will have `__checked_specs_cache__`
attribute that store which method has been invoked at least once. This is a
special attribute can be used for the cache itself. The wrapped class method
also has a special attribute `__checked_specs__` that marks the method as
decorated.
"""

if not input_spec and not output_spec:
raise ValueError("At least one of input_spec or output_spec must be provided.")

def decorator(func):
@functools.wraps(func)
def wrapper(self, input_data, **kwargs):

if cache and not hasattr(self, "__checked_specs_cache__"):
self.__checked_specs_cache__ = {}

def should_validate():
return not cache or func.__name__ not in self.__checked_specs_cache__

def validate(data, spec, exact_match, tag="data"):
is_mapping = isinstance(spec, ModelSpec)
is_tensor = isinstance(spec, TensorSpec)

if is_mapping:
if not isinstance(data, Mapping):
raise ValueError(
f"{tag} must be a Mapping, got {type(data).__name__}"
)
data = NestedDict(data)

if should_validate():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you check should_validate() first thing, so we don't waste time building NestedDict, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, You may still need to filter the data if it's a mapping regardless of whether you should validate or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌 👌

try:
if is_mapping:
spec.validate(data, exact_match=exact_match)
elif is_tensor:
spec.validate(data)
except ValueError as e:
raise ValueError(
f"{tag} spec validation failed on "
f"{self.__class__.__name__}.{func.__name__}, {e}."
)

if not (is_tensor or is_mapping):
if not isinstance(data, spec):
raise ValueError(
f"Input spec validation failed on "
f"{self.__class__.__name__}.{func.__name__}, "
f"expected {spec.__name__}, got "
f"{type(data).__name__}."
)
return data

input_data_ = input_data
if input_spec:
input_spec_ = getattr(self, input_spec)()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can input_spec_ just be a member variable on self?
or maybe we can simply assume that the specs will be provided by RLModule under some hardcoded key names.
so by default, folks can simply use the decorator without specifying any parameters.

Copy link
Contributor Author

@kouroshHakha kouroshHakha Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will assume some hard-coded names on the base class and is not the intention of this general-purpose decorator. The decorator let's the user choose their own spec names. so it can be applied to any base class essentially. It may become handy in defining Pi base class that looks different than RLModule base class. You can see the use-case in the RLModule PR.


input_data_ = validate(
input_data,
input_spec_,
exact_match=input_exact_match,
tag="input_data",
)

if filter and isinstance(input_spec_, ModelSpec):
# filtering should happen regardless of cache
input_data_ = input_data_.filter(input_spec_)

output_data = func(self, input_data_, **kwargs)
if output_spec:
output_spec_ = getattr(self, output_spec)()
validate(
output_data,
output_spec_,
exact_match=output_exact_match,
tag="output_data",
)

if cache:
self.__checked_specs_cache__[func.__name__] = True

return output_data

wrapper.__check_specs__ = True
return wrapper

return decorator
12 changes: 6 additions & 6 deletions rllib/models/specs/specs_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.framework import try_import_jax
from ray.rllib.models.specs.specs_base import TensorSpecs
from ray.rllib.models.specs.specs_base import TensorSpec

jax, _ = try_import_jax()
jnp = None
Expand All @@ -11,20 +11,20 @@


@DeveloperAPI
class JAXSpecs(TensorSpecs):
@override(TensorSpecs)
class JAXTensorSpec(TensorSpec):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call this JAXTensorSpec, but the others only XYZSpec (w/o the "tensor", e.g. "TFSpec")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed everything to XXXTensorSpec to be more precise, we may have XXXDistributionSpecs down the line too.

@override(TensorSpec)
def get_type(cls) -> Type:
return jnp.ndarray

@override(TensorSpecs)
@override(TensorSpec)
def get_shape(self, tensor: jnp.ndarray) -> Tuple[int]:
return tuple(tensor.shape)

@override(TensorSpecs)
@override(TensorSpec)
def get_dtype(self, tensor: jnp.ndarray) -> Any:
return tensor.dtype

@override(TensorSpecs)
@override(TensorSpec)
def _full(
self, shape: Tuple[int], fill_value: Union[float, int] = 0
) -> jnp.ndarray:
Expand Down
12 changes: 6 additions & 6 deletions rllib/models/specs/specs_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
import numpy as np

from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.models.specs.specs_base import TensorSpecs
from ray.rllib.models.specs.specs_base import TensorSpec


@DeveloperAPI
class NPSpecs(TensorSpecs):
@override(TensorSpecs)
class NPTensorSpec(TensorSpec):
@override(TensorSpec)
def get_type(cls) -> Type:
return np.ndarray

@override(TensorSpecs)
@override(TensorSpec)
def get_shape(self, tensor: np.ndarray) -> Tuple[int]:
return tuple(tensor.shape)

@override(TensorSpecs)
@override(TensorSpec)
def get_dtype(self, tensor: np.ndarray) -> Any:
return tensor.dtype

@override(TensorSpecs)
@override(TensorSpec)
def _full(self, shape: Tuple[int], fill_value: Union[float, int] = 0) -> np.ndarray:
return np.full(shape, fill_value, dtype=self.dtype)
Loading