Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masking of different samples in a long sequence for flash-attention mechanism #32875

Closed
wants to merge 5 commits into from

Conversation

sz128
Copy link

@sz128 sz128 commented Aug 19, 2024

What does this PR do?

Fixes # (issue)

In LLM training, we always choose to pack short samples in one sequence for efficient training. In this situation, it is ideal to do masking for different samples.

For casual self-attention implementation, we can use a 3-D mask matrix to mask different samples. But, for flash-attention which do not support a casual 3-D mask matrix, we need a shortcut.

The attention_mask_in_length is utilized to mask other short samples. The motivation for this function is explained here.

  1. We can utilize attention_mask_in_length to get indices and lengths of all short samples.
  2. Next, long sequence embeddings and length indicators are fed into the Flash attention mechanism to obtain its outputs.
  3. Finally, through the use of an inverse operation, we can rearrange the outputs to match the shape of the original batch.
An example of attention_mask_in_length

For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:

        [
          [2, 3],
          [3, 2],
          [6, 0]
        ]

, which refers to the 3D-attention mask:

        [
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [0, 0, 1, 0, 0, 0],
            [0, 0, 1, 1, 0, 0],
            [0, 0, 1, 1, 1, 0],
            [0, 0, 0, 0, 0, 1]
          ],
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [0, 0, 0, 1, 0, 0],
            [0, 0, 0, 1, 1, 0],
            [0, 0, 0, 0, 0, 1]
          ],
          [
            [1, 0, 0, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [1, 1, 1, 1, 0, 0],
            [1, 1, 1, 1, 1, 0],
            [1, 1, 1, 1, 1, 1]
          ]
        ]
Use sample masking of flash-attention for Llama3

https://github.com/sz128/LLMs_implementations/blob/main/sample_mask_with_flash-attn-2.ipynb

