-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add masking of different samples in a long sequence for flash-attention mechanism #32875
Conversation
Isn't this already supported with #31629 ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Don't you think that ragging the tensor would be more efficient?
It seems that both implementations are similar. But, we need to consider the situation that position ids are not reset between different samples (especially for LLM pre-training). |
Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix. While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort. |
dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers. |
Ohh, I mean it is useful in long-context pre-training. For long-context training (for example, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. In this case, position IDs should not be reset. |
does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. |
No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training. |
A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃♂️ |
You can take a look at #14767. |
Why wouldn't we use For example let's say that the sequence length is 8:
When the information about sequences is stored in |
For long-context training (e.g., 128k long sequence), we may utilize synthesized samples, which are usually concatenations of short samples. Thus, we don't want to reset position_ids in this case. For instance, consider four sequences of length 32k. The attention_mask would be [32k, 32k, 32k, 32k, 0, ..., 0, 0, 0], and the position_ids would be [0, 1, 2, 3, 4, 5, 6, 7, ..., 128k-1]. This allows the model to learn position embeddings for longer sequences. In contrast, if we use attention_mask = [1, 1, 1, ..., 1, 1] and position_ids = [0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1], the model can only learn position embeddings in the range of [0, 32k-1]. |
Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator. |
Well, the point is that when we treat concatenated sequences as a single sequence ( |
Are you suggesting that they are using
^I have interpreted this sentence in a different way. They concatenate sequences together, and make sure that there is no cross-document attention, which would translate into: transformers/src/transformers/modeling_flash_attention_utils.py Lines 270 to 293 in ecd61c6
|
Yes. We want to learn position embeddings for larger position ids (32k~128k-1). |
Are there any ablation studies which show that this mix-and-match approach helps? |
Since most LLMs use ROPE positions, which are relative, there is no difference between resetting position IDs or not for cross-document attention masking. |
What does this PR do?
Fixes # (issue)
In LLM training, we always choose to pack short samples in one sequence for efficient training. In this situation, it is ideal to do masking for different samples.
For casual self-attention implementation, we can use a 3-D mask matrix to mask different samples. But, for flash-attention which do not support a casual 3-D mask matrix, we need a shortcut.
The
attention_mask_in_length
is utilized to mask other short samples. The motivation for this function is explained here.attention_mask_in_length
to get indices and lengths of all short samples.An example of attention_mask_in_length
For example, if batch = 3 and seqlen = 6, the
attention_mask_in_length
is:, which refers to the 3D-attention mask:
Use sample masking of flash-attention for Llama3
https://github.com/sz128/LLMs_implementations/blob/main/sample_mask_with_flash-attn-2.ipynb
Here is the core code to get
attention_mask_in_length
for megatron-LM data collator (https://github.com/microsoft/Megatron-DeepSpeed/blob/b7b2d5ef330f43729b406630e6c5d38e873d7398/megatron/utils.py#L162):We can also change attention_mask for
DataCollatorWithFlattening
transformers/src/transformers/data/data_collator.py
Lines 1617 to 1663 in 23d2c69
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.