-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Maybe Wrong implementation of AttentionWithRoPE for GPTJ and GPT-NeoX? #747
Comments
The implementation looks correct to me.
in GPTNeoXAttention @ vllm/model_executor/models/gpt_neox.py Strange inference results has especially been reported for GPTJ though: #590 |
I think one of mainly problem is how to rotate tensor. Referenced from HF transformers' implementation, for GPT-J it's |
I haven't read the .cu code yet. But if my understanding is correct, it should always get and apply pos embeds to the top rot_dim columns in the tensor's last dimension, if so I think GPT-NeoX's implementation should be correct. |
It should be addressed, since vllm output from many user's side are not aligned with hf, from the result it become more stupid than hf's output. |
Addition notes, I think one can claim a vLLM model's generation quality is worse than HF's only when they doing following things:
|
How to make sure beam search enabled in vllm? |
Turn the flag |
@PanQiWei Does n = 2 means using 2 beam search? Does it work in stream mode? |
sorry my bad, in openai api it should best_of=2 aka beam_size=2; and I don't think it work in stream mode. |
@PanQiWei So there still have some bias between hf and vllm. |
Hi @PanQiWei @lucasjinreal @syskn , thanks for letting us know the bug and the solution. As you pointed out, I misunderstood the rotary embedding in GPT-J and treated it equal to the RoPE used by GPT-NeoX. #941 fixes the bug. Apologies for the confusion and inconvenience. |
I think there may be a wrong implementation for GPTJ and GPT-NeoX when doing 'apply rotary embedding'.
Currently implemented
PagedAttentionWithRoPE
always use the wholequery
andkey
, which is compatible with models likellama
andbaichuan
, however for GPTJ and GPT-NeoX they may only use part of query and key when doing rope.If current implementation does not compatible with those two models, I would suggest for those models that not using the whole query and key when applying rotary embeddings, can have another attention class that inherit from
PagedAttentionWithRoPE
and do something like:The text was updated successfully, but these errors were encountered: