Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlamaRotaryEmbedding (wrong cache value when casting model to float16/bfloat16) #25681

Closed
1 of 4 tasks
KeremTurgutlu opened this issue Aug 23, 2023 · 12 comments
Closed
1 of 4 tasks

Comments

@KeremTurgutlu
Copy link

KeremTurgutlu commented Aug 23, 2023

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.15.0-79-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.2
  • Accelerate version: 0.22.0.dev0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: FSDP
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - fsdp_config: {'fsdp_auto_wrap_policy': 'SIZE_BASED_WRAP', 'fsdp_backward_prefetch_policy': 'BACKWARD_PRE', 'fsdp_forward_prefetch': False, 'fsdp_min_num_params': 100000000, 'fsdp_offload_params': False, 'fsdp_sharding_strategy': 2, 'fsdp_state_dict_type': 'FULL_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_use_orig_params': True}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
    - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • PyTorch version (GPU?): 2.1.0.dev20230809+cu121 (True)
  • Tensorflow version (GPU?): 2.13.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.2 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker would be the best person to discuss this.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

TL;DR If a model with a LlamaRotaryEmbedding layer is cast to bfloat16/float16 after initialization and if during forward pass a sequence with a sequence length > self.max_position_embeddings is used, then the cached cos and sin buffer values will most probably be different than the trained model, giving unexpected results.

I came across this very subtle error doing the following and I am not sure what might be the best solution for this.

I finetuned the Llama-2 model using accelerate FSDP and bfloat16 mixed precision policy. I used a slightly different config than the original one in which the max_position_embeddings=2048 was set. FSDP + accelerate uses autocast under the hood which takes care of the ops inside LlamaRotaryEmbedding to be in full precision which is great.

Problem happens when we feed a sequence with a greater sequence length and also cast the model to a lower precision as opposed to using autocast. I loaded this trained model using

load_checkpoint_and_dispatch(custom_config_model, str(fn),
                              device_map={
                                          "model":torch.cuda.current_device(),
                                          "lm_head":torch.cuda.current_device(),
                                         },
                              dtype=torch.bfloat16);

My custom config looked like this, notice "max_position_embeddings": 2048,:

 LlamaConfig {
   "block_size": 2960,
   "bos_token_id": 1,
   "eos_token_id": 2,
   "hidden_act": "silu",
   "hidden_size": 4096,
   "initializer_range": 0.02,
   "intermediate_size": 11008,
   "max_position_embeddings": 2048,
   "model_type": "llama",
   "num_attention_heads": 32,
   "num_hidden_layers": 32,
   "num_key_value_heads": 32,
   "packed_inputs": false,
   "pad_token_id": 0,
   "prefix_lm": false,
   "pretraining_tp": 1,
   "rms_norm_eps": 1e-06,
   "rope_scaling": null,
   "tie_word_embeddings": false,
   "transformers_version": "4.31.0",
   "use_cache": true,
   "vocab_size": 64008
 }

During inference when testing the trained model my training/validation perplexity increased from ~2.5 to ~20.0, it took me 2 days to figure out that the exact issue was with model casting + having sequence lengths > max_position_embeddings.

Potential Fixes:

  • Add warning about this, and suggest using autocast during inference.
  • Add warning about this, and suggest initializing the model with a very high self.max_position_embeddings value so that cos-sin caches won't be re-initialized with wrong values due to lower precision. Even using, self.max_position_embeddings=80k should be fine given the relatively small size of the buffer compared to total model size.
  • Modify LlamaRotaryEmbedding so that always float32 is used in ops and cast to x.dtype only at the very end. This is a bit difficult because if a model is cast to bfloat16/float16, it will still produce different cache values even if its cast back to float32. I don't know if there is way to disable model casting for certain layers - but I guess that would be autocast 😄

This modified version will produce closer but still wrong cache values:

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

I personally will keep self.max_position_embeddings as high as my max intended sequence length and also will use autocast where possible.

Reproduction

# from https://github.com/huggingface/transformers/blob/3d1edb6c5d36bf6426e72223f534266ff29c45c4/src/transformers/models/llama/modeling_llama.py#L92C1-L125C10

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
# expected cache values
rotary_emb = LlamaRotaryEmbedding(2048)
rotary_emb.cos_cached[:,:,:1024]

tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5403,  0.5478,  0.5552,  ...,  1.0000,  1.0000,  1.0000],
          [-0.4161, -0.3998, -0.3835,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [-0.9998,  0.9651, -0.8084,  ...,  0.9945,  0.9946,  0.9947],
          [-0.5550,  0.3096,  0.0407,  ...,  0.9945,  0.9946,  0.9947],
          [ 0.4001, -0.6259,  0.8536,  ...,  0.9945,  0.9946,  0.9947]]]])


