-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix policy_map
always loading all policies from disk due to (not always needed) global_vars
update.
#22010
Conversation
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
sorry to miss this PR until now. |
If those policies are necessary, they are still accessible to be loaded from disk. However, they should only be loaded when they are actually needed. Without this change having more policies than the policy_map_cache is unusable, as the disk will constantly be accessed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, a couple of more questions.
thanks a lot for the contribution!
func(policy, pid, **kwargs) | ||
for pid, policy in self.policy_map.items() | ||
func(self.policy_map[pid], pid, **kwargs) | ||
for pid in self.policy_map.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't really see the difference between current and updated version here.
can you say a few words about the intention here?
same below in get_weights() function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes from the view of the rollout_worker
the behavior shouldn't change.
The difference is that iterating over items()
will access both the key and the policy, i.e. internally PolicyMap
will call __getitem__
. __getitem__
will access the disk, if the selected policy is not in the cache. Every policy is accessed and then afterwards the if statement in the next line checks whether we are actually interested in this policy. With my changes we only call __getitem__
for those policies that pass through the if statement. This way we reduce the number of unnecessary access of policies and in turn reduce the number of unnecessary _read_from_disk()
(Compare rllib/policy/policy_map.py
).
The same goes for get_weights()
.
In short: Don't do policy_map[pid]
if we skip this pid
anyway, due to the if statement. This is important because policy_map[pid]
might have disk reads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯 this is some highly optimized stuff.
can you please help add a big comment above these 2 places, about what you just said, so we don't come in and accidentally "clean up" these loops some day.
appreciate the thoughtful changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
understand! this is great.
I think global_vars are mostly for training purpose, so it should be fine to skip if a policy is not being trained.
@@ -1561,7 +1561,7 @@ def set_global_vars(self, global_vars: dict) -> None: | |||
Examples: | |||
>>> global_vars = worker.set_global_vars({"timestep": 4242}) | |||
""" | |||
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) | |||
self.foreach_policy_to_train(lambda p, _: p.on_global_var_update(global_vars)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. so this PR doesn't really change what policies get loaded to the map, it only changes what policies get updated. is my understanding correct?
thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this change only changes what policy is updated with global vars.
The problem:
- In order to update a policy, if it isn't in the memory cache it will need to be loaded from disk.
- Looking through the project, it looks like every time
set_global_vars()
is called_get_global_vars
is passed as argument - I'm not sure what kind of global vars there usually are, but looking into
rllib/execution/common.py
it looks like_get_global_vars()
basically always returns an updated timestep - Consequently, if any policy is written to disk, these are read from disk all the time to update their timesteps and we have constant disk access again, slowing everything down, kind of defeating the reason why we put them to the disk in the first place.
The solution:
- What at least helps with this problem is only updating global vars for policies that are being trained.
- If all policies that are being trained fit into the in-memory cache, the disk won't constantly be accessed.
- If they don't fit, then I see no way to stop constant disk access anyway.
- I'm not 100% sure this doesn't have other consequences, someone with better knowledge of the code base would have to judge that. However, I see no reason why policies that are not being trained would need their vars updated (are there even other vars besides timesteps).
- At any rate, without this case we are again at a point where any disk that is being offloaded to disk is constantly being accessed, making it basically unusable to have more policies than the
policy_map_capacity
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @PavelCz, one thing though, can you rebase to the latest master, so we can make sure all the tests pass? |
I added comments to the two changes and also added a short comment to the third change. Let me know if those are ok. |
policy_map
always loading all policies from disk due to (not always needed) global_vars
update.
… (not always needed) `global_vars` update. (ray-project#22010) (cherry picked from commit de0c6f6)
Why are these changes needed?
policy_map_capacity
, the in-memory cache should contain the frequently used policies and the less frequently used are stashed to disktimesteps
is updated for all policies, causing all policies to be loaded from disk, making the cache useless. My change makes it so that only policies that are being trained need to be loaded from diskRelated issue number
n/a
Checks
scripts/format.sh
to lint the changes in this PR.