-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] created check_specs decorator, RLModule PR 1/N #29599
Conversation
…he names a little bit for generality Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@@ -28,7 +28,7 @@ def validate(self, data: Any) -> None: | |||
|
|||
|
|||
@DeveloperAPI | |||
class TensorSpecs(SpecsAbstract): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realized they can represent only one spec, hence renaming :)
rllib/utils/nested_dict.py
Outdated
@@ -85,7 +85,7 @@ class NestedDict(Generic[T], MutableMapping[str, Union[T, "NestedDict"]]): | |||
>>> # 'b': {'c': 200, 'd': 300}} | |||
>>> # Getting elements, possibly nested: | |||
>>> print(foo_dict['b', 'c']) # 200 | |||
>>> print(foo_dict['b']) # IndexError("Use get for partial indexing.") | |||
>>> print(foo_dict['b']) # IndexError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question: Is there a reason why we make this seemingly arbitrary distinction between accessing by get
(no error) vs direct (error)? Why should both not return the sub-dict? How would the user know this difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point @sven1977, In hindsight I don't see any reason for not returning the sub-nested dict if __getitem__
is used. It actually confused myself at some point during a later pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
rllib/utils/nested_dict.py
Outdated
raise IndexError( | ||
f"Key `{k}` is not a complete key in the given " | ||
f"{self.__class__.__name__}. It results in a container " | ||
f"with subkeys {set(output.keys())}. To get partial indexing, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> "To use partial indexing and thus retrieve a sub-structure ..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed partial indexing error all together due the valid comment above. Also update the examples in the docstring to show that index error is not raised anymore.
lambda: correct_module.check_input_and_output_wo_filter(input_dict), | ||
) | ||
|
||
def test_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice test, going the extra mile!
) | ||
|
||
# this should not raise an error because output is not forced to be checked | ||
incorrect_module.check_only_input({"input": 2}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question, why would the decorator itself not already complain when it's being instantiated b/c of the missing output check? In other words, why is it allowed to have an implementation that doesn't check both, in- and output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
incorrect_module
's run
implementation does not have the correct output type. Therefore those functions that enforce output type checking should raise an error and those that don't should just ignore the output spec enforcement. This is only determined when the function is actually executed and not during the function decoration itself. Implementation if the decorated function is only visible when the function is invoked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yes, that makes sense! Thanks for clarifying.
@@ -11,20 +11,20 @@ | |||
|
|||
|
|||
@DeveloperAPI | |||
class JAXSpecs(TensorSpecs): | |||
@override(TensorSpecs) | |||
class JAXTensorSpec(TensorSpec): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why call this JAXTensorSpec, but the others only XYZSpec
(w/o the "tensor", e.g. "TFSpec")?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed everything to XXXTensorSpec
to be more precise, we may have XXXDistributionSpecs
down the line too.
return f"ModelSpec({repr(self._data)})" | ||
|
||
|
||
def check_specs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice! This is going to be dope for early-catching user errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome PR! Thanks for being very meticulous about designing these new APIs from the ground up. These will comprise ground-breaking advances for RLlib toward more user-friendliness and transparency.
Just a bunch of questions and nits.
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@sven1977 Please re-review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
|
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks pretty good. just some minor issues.
... }, | ||
... "action": TensorSpecs("b, d_a", h=12), | ||
... "action": TensorSpec("b, d_a", h=12), | ||
... "action_dist": torch.distributions.Categorical |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a random question. do you intend for these specs to be checkpointed and restored?
if so, maybe a registry of distribution is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. This is a good point. This should be incorporated into the RLModule PR then. I'll look into this there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to write this directly to the state_dict
, but this might break logic that assumes state_dict
only contains tensors.
if exact_match: | ||
data_spec_missing_keys = data_keys_set.difference(self._keys_set) | ||
if data_spec_missing_keys: | ||
raise ValueError(_MISSING_KEYS_FROM_SPEC.format(data_spec_missing_keys)) | ||
|
||
for spec_name, spec in self.items(): | ||
data_to_validate = data[spec_name] | ||
if isinstance(spec, TensorSpecs): | ||
spec.validate(data_to_validate) | ||
if isinstance(spec, TensorSpec): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for safety, this needs an "else" clause that just raises error if spec is some random data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
) | ||
data = NestedDict(data) | ||
|
||
if should_validate(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should you check should_validate() first thing, so we don't waste time building NestedDict, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, You may still need to filter the data if it's a mapping regardless of whether you should validate or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👌 👌
|
||
input_data_ = input_data | ||
if input_spec: | ||
input_spec_ = getattr(self, input_spec)() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can input_spec_ just be a member variable on self?
or maybe we can simply assume that the specs will be provided by RLModule under some hardcoded key names.
so by default, folks can simply use the decorator without specifying any parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will assume some hard-coded names on the base class and is not the intention of this general-purpose decorator. The decorator let's the user choose their own spec names. so it can be applied to any base class essentially. It may become handy in defining Pi base class that looks different than RLModule base class. You can see the use-case in the RLModule PR.
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@gjoliver Can we merge this? The failed tests are again not related. |
mem leak test failures are not related. |
…9599) * 1. created check_specs decorator 2. updated unittests 3. refactored the names a little bit for generality Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
Why are these changes needed?
submitting this PR in pieces:
check_specs decorator can be added to any module method to enforce input/output struct types. This is useful for imposing a certain input/output behavior in RLModule without taking away the flexibility of implementation details from the user. User would also be efficiently informed about what needs to be implemented.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.