-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add initial support for torch.utils.checkpoint #1127
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! @t-vi, do you want to take a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced of the design.
Why would we not just let the function be any function and have a state "currently checkpointing" that informs Thunder to add a tag to the proxies that are generated during the checkpointing instead?
We would need to clear that tag on the outputs, but that would be easier than reentrant jit and higher order functions.
Do you have ideas about how the "currently checkpointing" approach would generalize to supporting, for example, |
I don't have immediate ideas, but I don't see that we should be having higher order functions right now. |
@IvanYashchuk You might wanna checkout selective activation checkpointing available in PyTorch nightlies: https://pytorch.org/docs/main/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts to specify which activations to save for backward. |
Awesome, thanks for the link, Syed! Not a fan of ATen ops leaking into the PyTorch Python interface with |
A checkpointed function doesn't save any intermediates from forward to backward. Instead, all required values are recomputed during the backward pass. Because less intermediates are saved, peak memory usage is usually decreased.
This PR introduces the support of recognizing
torch.utils.checkpoint.checkpoint
calls and inserting a new bound symbol in the initial trace. Then in the forward-backward generation pass this bound symbol is converted into augmented forward and backward parts of the computation. This step requires the function argument tothunder.torch.checkpoint
be a Thunder function. Currently, there's no conversion PyTorch->Thunder implemented and this works only for simple functions that are both recognized by Thunder and PyTorch, for example when only methods are used.The PyTorch function needs to be converted to a Thunder function in Thunder's JIT. Previously we could simply use
thunder.preprocess
which is not available today. When I attempted implementing a redispatching/reinterpretation of PyTorch functions usinggeneral_thunder_jit
I hit the following bug: #1126.Example:
Forward execution trace:
Backward execution trace: