Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: device/type-invariant RoPE sin/cos computation, eager attention matches original implementation #28837

Closed
wants to merge 12 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Feb 2, 2024

What does this PR do?

This PR fixes the following problems, all related to RoPE:

  1. Casting a model with .from_pretrained(..., torch_dtype=...) or .to(dtype=...) would produce different sin/cos tensors at recomputation time. The underlying cause was inv_freq being a buffer, which means it was subject to buffer manipulation (like a .to() operation in the wrapping module). Note that the original repo assumed it was always a torch.float32 tensor. In some models, there was a visible performance degradation when doing inference with seq_len > max_position_embeddings (see here);
  2. The inv_freq tensor was being loaded from the state dict, due to a previous version of the code where it was a persistent buffer;
  3. ⚠️ Perhaps more importantly, the sin/cos tensors are now always computed on CPU. As pointed out in this comment, there are subtle numerical differences that depend on the initialization device, which quickly escalate into further downstream issues. This particular change results in the following:
    a. Smaller modeling performance differences across devices, as CPUs are ubiquitous (as opposed to accelerators, which may change);
    b. Prevention of loss spikes at train time, possibly due to the more accurate sin/cos computation (see this comment and the whole issue);
    c. Slightly slower throughput when recomputing the sin/cos tensors, i.e. when going beyond self.max_seq_len_cached.

See additional data and experiments below for the impact of this PR. Most of the diff in this PR is tests, to ensure we don't regress 🤗

Suggested review order:

  1. Llama modelling changes
  2. Llama test changes
  3. GPTNeoX changes (fixes dtype cast as intended, see experiments below :) )
  4. Other models (direct #Copied from changes)
  5. Other tests (copy/paste)
    (Other RoPE models will follow in a future PR)

Related GH issues

Fixes #28685
Fixes #25681
Fixes #28596
Fixes #27179
Should fix/help microsoft/DeepSpeed#4932

Additional data and experiments

Perlplexity, memory, and latency results before/after this PR

NOTE: using the .to() casting method. The torch_dtype sees no differences, as inv_freq is not casted.

Llama 2 -- very little ppl differences

Dtype: bfloat16
(ignore the vram -- the latest commit has the same GPU memory footprint as main)

plot_perplexity_vram
plot_latency

Dtype: float16
(ignore the vram -- the latest commit has the same GPU memory footprint as main)
plot_perplexity_vram
plot_latency

TinyLlama -- visible ppl upgrade

Dtype: bfloat16
(ignore the vram -- the latest commit has the same GPU memory footprint as main)
plot_perplexity_vram
plot_latency

Dtype: float16
(ignore the vram -- the latest commit has the same GPU memory footprint as main)
plot_perplexity_vram
plot_latency

How sensible is the sin/cos creation to the device placement?

Consider the following script:

import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding

TEST_DTYPE = torch.bfloat16

for dim in (64, 256, 1024):
  for max_position_embeddings in (1024, 2048, 4096):
      for base in (10000, 100000, 1000000):
          rope_gpu = LlamaRotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=base, device='cuda')
          rope_cpu = LlamaRotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=base, device='cpu')

          rope_cpu = rope_cpu.to(device='cuda', dtype=TEST_DTYPE)
          rope_gpu = rope_gpu.to(device='cuda', dtype=TEST_DTYPE)
          max_sin_diff = (rope_gpu.sin_cached - rope_cpu.sin_cached).abs().max()
          max_cos_diff = (rope_gpu.cos_cached - rope_cpu.cos_cached).abs().max()
          max_diff = max(max_sin_diff, max_cos_diff)
          if max_diff > 0.0:
              print(f"dim={dim}, max_position_embeddings={max_position_embeddings}, base={base}, max_diff={max_diff:.2e}")

On main, before this PR, we can see differences as large as ~1e-3 regardless of TEST_DTYPE (even in torch.float64!). After this PR, the difference is 0.0.

Original Llama codebase vs our codebase after this PR?

Key takeaways:
👉 sin/cos are created on the available device (and not on CPU)
👉 sin/cos are not only kept in FP32, but also applied in FP32!

Consider the following script, which compares this hugging face's implementation against meta's repo

# run as `torchrun this_script.py`
from llama import Llama
from transformers import AutoModelForCausalLM
import torch

# Loaded in FP16 on GPU
original_llama = Llama.build(
  ckpt_dir="/home/joao/meta_llama/Llama-2-7b/",
  tokenizer_path="/home/joao/meta_llama/Llama-2-7b/tokenizer.model",
  max_seq_len=2048,  # internaly, 2048*2 is considered to compute sin/cos
  max_batch_size=1,
)
og_logits = original_llama.model(tokens=torch.tensor([list(range(1000))]), start_pos=0)
og_sin = original_llama.model.freqs_cis.imag
og_cos = original_llama.model.freqs_cis.real
del original_llama
torch.cuda.empty_cache()

# Loaded in FP16 on GPU
transformers_llama = AutoModelForCausalLM.from_pretrained(
  "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
logits = transformers_llama(torch.tensor([list(range(1000))])).logits.float()
sin = transformers_llama.model.layers[0].self_attn.rotary_emb.sin_cached
cos = transformers_llama.model.layers[0].self_attn.rotary_emb.cos_cached


logits_diff = (og_logits.cpu() - logits.cpu()).abs().max()
print(f"Max logits diff: {logits_diff.item()}")

# .cat -> our sin/cos have a period of 4pi (2 cycles), the orginal have a period of 2pi (1 cycle)
# .float() -> on main, we cast sin/cos to the model dtype
sin_diff = (torch.cat([og_sin, og_sin], dim=1).cpu() - sin.float().cpu()).abs().max()
cos_diff = (torch.cat([og_cos, og_cos], dim=1).cpu() - cos.float().cpu()).abs().max()
print(f"Max sin diff: {sin_diff.item()}")
print(f"Max cos diff: {cos_diff.item()}")

On main + GPU + FP16, before this PR, we can see sin/cos and logits differences as large as 2e-4 and 6e-2 (respectively). After this PR, the difference is 0.0.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -505,6 +506,120 @@ def test_eager_matches_sdpa_generate(self):
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))

@require_torch_gpu
def test_rope_cast_strategy_invariant(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test fails on main because inv_freq was being casted with .to()

)

@require_torch_gpu
def test_rope_initialization_invariant(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test fails on main, as initialization is device-dependent there

@gante gante changed the title Llama: device and type-invariant RoPE sin/cos computation Llama: device/type-invariant RoPE sin/cos computation and FP32 application Feb 5, 2024
@gante gante changed the title Llama: device/type-invariant RoPE sin/cos computation and FP32 application Llama: device/type-invariant RoPE sin/cos computation, eager attention matches original implementation Feb 7, 2024
@huggingface huggingface deleted a comment from github-actions bot Mar 5, 2024
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment