Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maybe Wrong implementation of AttentionWithRoPE for GPTJ and GPT-NeoX? #747

Closed
PanQiWei opened this issue Aug 12, 2023 · 11 comments · Fixed by #941
Closed

Maybe Wrong implementation of AttentionWithRoPE for GPTJ and GPT-NeoX? #747

PanQiWei opened this issue Aug 12, 2023 · 11 comments · Fixed by #941
Labels
bug Something isn't working

Comments

@PanQiWei
Copy link

PanQiWei commented Aug 12, 2023

I think there may be a wrong implementation for GPTJ and GPT-NeoX when doing 'apply rotary embedding'.

Currently implemented PagedAttentionWithRoPE always use the whole query and key, which is compatible with models like llama and baichuan, however for GPTJ and GPT-NeoX they may only use part of query and key when doing rope.

If current implementation does not compatible with those two models, I would suggest for those models that not using the whole query and key when applying rotary embeddings, can have another attention class that inherit from PagedAttentionWithRoPE and do something like:

query_rot, query_pass = self._prepare_tensor_for_rope(query, self.rotary_dim)
key_rot, key_pass = self._prepare_tensor_for_rope(key, self.rotary_dim)
pos_encoding_ops.rotary_embedding_neox(
    positions,
    query_rot,
    key_rot,
    self.rotary_dim,
    self.cos_sin_cache,
)
query = self._cat_tensor_after_rope(query_rot, query_pass)
key = self._cat_tensor_after_rope(key_rot, key_pass)
@PanQiWei PanQiWei changed the title Wrong implementation of GPTJ and GPT-NeoX Maybe Wrong implementation of AttentionWithRoPE for GPTJ and GPT-NeoX? Aug 12, 2023
@syskn
Copy link

syskn commented Aug 12, 2023

The implementation looks correct to me. pos_encoding_kernels.cu uses int rot_dim = cos_sin_cache.size(1);
which is determined by rotary_dim passed from

rotary_dim = int(self.head_size * config.rotary_pct)
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, scaling, rotary_dim)

in GPTNeoXAttention @ vllm/model_executor/models/gpt_neox.py

Strange inference results has especially been reported for GPTJ though: #590

@PanQiWei
Copy link
Author

I think one of mainly problem is how to rotate tensor. Referenced from HF transformers' implementation, for GPT-J it's rotate_every_two and for GPT-NeoX or LLaMa it's rotate_half, which will cause different results.

@PanQiWei
Copy link
Author

PanQiWei commented Aug 12, 2023

I haven't read the .cu code yet. But if my understanding is correct, it should always get and apply pos embeds to the top rot_dim columns in the tensor's last dimension, if so I think GPT-NeoX's implementation should be correct.

@lucasjinreal
Copy link

It should be addressed, since vllm output from many user's side are not aligned with hf, from the result it become more stupid than hf's output.

@PanQiWei
Copy link
Author

Addition notes, I think one can claim a vLLM model's generation quality is worse than HF's only when they doing following things:

  1. loading the model from .bin or .pt file instead of .safetensors file;
  2. done a ppl benchmark and confirm there is a big difference between two models from different framework
  3. using beam search decode strategy rather than sampling decode strategy

@lucasjinreal
Copy link

@PanQiWei

How to make sure beam search enabled in vllm?

@PanQiWei
Copy link
Author

PanQiWei commented Aug 15, 2023

@PanQiWei

How to make sure beam search enabled in vllm?

Turn the flag use_beam_search on in SamplingParams or your request payload. And make sure n > 1(thus vLLM not support greedy search)

@lucasjinreal
Copy link

@PanQiWei Does n = 2 means using 2 beam search? Does it work in stream mode?

@PanQiWei
Copy link
Author

@PanQiWei Does n = 2 means using 2 beam search? Does it work in stream mode?

sorry my bad, in openai api it should best_of=2 aka beam_size=2; and I don't think it work in stream mode.

@lucasjinreal
Copy link

@PanQiWei So there still have some bias between hf and vllm.

@WoosukKwon WoosukKwon added the bug Something isn't working label Sep 1, 2023
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Sep 4, 2023

Hi @PanQiWei @lucasjinreal @syskn , thanks for letting us know the bug and the solution. As you pointed out, I misunderstood the rotary embedding in GPT-J and treated it equal to the RoPE used by GPT-NeoX. #941 fixes the bug. Apologies for the confusion and inconvenience.

@PanQiWei PanQiWei closed this as completed Sep 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants