Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Rotation Base for Dynamic NTK Scaling RoPE #25104

Closed
1 of 4 tasks
NormXU opened this issue Jul 26, 2023 · 12 comments
Closed
1 of 4 tasks

Inconsistent Rotation Base for Dynamic NTK Scaling RoPE #25104

NormXU opened this issue Jul 26, 2023 · 12 comments

Comments

@NormXU
Copy link
Contributor

NormXU commented Jul 26, 2023

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.19.0-42-generic-x86_64-with-glibc2.27
  • Python version: 3.8.0
  • PyTorch version (GPU?): 2.0.0+cu117 (True)

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Inconsistent problem

There is a subtle rotation inconsistency in the base factor of the DynamicNTKRope implemented in transformers 4.31.0

Suppose we have a decoder model, like LLaMA-1, that utilizes DynamicNTKRope for interpolation and we want to evaluate it using perplexity. In any layer of this decoder model, after the key_states and query_states are computed from the hidden features, they are then rotated based on a fixed seq_len, which is the context length.

However, while generating token by token beyond its maximum trained length at the inference stage, LLM usually reuses previous cached keys which are rotated based on factors associated with the previous seq_len. As the seq len keeps increasing, each cached key is rotated with respect to a different base, and consequently, the inconsistency between keys and values arises.

Expected behavior

I have conducted some experiments on the inconsistency and edited the codes about applying rotation on the keys and values to ensure the rotation base consistent here. Please check the repo for further details.

While I haven't tested if a consistent rotation will benefit perplexity or downstream tasks in any dataset or language model, I believe that, from a mathematical perspective, keeping consistency in the rotation base could potentially enhance the language model's ability to learn relative positions more effectively. My intuition suggests that this consistency might offer advantages in capturing relative position information.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 26, 2023

cc @ArthurZucker who works more on this (kind of) model(s).

However, @NormXU, could you provide a short but self-contained code snippet to demonstrate the inconsistency you mentioned? Thanks.

@NormXU
Copy link
Contributor Author

NormXU commented Jul 26, 2023

Sure

The inconsistency happens here:

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None

While LLM generates token by token beyond its maximum trained length at the inference stage, the key_states are first applied RoPE based on cos and sin w.r.t. kv_seq_len, then rotated key_states are cached.

Then when we come to the next token, key_states are applied to RoPE based on cos and sin w.r.t. kv_seq_len + 1. Since DynamicNTKScale has a rotation base w.r.t seq_len:

if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

Therefore, we have an inconsistency between cached keys_states and between keys_states and query_states

@ArthurZucker
Copy link
Collaborator

Actually @gante is the king of dynamic ROPE so will let him handle this! 🤗

@gante
Copy link
Member

gante commented Jul 26, 2023

Hey @NormXU 👋

I agree with your consistency issue you pointed out. However, our users' biggest concern (and ours, by extension) is empirical results, regardless of correctness.

If you can share with us a few benchmarks where we see the change has positive benefits (and little to no downsides), I'll be more than happy to include it!

@NormXU
Copy link
Contributor Author

NormXU commented Jul 26, 2023

@gante Of course.

How about the perplexity experiments I did in my repo link,

The way how we currently compute perplexity is more like we keep the rotation base consistent. Therefore, to bridge such a gap in rotation base between perplexity evaluation and inference with DynamicNTKScale, I modified the codes about how to apply the rotary embedding on keys and queries and do simple experiments on LLama1-7B.

After modification, the perplexity is computed in this way:

image

$K(\alpha(x))$ means, key_states is rotated by a rotation matrix whose base $\alpha$ is a function of sequence length.

Then, I compare the perplexity and the results are shown as below
image
This is about perplexity value on Llama1-7B, an 2k max sequence length model, values above 12.0 are cut off for concise;
Vanilla: RoPE w/o any interpolation;
NTK: DynamicNTK when scale=1;
Consistent DynamicNTK: keep rotation base between keys consistent, this is how we currently calculate perplexity
Inconsistent DynamicNTK: keep rotation base between keys inconsistent w.r.t context length;

Can this experiment convincingly demonstrate that a consistent DynamicNTK can achieve better perplexity in long context than an inconsistent DynamicNTK?

Besides, could you please give me any advice on what benchmark I need to test this on? I have access to 8 x A100, enabling me to conduct many experiments quickly.

@gante
Copy link
Member

gante commented Aug 3, 2023

@NormXU I see, that makes sense!

A final question, before I make a decision: have you measured throughput vs the original dynamic scaling? Since we need to reapply RoPE to the cached values, it should introduce slowdowns. The decision to include the technique in transformers depends on the extent of the slowdown :)

@ArthurZucker
Copy link
Collaborator

cc @gante #25104 is a duplicate, opened #25308

@NormXU
Copy link
Contributor Author

NormXU commented Aug 7, 2023

@gante The main difference between my implementations and huggingface's is as follows:

In the former approach, all keys are cached before RoPE is applied to a length-increasing key_states list. The latter one applies RoPE only to a single key_state. Therefore, we just need to confirm whether applying RoPE on a length-increasing key_states list will take more time than applying it to a single key_state.

Here is the exec time of apply_rotary_pos_emb in consistent DynamicNTKScale RoPE on LLaMA-7B (32 layers)

seq_length exec time (ms) seq_length exec time (ms)
16 56.32 528 206.08
32 44.48 544 194.88
48 39.68 560 197.44
64 30.72 576 215.36
80 43.84 592 207.04
96 25.28 608 211.52
112 26.24 624 220.16
128 24.32 640 227.84
144 35.2 656 245.76
160 26.88 672 238.4
176 71.68 688 248.64
192 65.6 704 246.72
208 95.04 720 270.08
432 161.28 944 356.48
448 164.16 960 367.36
464 172.8 976 354.56
480 177.92 992 365.12
496 178.88 1008 407.68

You can find the exec time eval script here:
According to the table above, the answer is: The throughput of consistent is impaired compared to that of dynamic's.

@gante
Copy link
Member

gante commented Aug 7, 2023

@NormXU I see -- if I understand correctly, the execution time of apply_rotary_pos_emb with the modification grows quickly with the sequence length, whereas in the original inconsistent DynamicNTK it doesn't grow (assuming caching is used).

Since DynamicNTK will be used in the large sequence length regime, this means that we would be incurring a high execution speed penalty, which is highly undesirable. Unless we find a way to work around this speed issue, I am against adding this modification -- execution speed is paramount in LLMs nowadays :)

Follow-up question: wouldn't these proposed modifications be the same as running DynamicNTK with use_cache=False? No caching = no inconsistency, correct?

@NormXU
Copy link
Contributor Author

NormXU commented Aug 7, 2023

@gante Indeed, No caching = no inconsistency. In fact, I haven't found any practical downstream tasks where the consistent RoPE can bring significant performance boost. The only advantage convinces me to replace it is its potential to achieve better perplexity scores when dealing with very long contexts.. Therefore, it looks, it is not necessary to correct this inconsistency in the RoPE. Speed does matter more than correctness :)

@gante
Copy link
Member

gante commented Aug 7, 2023

Thank you for discussing and iterating with us, @NormXU 💪 I'll close the issue for now.


TL;DR of the discussion: the inconsistency in DynamicNTK RoPE scaling can be fixed with use_cache=False, at the cost of speed.

@gante gante closed this as completed Aug 7, 2023
@NormXU
Copy link
Contributor Author

NormXU commented Aug 10, 2023

I write a blog about this problem based on our discussion. Hope this can be helpful if you happen to find this issue and want to learn more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants