Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] created check_specs decorator, RLModule PR 1/N #29599

Merged
merged 12 commits into from
Oct 25, 2022

Conversation

kouroshHakha
Copy link
Contributor

@kouroshHakha kouroshHakha commented Oct 24, 2022

Why are these changes needed?

submitting this PR in pieces:
check_specs decorator can be added to any module method to enforce input/output struct types. This is useful for imposing a certain input/output behavior in RLModule without taking away the flexibility of implementation details from the user. User would also be efficiently informed about what needs to be implemented.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…he names a little bit for generality

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@@ -28,7 +28,7 @@ def validate(self, data: Any) -> None:


@DeveloperAPI
class TensorSpecs(SpecsAbstract):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized they can represent only one spec, hence renaming :)

@kouroshHakha kouroshHakha changed the title [RLlib] created check_specs decorator [RLlib] created check_specs decorator, RLModule PR 1/N Oct 24, 2022
@@ -85,7 +85,7 @@ class NestedDict(Generic[T], MutableMapping[str, Union[T, "NestedDict"]]):
>>> # 'b': {'c': 200, 'd': 300}}
>>> # Getting elements, possibly nested:
>>> print(foo_dict['b', 'c']) # 200
>>> print(foo_dict['b']) # IndexError("Use get for partial indexing.")
>>> print(foo_dict['b']) # IndexError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: Is there a reason why we make this seemingly arbitrary distinction between accessing by get (no error) vs direct (error)? Why should both not return the sub-dict? How would the user know this difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point @sven1977, In hindsight I don't see any reason for not returning the sub-nested dict if __getitem__ is used. It actually confused myself at some point during a later pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

raise IndexError(
f"Key `{k}` is not a complete key in the given "
f"{self.__class__.__name__}. It results in a container "
f"with subkeys {set(output.keys())}. To get partial indexing, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> "To use partial indexing and thus retrieve a sub-structure ..."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed partial indexing error all together due the valid comment above. Also update the examples in the docstring to show that index error is not raised anymore.

lambda: correct_module.check_input_and_output_wo_filter(input_dict),
)

def test_cache(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice test, going the extra mile!

)

# this should not raise an error because output is not forced to be checked
incorrect_module.check_only_input({"input": 2})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question, why would the decorator itself not already complain when it's being instantiated b/c of the missing output check? In other words, why is it allowed to have an implementation that doesn't check both, in- and output?

Copy link
Contributor Author

@kouroshHakha kouroshHakha Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

incorrect_module's run implementation does not have the correct output type. Therefore those functions that enforce output type checking should raise an error and those that don't should just ignore the output spec enforcement. This is only determined when the function is actually executed and not during the function decoration itself. Implementation if the decorated function is only visible when the function is invoked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, that makes sense! Thanks for clarifying.

@@ -11,20 +11,20 @@


@DeveloperAPI
class JAXSpecs(TensorSpecs):
@override(TensorSpecs)
class JAXTensorSpec(TensorSpec):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call this JAXTensorSpec, but the others only XYZSpec (w/o the "tensor", e.g. "TFSpec")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed everything to XXXTensorSpec to be more precise, we may have XXXDistributionSpecs down the line too.

return f"ModelSpec({repr(self._data)})"


def check_specs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice! This is going to be dope for early-catching user errors.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR! Thanks for being very meticulous about designing these new APIs from the ground up. These will comprise ground-breaking advances for RLlib toward more user-friendliness and transparency.

Just a bunch of questions and nits.

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@kouroshHakha
Copy link
Contributor Author

@sven1977 Please re-review.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@sven1977
Copy link
Contributor

test_nested_dict failing
cc: @kouroshHakha

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty good. just some minor issues.

... },
... "action": TensorSpecs("b, d_a", h=12),
... "action": TensorSpec("b, d_a", h=12),
... "action_dist": torch.distributions.Categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a random question. do you intend for these specs to be checkpointed and restored?
if so, maybe a registry of distribution is a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. This is a good point. This should be incorporated into the RLModule PR then. I'll look into this there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to write this directly to the state_dict, but this might break logic that assumes state_dict only contains tensors.

if exact_match:
data_spec_missing_keys = data_keys_set.difference(self._keys_set)
if data_spec_missing_keys:
raise ValueError(_MISSING_KEYS_FROM_SPEC.format(data_spec_missing_keys))

for spec_name, spec in self.items():
data_to_validate = data[spec_name]
if isinstance(spec, TensorSpecs):
spec.validate(data_to_validate)
if isinstance(spec, TensorSpec):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for safety, this needs an "else" clause that just raises error if spec is some random data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

)
data = NestedDict(data)

if should_validate():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you check should_validate() first thing, so we don't waste time building NestedDict, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, You may still need to filter the data if it's a mapping regardless of whether you should validate or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌 👌


input_data_ = input_data
if input_spec:
input_spec_ = getattr(self, input_spec)()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can input_spec_ just be a member variable on self?
or maybe we can simply assume that the specs will be provided by RLModule under some hardcoded key names.
so by default, folks can simply use the decorator without specifying any parameters.

Copy link
Contributor Author

@kouroshHakha kouroshHakha Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will assume some hard-coded names on the base class and is not the intention of this general-purpose decorator. The decorator let's the user choose their own spec names. so it can be applied to any base class essentially. It may become handy in defining Pi base class that looks different than RLModule base class. You can see the use-case in the RLModule PR.

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@kouroshHakha
Copy link
Contributor Author

@gjoliver Can we merge this? The failed tests are again not related.

@kouroshHakha kouroshHakha added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Oct 25, 2022
@gjoliver
Copy link
Member

mem leak test failures are not related.

@gjoliver gjoliver merged commit 45420f5 into ray-project:master Oct 25, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
…9599)

* 1. created check_specs decorator 2. updated unittests 3. refactored the names a little bit for generality

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants