Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] MARWIL RLModule #44970

Merged
merged 38 commits into from
Aug 1, 2024
Merged

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Apr 25, 2024

Why are these changes needed?

This PR implements MARWIL in the new API stack using Learner API and RLModule API. It relates to the proposal for the new Offline Data API to be used in its training step.

Related issue number

#44969
Closes #37775

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Simon Zehnder <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…matted commit merely for securing the work.

Signed-off-by: simonsays1980 <[email protected]>
…' and 'MARWILTorchPolicy', fixed imports and tested MARWIL on non-recurrent policies.

Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
@sven1977 sven1977 changed the title [RLlib] - MARWIL RLModule [RLlib] MARWIL RLModule Jul 24, 2024
@sven1977 sven1977 marked this pull request as ready for review July 24, 2024 15:36
@@ -59,7 +63,7 @@ def test_bc_compilation_and_learning_from_offline_file(self):
results = algo.train()
print(results)

eval_results = results.get("evaluation", {})
eval_results = results.get(EVALUATION_RESULTS, {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -49,11 +50,13 @@ def possibly_masked_mean(t):
mask = None
possibly_masked_mean = torch.mean

action_dist_class_train = self.module[module_id].get_train_action_dist_cls()
action_dist_class_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but I wonder whether we should make the MARLModule return already the unwrapped RLModule automatically when we access a sub-module through __getitem__(). ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need anywhere the wrapped module? I guess not even in the DDP case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, hold on, good point. The most user-friendly way is probably to make sure the wrapper exposes:

  • all RLModule base methods.
  • all API methods that the wrapped module implements -> So the wrapper should check, what RLModule APIs (ValueFunctionAPI, TargetNetworkAPI, etc..) its wrapped module implements, then expose all this API's methods as well automatically. ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how it could work. Its not a very elegant design but I tried around some weeks ago to make it elegant by just defining all methods not yet published from the wrapped module and this is non-trivial. Have to check agaon what the reasons were.


config = (
marwil.MARWILConfig()
.env_runners(num_env_runners=1)
.api_stack(
enable_rl_module_and_learner=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


config = (
marwil.MARWILConfig()
.api_stack(
enable_rl_module_and_learner=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!

… to 'OfflineData'. Set return to reach higher for tuned example.

Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
… in linting and building.

Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
@ray-project ray-project deleted a comment from simonsays1980 Jul 29, 2024
Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Let's get this merged :) Thanks @simonsays1980 for this great PR!

@sven1977 sven1977 enabled auto-merge (squash) July 29, 2024 06:59
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Jul 29, 2024
@github-actions github-actions bot disabled auto-merge July 29, 2024 08:57
…nectors request finalized episodes.

Signed-off-by: simonsays1980 <[email protected]>
…g as this was giving an error when 'MARWILOfflinePreLearner' tried to call a value function unneeded by BC. Deprecated hybrid stack.

Signed-off-by: simonsays1980 <[email protected]>
…. BC depends now fully on MARWIL.

Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
…epcrecated. Moved to old stack as it uses policies.

Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
@sven1977 sven1977 merged commit 9df091a into ray-project:master Aug 1, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RLlib] Enable RLModule by default on MARWIL
2 participants