Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update attention.py #1416

Merged
merged 4 commits into from
Oct 9, 2023
Merged

Update attention.py #1416

merged 4 commits into from
Oct 9, 2023

Conversation

DongHande
Copy link
Contributor

modify the code about bigcode.
This modification makes the KV cache with multiple new tokens works well.

What does this PR do?

When we use the starcoder to generate text/code with KV cache and multiple new tokens, it becomes wrong because a possible error in the torch.nn.functional.scaled_dot_product_attention() function. I have proposed a issue in pytorch in pytorch/pytorch#110144. But before pytorch fix it, the optimum can work well with minor changes.

How to re-implement the error in current version:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

checkpoint = "bigcode/starcoderbase-1b"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, torch_dtype=torch.bfloat16).to("cuda:0").to_bettertransformer()

prefix = tokenizer.encode("def quick_sort", return_tensors="pt").to("cuda:0")
outputs = model(prefix)
past_key_values = outputs.past_key_values

next_token = tokenizer.encode("(arr", return_tensors="pt").to("cuda:0") # TWO NEW TOKENS
logit = model(next_token, past_key_values = past_key_values).logits[:, -1, :]
idx_next = torch.argmax(logit, dim=1, keepdim=True)
print(tokenizer.decode(idx_next[0], skip_special_tokens=True))

Before submitting

  • [Y] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [Y] Did you make sure to update the documentation with your changes?
  • [Y] Did you write any new necessary tests?

modify the code about bigcode. 
This modification makes the KV cache with multiple new tokens works well.
@fxmarty
Copy link
Contributor

fxmarty commented Oct 5, 2023

Hi @DongHande, thank you for the PR. I will have a look shortly!

@fxmarty
Copy link
Contributor

fxmarty commented Oct 6, 2023

Thank you @DongHande for the notice, this is indeed a significant bug in our code base.

Passing a non-None attn_mask to SDPA currently can not dispatch to flash attention, so I would suggest the following in order to enable the dispatch on FA during training and when query_length=1, if that sounds good to you:

    # We treat self.training and (batch_size == 1 and query_length == 1) cases separately to still allow the dispatch to Flash Attention.
    if self.training:
        is_causal = True
        attn_mask = None
    elif batch_size == 1 and query_length == 1:
        is_causal = False
        attn_mask = None
    elif batch_size == 1 and kv_seq_len == query_length:
        is_causal = True
        attn_mask = None
    elif attention_mask is not None:
        mask_value = self._get_mask_value(query.device, query.dtype)

        # gpt_bigcode has the bad taste to use a causal mask a
        # [batch_size, target_length, 1, source_length] which is different from
        # **all** other architectures and not compatible with SDPA.
        # We could avoid this transpose by overriding the forward from GPTBigCodeModel,
        # but it is probably not worth it.
        attention_mask = attention_mask.transpose(1, 2)
        attn_mask = torch.where(attention_mask, 0.0, mask_value)
        is_causal = False
    else:
        attn_mask = None
        is_causal = True

    sdpa_result = torch.nn.functional.scaled_dot_product_attention(
        query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
    )

WDYT?

@DongHande
Copy link
Contributor Author

Thank you for your reply. I still have two questions:

(1) I don't understand why we should consider batch_size == 1 here. The attn_mask has been calculated in the outer forward function. Why not use it directly?

In other words, this function is a SDPA implementation to replace the attention operation of the Transformers Library (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L128-L203 ). In the Transformers library, it does not consider the situation of batch_size == 1. So why should consider it in the optimum library?

(2) Maybe in your reply, the last sentense should be modified
from

    sdpa_result = torch.nn.functional.scaled_dot_product_attention(
        query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
    )

to

    sdpa_result = torch.nn.functional.scaled_dot_product_attention(
        query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=is_causal
    )

?

For the first question, it is likely you have some other reasons to write in this way, and you don't have to explain it if it has a long context and is hard to explain to save your time.
But the second question may incur other errors, please review it. Thank you!

@fxmarty
Copy link
Contributor

fxmarty commented Oct 9, 2023

  1. The reason is that if the attn_mask input to F.scaled_dot_product_attention is not None, SDPA will be unable to dispatch to the Flash Attention (or FA2 in nightly) kernel. See: https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html. Transformers does not consider batch_size == 1 as a specific case as there are no optimized path in Transformers.

What I am concerned about your proposed change is that it will never dispatch to FA/FA2.

  1. Yes it should indeed be is_causal=is_causal

@DongHande
Copy link
Contributor Author

OK. I have modified my PR according to your instruction. Please review and merge it. Thank you.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thank you!

I'll keep in mind to update other archs as well :)

@fxmarty fxmarty merged commit c8cf353 into huggingface:main Oct 9, 2023
44 of 52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants