Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor freqs_cis slice to be safer for PP #321

Merged
merged 1 commit into from
May 13, 2024

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented May 10, 2024

Stack from ghstack (oldest at bottom):

Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this. That makes it hard for stage1 to slice
freqs_cis correctly. It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 10, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d
Pull Request resolved: #321
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 10, 2024
Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense - lgtm!

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@wconstab wconstab merged commit 231ebc1 into gh/wconstab/13/base May 13, 2024
4 checks passed
wconstab added a commit that referenced this pull request May 13, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d
Pull Request resolved: #321
@wconstab wconstab deleted the gh/wconstab/13/head branch May 13, 2024 21:46
@@ -76,7 +79,9 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
"""
ndim = x.ndim
assert 0 <= 1 < ndim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not from this PR: I wonder what the point of the 0 <= 1 part is 😃 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol. its always good to check your assumptions

tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d
Pull Request resolved: pytorch#321
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this.  That makes it hard for stage1 to slice
freqs_cis correctly.  It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.

ghstack-source-id: 20ef05e0734e53260366878dfe0fac5e1ab48f1d
Pull Request resolved: pytorch#321
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants