Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Dynamic" Issue in LlamaDynamicNTKScalingRotaryEmbedding - Long context inference will impact short context inference. #25306

Closed
2 of 4 tasks
LetianLee opened this issue Aug 4, 2023 · 7 comments

Comments

@LetianLee
Copy link

System Info

  • transformers version: 4.32.0.dev0
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.22.0.dev0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (gpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@sgugger

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Please see my colab code:
https://colab.research.google.com/drive/1SnQQxW7WMHgSOvAwF_HIlIDrAuXZ4IKp?usp=sharing

I asked the same prompt twice, with a long-context prompt inserted in between. However, this intermediate long-context inference resulted in different answers for the same question before and after it.

Expected behavior

Since the input length of the tested prompts is within the maximum input token capacity the model can handle, the significance of "Dynamic" lies in ensuring that the embeddings for the inputs before and after remain the same, and consequently, the output results should also be the same.

I reviewed the code of the class "LlamaDynamicNTKScalingRotaryEmbedding" and I think that due to caching, when the model infers a long context, the cached values of cos_cached and sin_cached are updated to adapt to the longer context. This causes the issue when the model infers a shorter context again.

@amyeroberts
Copy link
Collaborator

cc @gante @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for reporting, this is a duplicate of #25104. Will link it in the PR as well

@github-actions
Copy link

github-actions bot commented Sep 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@i4never
Copy link

i4never commented Oct 16, 2023

Hey! Thanks for reporting, this is a duplicate of #25104. Will link it in the PR as well

No, they're not same. I understand #25104 is about the trade off between using kv cache and rotary embed inconsistence. But when you freeze everything during generation including random seeds, same input should give same output sequence.

The dynamic ntk rotary will only recalculate if input seq is longer than cached. What if the longest sequence is predicted at first? Cached embed will never change again. PR #25308 is a correct fix without extra calculate. I think it should be merged.
@gante

@ArthurZucker
Copy link
Collaborator

I see. Makes sense for me @gante if you can have a look! 🤗

@gante
Copy link
Member

gante commented Oct 23, 2023

@i4never I agree, it is a limitation of the technique when implemented as the authors suggest. #25308 is not the correct fix either -- we should only resize the sin and cos caches down to the original size, as smaller values will likely have a negative impact.

Would you like to open a PR to fix it? :)

@i4never
Copy link

i4never commented Oct 24, 2023

@i4never I agree, it is a limitation of the technique when implemented as the authors suggest. #25308 is not the correct fix either -- we should only resize the sin and cos caches down up to the original size, as smaller values will likely have a negative impact.

Would you like to open a PR to fix it? :)

#27033

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants