-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistent Rotation Base for Dynamic NTK Scaling RoPE #25104
Comments
cc @ArthurZucker who works more on this (kind of) model(s). However, @NormXU, could you provide a short but self-contained code snippet to demonstrate the |
Sure The inconsistency happens here: transformers/src/transformers/models/llama/modeling_llama.py Lines 314 to 330 in a5cc30d
While LLM generates token by token beyond its maximum trained length at the inference stage, the key_states are first applied RoPE based on cos and sin w.r.t. Then when we come to the next token, key_states are applied to RoPE based on cos and sin w.r.t. transformers/src/transformers/models/llama/modeling_llama.py Lines 156 to 163 in a5cc30d
Therefore, we have an inconsistency between cached |
Actually @gante is the king of dynamic ROPE so will let him handle this! 🤗 |
Hey @NormXU 👋 I agree with your consistency issue you pointed out. However, our users' biggest concern (and ours, by extension) is empirical results, regardless of correctness. If you can share with us a few benchmarks where we see the change has positive benefits (and little to no downsides), I'll be more than happy to include it! |
@gante Of course. How about the perplexity experiments I did in my repo link, The way how we currently compute perplexity is more like we keep the rotation base consistent. Therefore, to bridge such a gap in rotation base between perplexity evaluation and inference with DynamicNTKScale, I modified the codes about how to apply the rotary embedding on keys and queries and do simple experiments on LLama1-7B. After modification, the perplexity is computed in this way: Then, I compare the perplexity and the results are shown as below Can this experiment convincingly demonstrate that a consistent DynamicNTK can achieve better perplexity in long context than an inconsistent DynamicNTK? Besides, could you please give me any advice on what benchmark I need to test this on? I have access to 8 x A100, enabling me to conduct many experiments quickly. |
@NormXU I see, that makes sense! A final question, before I make a decision: have you measured throughput vs the original dynamic scaling? Since we need to reapply RoPE to the cached values, it should introduce slowdowns. The decision to include the technique in |
@gante The main difference between my implementations and huggingface's is as follows: In the former approach, all keys are cached before RoPE is applied to a length-increasing key_states list. The latter one applies RoPE only to a single key_state. Therefore, we just need to confirm whether applying RoPE on a length-increasing key_states list will take more time than applying it to a single key_state. Here is the exec time of
You can find the exec time eval script here: |
@NormXU I see -- if I understand correctly, the execution time of Since DynamicNTK will be used in the large sequence length regime, this means that we would be incurring a high execution speed penalty, which is highly undesirable. Unless we find a way to work around this speed issue, I am against adding this modification -- execution speed is paramount in LLMs nowadays :) Follow-up question: wouldn't these proposed modifications be the same as running DynamicNTK with |
@gante Indeed, No caching = no inconsistency. In fact, I haven't found any practical downstream tasks where the consistent RoPE can bring significant performance boost. The only advantage convinces me to replace it is its potential to achieve better perplexity scores when dealing with very long contexts.. Therefore, it looks, it is not necessary to correct this inconsistency in the RoPE. Speed does matter more than correctness :) |
Thank you for discussing and iterating with us, @NormXU 💪 I'll close the issue for now. TL;DR of the discussion: the inconsistency in DynamicNTK RoPE scaling can be fixed with |
I write a blog about this problem based on our discussion. Hope this can be helpful if you happen to find this issue and want to learn more details. |
System Info
transformers
version: 4.31.0Who can help?
@ArthurZucker @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Inconsistent problem
There is a subtle rotation inconsistency in the base factor of the DynamicNTKRope implemented in transformers 4.31.0
Suppose we have a decoder model, like LLaMA-1, that utilizes DynamicNTKRope for interpolation and we want to evaluate it using perplexity. In any layer of this decoder model, after the key_states and query_states are computed from the hidden features, they are then rotated based on a fixed seq_len, which is the context length.
However, while generating token by token beyond its maximum trained length at the inference stage, LLM usually reuses previous cached keys which are rotated based on factors associated with the previous seq_len. As the seq len keeps increasing, each cached key is rotated with respect to a different base, and consequently, the inconsistency between keys and values arises.
Expected behavior
I have conducted some experiments on the inconsistency and edited the codes about applying rotation on the keys and values to ensure the rotation base consistent here. Please check the repo for further details.
While I haven't tested if a consistent rotation will benefit perplexity or downstream tasks in any dataset or language model, I believe that, from a mathematical perspective, keeping consistency in the rotation base could potentially enhance the language model's ability to learn relative positions more effectively. My intuition suggests that this consistency might offer advantages in capturing relative position information.
The text was updated successfully, but these errors were encountered: