-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] MARWIL RLModule #44970
[RLlib] MARWIL RLModule #44970
Conversation
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…matted commit merely for securing the work. Signed-off-by: simonsays1980 <[email protected]>
…' and 'MARWILTorchPolicy', fixed imports and tested MARWIL on non-recurrent policies. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…unction. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
@@ -59,7 +63,7 @@ def test_bc_compilation_and_learning_from_offline_file(self): | |||
results = algo.train() | |||
print(results) | |||
|
|||
eval_results = results.get("evaluation", {}) | |||
eval_results = results.get(EVALUATION_RESULTS, {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
@@ -49,11 +50,13 @@ def possibly_masked_mean(t): | |||
mask = None | |||
possibly_masked_mean = torch.mean | |||
|
|||
action_dist_class_train = self.module[module_id].get_train_action_dist_cls() | |||
action_dist_class_train = ( | |||
self.module[module_id].unwrapped().get_train_action_dist_cls() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but I wonder whether we should make the MARLModule return already the unwrapped RLModule automatically when we access a sub-module through __getitem__()
. ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need anywhere the wrapped module? I guess not even in the DDP
case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, hold on, good point. The most user-friendly way is probably to make sure the wrapper exposes:
- all RLModule base methods.
- all API methods that the wrapped module implements -> So the wrapper should check, what RLModule APIs (ValueFunctionAPI, TargetNetworkAPI, etc..) its wrapped module implements, then expose all this API's methods as well automatically. ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is how it could work. Its not a very elegant design but I tried around some weeks ago to make it elegant by just defining all methods not yet published from the wrapped module and this is non-trivial. Have to check agaon what the reasons were.
|
||
config = ( | ||
marwil.MARWILConfig() | ||
.env_runners(num_env_runners=1) | ||
.api_stack( | ||
enable_rl_module_and_learner=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
||
config = ( | ||
marwil.MARWILConfig() | ||
.api_stack( | ||
enable_rl_module_and_learner=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!
… to 'OfflineData'. Set return to reach higher for tuned example. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
… in linting and building. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Let's get this merged :) Thanks @simonsays1980 for this great PR!
Signed-off-by: simonsays1980 <[email protected]>
…nectors request finalized episodes. Signed-off-by: simonsays1980 <[email protected]>
…g as this was giving an error when 'MARWILOfflinePreLearner' tried to call a value function unneeded by BC. Deprecated hybrid stack. Signed-off-by: simonsays1980 <[email protected]>
…tting 'beta=0.0'. Signed-off-by: simonsays1980 <[email protected]>
…. BC depends now fully on MARWIL. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…epcrecated. Moved to old stack as it uses policies. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…he learner from MARWIL. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Why are these changes needed?
This PR implements MARWIL in the new API stack using
Learner API
andRLModule API
. It relates to the proposal for the newOffline Data API
to be used in its training step.Related issue number
#44969
Closes #37775
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.