-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Algorithm Level Checkpointing with Learner and RL Modules #34717
[RLlib] Algorithm Level Checkpointing with Learner and RL Modules #34717
Conversation
…oducible Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
rllib/algorithms/algorithm.py
Outdated
@@ -2131,6 +2148,17 @@ def load_checkpoint(self, checkpoint: Union[Dict, str]) -> None: | |||
else: | |||
checkpoint_data = checkpoint | |||
self.__setstate__(checkpoint_data) | |||
if isinstance(checkpoint, str) and self.config._enable_learner_api: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the location of this logic ties well to the existing code where checkpoint can take both a dict or str value. You need to map the checkpoint input (str or dict) to a checkpoint data first and then use checkpoint data inside the __setstate__()
api to set the state of the learner group.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my question is more like what do you need to do if checkpoint is a dict? when would that happen, and what would that mean for the learner group
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so this is interesting. upon further inspection, this reason that this is supposed to accept a dict is in the case that trainable.save_checkpoint
ever returns a dictionary. However, we don't do this, which means that we don't need to even support this inside of load_checkpoint
to begin with. I just ended up removing all the logic related to handling dicts.
…kpointing_learner_group_from_algo
…t this to begin with Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
…ze feature for torch (ray-project#34189)" This reverts commit 72268e8.
…r_group_from_algo
Signed-off-by: Avnish <[email protected]>
rllib/algorithms/algorithm.py
Outdated
if self.config._enable_learner_api: | ||
learner_state_dir = os.path.join(checkpoint_dir, "learner") | ||
self.learner_group.save_state(learner_state_dir) | ||
state["learner_state_dir"] = "learner/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state dict has been already dumped into a file when we get to this line. So what's the point of writing new kvs into it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover from experimenting, you're right :)
…kpointing_learner_group_from_algo
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approved contingent on tests passing. Thanks @avnishn
…y-project#34717) Signed-off-by: Avnish <[email protected]> Signed-off-by: Jack He <[email protected]>
…y-project#34717) Signed-off-by: Avnish <[email protected]>
Signed-off-by: Avnish [email protected]
This PR introduces algorithm level checkpointing with the RL modules stack. It also introduces a test for making sure that the checkpointing runs. Checkpointing however isn't seed reproducible. Upon some inspection by me and @kouroshHakha, there is some portion of the sampler that is not seed reproducible.
That being said, if I take an algorithm, checkpoint it, and then multiple times restore it and train it, the restored versions are seed reproducible with respect to each other. I've added a test that reflects this.
The more I think about it the more I realize that the algorithm won't be seed reproducible across interrupts. This is because when loading from checkpoint, we first construct an algorithm instance, then seed it, then load training state in. We aren't restoring the seeded state at the time that the algorithm was checkpointed, therefore this random state won't carry across checkpoints.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.