Here is the core code to get attention_mask_in_length for megatron-LM data collator (https://github.com/microsoft/Megatron-DeepSpeed/blob/b7b2d5ef330f43729b406630e6c5d38e873d7398/megatron/utils.py#L162):

def mask_concat_samples(batch_data, eos_token_id, reset_position_ids=False):
    input_ids = batch_data["input_ids"]
    labels = batch_data["labels"].clone()
    micro_batch_size, seq_length = input_ids.shape

    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

    inner_sample_lengths = torch.zeros((micro_batch_size, seq_length), dtype=torch.int)
    for b in range(micro_batch_size):
        # Find indecies where EOD token is.
        eod_index = position_ids[b, input_ids[b] == eos_token_id]
        # Detach indecies from positions if going to modify positions.
        if reset_position_ids:
            eod_index = eod_index.clone()

        prev_index = -1
        for j in range(len(eod_index)):
            inner_sample_lengths[b, j] = eod_index[j] - prev_index
            prev_index = eod_index[j]
            if eod_index[j] < seq_length - 1:
                labels[b, eod_index[j]+1] = -100

        if prev_index < seq_length - 1:
            inner_sample_lengths[b, len(eod_index)] = seq_length - 1 - prev_index

        #print(len(input_ids[b]), sum(inner_sample_lengths[b]))
        assert len(input_ids[b]) == sum(inner_sample_lengths[b]).item()

        if reset_position_ids and len(eod_index) > 1:
            for j in range(1, len(eod_index)):
                i = eod_index[j]
                prev_len = eod_index[j-1]
                position_ids[b, i:] -= (i - prev_len)

    batch_data["labels"] = labels
    batch_data["attention_mask"] = inner_sample_lengths

    if reset_position_ids:
        batch_data["position_ids"] = position_ids

We can also change attention_mask for DataCollatorWithFlattening

class DataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
"""
def __init__(self, *args, reset_position_ids=True, reset_attention_mask=False, **kwargs):
super().__init__(*args, **kwargs)
if 'return_position_ids' in kwargs:
warnings.warn(
"The `return_position_ids` argument is deprecated and will be removed in a future version, "
"use `reset_position_ids` instead.",
)
reset_position_ids = kwargs.pop("return_position_ids")
self.reset_position_ids = reset_position_ids
self.reset_attention_mask = reset_attention_mask
warnings.warn(
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
"Make sure your attention computation is able to handle it!"
)
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.reset_position_ids:
ret.update({"position_ids": []})
if self.reset_attention_mask:
ret.update({"attention_mask": []})
for idx in range(0, len(features)):
ret["input_ids"] += features[idx]["input_ids"]
if is_labels_provided:
ret["labels"] += [-100] + features[idx]["labels"][1:]
else:
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
if self.reset_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
if self.reset_attention_mask:
ret["attention_mask"].append(len(features[idx]["input_ids"]))
if self.reset_attention_mask:
# If there is none zero in attention_mask, it will be misunderstood as a full attention matrix.
ret["attention_mask"].append(0.0)
return default_data_collator([ret], return_tensors)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@eldarkurtic
Copy link
Contributor

Isn't this already supported with #31629 ?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Don't you think that ragging the tensor would be more efficient?

@sz128
Copy link
Author

sz128 commented Aug 21, 2024

Isn't this already supported with #31629 ?

It seems that both implementations are similar. But, we need to consider the situation that position ids are not reset between different samples (especially for LLM pre-training).

@sz128
Copy link
Author

sz128 commented Aug 21, 2024

Hey! Don't you think that ragging the tensor would be more efficient?

Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.

While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.

@beep-bebop
Copy link
Contributor

Hey! Don't you think that ragging the tensor would be more efficient?

Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.

While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.

dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers.

@sz128
Copy link
Author

sz128 commented Aug 26, 2024

Hey! Don't you think that ragging the tensor would be more efficient?

Yes. I didn't describe it well. I updated the description of this PR. The implementations of this PR and #31629 are very similar. But, we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training. Therefore, this implementation opts to store short-sequence lengths in the attention mask matrix.
While borrowing the attention_mask might not be elegant, introducing a new variable could entail a substantial amount of effort.

dalao nihao! What is ‘position IDs are not reset between different short samples’ specifically and why is this especially seen in pre-training. Very helpful work, thanks for the answers.

Ohh, I mean it is useful in long-context pre-training.

For long-context training (for example, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. In this case, position IDs should not be reset.

@ArthurZucker
Copy link
Collaborator

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids.
As passing the correct positions ids is already supported, IMO we should not add this!

@sz128
Copy link
Author

sz128 commented Aug 28, 2024

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.

@beep-bebop
Copy link
Contributor

beep-bebop commented Sep 2, 2024

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.

A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃‍♂️

@sz128
Copy link
Author

sz128 commented Sep 3, 2024

we need to consider the scenario where position IDs are not reset between different short samples, especially for LLM pre-training

does this imply us properly computing the positions ids? It's something we'd rather avoid in general as forcing the user to pass the positions ids. As passing the correct positions ids is already supported, IMO we should not add this!

No. There is a scenario where position IDs can not be reset between different short samples. For long-context training (e.g, 128k long sequence), we may utilize synthesized samples which are usually concatenations of short samples. If position IDs are reset, it is no longer a long-context training.

A less native approach might be to tokenise the data and then splice the data piecewise via the map method. Then use this dataset for training. Of course, the default batch_size for map is 1000, so we may need to set this value according to the data, as the total length of these 1000 samples may less than 128k. It would be nice to have native support for this kind of processing. 🏃‍♂️

You can take a look at #14767.

@eldarkurtic
Copy link
Contributor

Why wouldn't we use position_ids to encode all information (packed, not packed, padded, not padded) in a slightly more elegant way without touching attention_mask?

For example let's say that the sequence length is 8:

  1. for perfectly packed sequences (e.g. two sequences of length 4): attention_mask = [1, 1, 1, 1, 1, 1, 1, 1] and position_ids = [0, 1, 2, 3, 0, 1, 2, 3]
  2. for partially packed sequences (e.g. two sequences of length 3 and the rest are padding tokens): attention_mask = [1, 1, 1, 1, 1, 1, 0, 0] and position_ids = [0, 1, 2, 0, 1, 2, 0, 1]

When the information about sequences is stored in position_ids (contrary to attention_mask) the positional embeddings are automatically calculated in a correct way, so it is minimally invasive approach relative to the existing transformers codebase.

@sz128
Copy link
Author

sz128 commented Sep 4, 2024

Why wouldn't we use position_ids to encode all information (packed, not packed, padded, not padded) in a slightly more elegant way without touching attention_mask?

For example let's say that the sequence length is 8:

  1. for perfectly packed sequences (e.g. two sequences of length 4): attention_mask = [1, 1, 1, 1, 1, 1, 1, 1] and position_ids = [0, 1, 2, 3, 0, 1, 2, 3]
  2. for partially packed sequences (e.g. two sequences of length 3 and the rest are padding tokens): attention_mask = [1, 1, 1, 1, 1, 1, 0, 0] and position_ids = [0, 1, 2, 0, 1, 2, 0, 1]

When the information about sequences is stored in position_ids (contrary to attention_mask) the positional embeddings are automatically calculated in a correct way, so it is minimally invasive approach relative to the existing transformers codebase.

For long-context training (e.g., 128k long sequence), we may utilize synthesized samples, which are usually concatenations of short samples. Thus, we don't want to reset position_ids in this case.

For instance, consider four sequences of length 32k. The attention_mask would be [32k, 32k, 32k, 32k, 0, ..., 0, 0, 0], and the position_ids would be [0, 1, 2, 3, 4, 5, 6, 7, ..., 128k-1]. This allows the model to learn position embeddings for longer sequences.

In contrast, if we use attention_mask = [1, 1, 1, ..., 1, 1] and position_ids = [0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1, 0, 1, 2, ..., 32k-1], the model can only learn position embeddings in the range of [0, 32k-1].

@eldarkurtic
Copy link
Contributor

Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator.
If they wish to treat such concatenated sequences as a single sequence they would set position_ids = [0, 1, 2, ..., 128k - 1].
If they wish to treat them as multiple shorter sequences concatenated together they would set position_ids = [0, 1, 2, ... , 32k - 1, 0, 1, 2, ..., 32k - 1, 0, 1, 2, 32k - 1].

@sz128
Copy link
Author

sz128 commented Sep 4, 2024

Yep, agree with that definitely! My proposal was to leave this choice to users to set in data collator. If they wish to treat such concatenated sequences as a single sequence they would set position_ids = [0, 1, 2, ..., 128k - 1]. If they wish to treat them as multiple shorter sequences concatenated together they would set position_ids = [0, 1, 2, ... , 32k - 1, 0, 1, 2, ..., 32k - 1, 0, 1, 2, 32k - 1].

Well, the point is that when we treat concatenated sequences as a single sequence (position_ids = [0, 1, 2, ..., 128k - 1]), we still need an attention mask to prevent self-attention between different samples within the same sequence. This attention mask is important for very long context training, which is discussed in the LLaMA-3 paper (https://arxiv.org/pdf/2407.21783, the bottom of page 6).

@eldarkurtic
Copy link
Contributor

Are you suggesting that they are using position_ids = [0, 1, 2, ..., len(concatenated_sequences) - 1] with flash_attn_varlen_func to prevent cross-document attention? If yes, I feel this is doing mix-and-match: for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.

We use an attention mask that prevents self-attention between different documents within the same
sequence.

^I have interpreted this sentence in a different way. They concatenate sequences together, and make sure that there is no cross-document attention, which would translate into:
position_id = [0, 1, 2, ..., len(seq1) - 1, 0, 1, 2, ..., len(seq2) - 1, ...]
which we would use to correctly compute positional embeddings and to figure out where each sequence starts and ends (cu_seq_lens), and use that to call flash_attn_varlen_func

elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

@sz128
Copy link
Author

sz128 commented Sep 5, 2024

for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.

Yes. We want to learn position embeddings for larger position ids (32k~128k-1).

@eldarkurtic
Copy link
Contributor

for positional embeddings (which are computed based on position_ids) we are pretending like we have a single sequence, but then when computing attention we decouple sequences and treat them separately.

Yes. We want to learn position embeddings for larger position ids (32k~128k-1).

Are there any ablation studies which show that this mix-and-match approach helps?

@sz128 sz128 closed this Sep 10, 2024
@sz128
Copy link
Author

sz128 commented Sep 10, 2024

Since most LLMs use ROPE positions, which are relative, there is no difference between resetting position IDs or not for cross-document attention masking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants