-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix stateful module errors with inference only mode. #45465
[RLlib] Fix stateful module errors with inference only mode. #45465
Conversation
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
…' mode. Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
@@ -20,6 +20,9 @@ class PPORLModule(RLModule, abc.ABC): | |||
def setup(self): | |||
# __sphinx_doc_begin__ | |||
catalog = self.config.get_catalog() | |||
# If we have a stateful model states for the critic need to be collected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we also use is_stateful()
here? What if the user doesn't use the built-in use_lstm option, but comes with their own stateful model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sven1977 this was my first intend, however at this point in time is_stateful()
cannot be called, yet b/c the encoder is not yet defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this is not a nice solution, but at this point in the code we need to know, if the module is stateful or not, but the is_stateful()
depends on the encoder which is defined depending on inference-only
being True/False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one question about use_lstm
not being generic enough as a criterion to determine statefulnes..
…ecurrentEncoderConfig' and added an additional check for 'inference-only' b/c negation resulted in learner modules being 'inference-only'. This is fixed now. Signed-off-by: Simon Zehnder <[email protected]>
I replaced this now with a more generic approach using |
Why are these changes needed?
Stateful models need to collect states from the critic during sampling, therefore they cannot be
inference-only
. This PR fixes this error by settinginference-only=False
for stateful modules and checking statefulness in theget_state
.Related issue number
Related to #44758
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.