# Wrong cache values when cast  to bfloat16
rotary_emb.to(torch.bfloat16);
# create an input > 2048
x = torch.randn(2, 32, 4096, 128)
_ = rotary_emb(x, seq_len=4096)
rotary_emb.cos_cached[:,:,:1024]

tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5391,  0.5469,  0.5547,  ...,  1.0000,  1.0000,  1.0000],
          [-0.4160, -0.4023, -0.3809,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [-0.5273,  0.9180,  0.5625,  ...,  0.9961,  0.9961,  0.9961],
          [ 0.9883, -0.3008,  0.2578,  ...,  0.9961,  0.9961,  0.9961],
          [ 0.9883, -0.3008,  0.2578,  ...,  0.9961,  0.9961,  0.9961]]]])

# try with float16 this time
rotary_emb = LlamaRotaryEmbedding(2048)
# cast model to float16
rotary_emb.to(torch.float16);
rotary_emb.cos_cached[:,:,:1024]
# create an input > 2048
x = torch.randn(2, 32, 4096, 128)
_ = rotary_emb(x, seq_len=4096)
rotary_emb.cos_cached[:,:,:1024]

tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5405,  0.5479,  0.5552,  ...,  1.0000,  1.0000,  1.0000],
          [-0.4163, -0.4001, -0.3831,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [-1.0000,  0.9185, -0.9453,  ...,  0.9946,  0.9946,  0.9946],
          [-0.5552,  0.1628, -0.2366,  ...,  0.9946,  0.9946,  0.9946],
          [ 0.4001, -0.7422,  0.6899,  ...,  0.9946,  0.9946,  0.9946]]]])

cc: @ArthurZucker

Expected behavior

Same cache values for rotary embeddings.

@ArthurZucker
Copy link
Collaborator

cc @gante you recently worked on the extension of the cache for RotaryEmbeddings! Might affect other (dynamic ones)

@gante
Copy link
Member

gante commented Aug 23, 2023

Hey @KeremTurgutlu 👋

It is known that, when casting to 16 bits for inference purposes, you should use the exact casting strategy as used with the model at train time. We try to store that in the torch_dtype config field, whenever we have access to that information (e.g. here).

In this particular case, the issue is compounded by the fact that the RoPE layer has buffers, which mask the issue in some cases.

@ArthurZucker should we emit a warning when the model gets converted to a 16-bit format different from the torch_dtype field? 🤔

@jph00
Copy link

jph00 commented Aug 23, 2023

This is the same bug that's discussed here

EleutherAI/gpt-neox#1003

The fix is to calculate sin and cos values in init and ensure they're not stored in buffers. Or don't cast the model, but instead use autocast, which avoids this issue. Note that with deepspeed it will always cast, so you need the fix.

@ArthurZucker
Copy link
Collaborator

There's also this #24262 and if we can have a code fix would be awesome than having another warning

@kikutakou
Copy link

@KeremTurgutlu
Is this just an inaccuracy problem of float16 precision?
The last value shown in your snippet may be calculated following way.

>>> import torch
>>> e = torch.tensor(0.1032)
>>> e.cos()
tensor(0.9947)
>>> e.cos().to(torch.bfloat16)
tensor(0.9961, dtype=torch.bfloat16)
>>> e.cos().to(torch.float16)
tensor(0.9946, dtype=torch.float16)

inv_freq is always float32 since it's converted using .float(). Hence, the variable t in _set_cos_sin_cache is also always float32.

@jph00
Copy link

jph00 commented Aug 25, 2023

inv_freq is always float32 since it's converted using .float(). Hence, the variable t in _set_cos_sin_cache is also always float32.

No, it's stored as a buffer, so it gets cast in some situations. See the full description of the bug and code to fix it here: EleutherAI/gpt-neox#1003

@jph00
Copy link

jph00 commented Sep 23, 2023

@ArthurZucker @gante I don't think this issue should be closed AFAICT.

@huggingface huggingface deleted a comment from github-actions bot Sep 25, 2023
@ArthurZucker
Copy link
Collaborator

Yep, it’s on my todo when I’ll deep dive on all the llama related issues

@ArthurZucker
Copy link
Collaborator

Sorry I'll get to this soon 🤗

@ArthurZucker
Copy link
Collaborator

cc @fxmarty related to your #26836 and why we have to be extra careful with ROPE and float16!

@huggingface huggingface deleted a comment from github-actions bot Nov 14, 2023
@ArthurZucker
Copy link
Collaborator

I don't know what took me so long but this is similar to #25306 and can be fixed by something close to #27033 (which slows down a lot) but this should be fixed for all ROPEs that copy from Llama / use dynamic scaling.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants