Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix activation checkpointing of SwiGLU when AMP is enabled. #1152

Merged
merged 1 commit into from
Nov 14, 2024

Conversation

warpuv
Copy link
Contributor

@warpuv warpuv commented Nov 14, 2024

Without this fix the number of tensors saved during recomputation is equal to 0.

  • moved at::AutoDispatchBelowADInplaceOrView guard after ctx->get_saved_variables().

What does this PR do?

Fixes #1151.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

Without this fix the number of tensors saved during recomputation is equal to 0.
Moved at::AutoDispatchBelowADInplaceOrView guard after ctx->get_saved_variables().
ctx->get_saved_variables() is the call where the recomputation of the forward pass occurs.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2024
Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I really understand what is going on here but the change looks harmless and if you claim it fixes your issue it's good with me.

@lw lw merged commit a561291 into facebookresearch:main Nov 14, 2024
8 of 9 checks passed
@warpuv
Copy link
Contributor Author

warpuv commented Nov 14, 2024

I'm not sure I really understand what is going on here but the change looks harmless and if you claim it fixes your issue it's good with me.

I don't understand either.

I also tried to remove the guard from the backward, it fixes the issue as well but adds some extra redispatch operations in the backward according to the LoggingTensorMode context manager.
In the docs https://pytorch.org/cppdocs/notes/inference_mode.html and examples I’ve found that guard is placed only in the forward of Autograd Function. Don't know why.

P.S. Thanks for the review and merge!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Activation checkpointing on fused SwiGLU is not working when AMP is enabled.
3 participants