Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Llama2 numerical errors #456

Merged

Conversation

obalcells
Copy link
Contributor

Description

The TransformerLens implementation of Llama2-7B-chat doesn't match the one by Huggingface. This probably also occurs with the other Llama2 models but I haven't tested it.

In the following code the torch.allclose assert is triggered if we use an error tolerance of atol=1.0. It doesn't get triggered if we use atol=1.5 instead.

import torch
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

def check_similarity_with_hf_model(
    tl_model: HookedTransformer,
    hf_model: AutoModelForCausalLM,
    atol: float,
    prompt="Hello world!",
):
    tokens = tl_model.tokenizer.encode(prompt, return_tensors="pt").cuda()
    tl_logits = tl_model(tokens, prepend_bos=False)
    hf_logits = hf_model(tokens).logits
    assert torch.allclose(tl_logits, hf_logits, atol=atol)


MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32,
).cuda()
tl_model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    hf_model=AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float32,
    ),
    tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
    device="cuda",
    n_devices=1,
    move_to_device=True,
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    center_unembed=False,
    torch_dtype=torch.float32,
)

with torch.no_grad():
    check_similarity_with_hf_model(tl_model, hf_model, atol=1)

This issue has already been reported (#385) and I have been able to reproduce it using two RTX A6000 (48GB) GPUs (one for each model).

I believe the issue lies in the value of the constant in the layer norm module (called eps) that's added to avoid division by zero (see implementation here).

The TransformerLens Llama2-7B-chat model uses the value 1e-6 but the Huggingface Llama2-7B-chat model uses 1e-5 by default.

Confusingly, both the LlamaRMSNorm class and the default Llama config use the value 1e-6 by default. However this gets overriden by the model's config in Huggingface. The Huggingface config for the other Llama2 models (chat and base, as well as the larger models) also uses the correct value 1e-5. This was probably the reason why the error was introduced.

After Fix

I've calculated the mean absolute error and the mean relative error between the residual stream vectors of both models. The plots with eps=1e-6 are those before the fix, and the ones with eps=1e-5 are the ones afterwards:

mae_bug

relative_error

mae_fixed

fixed_relative_error

The relative error was calculated with abs((hf_resid_post-tl_resid_post) / (abs(hf_resid_post) + 1e-5)).
Notice that the y-axis is scaled in the last two plots.

After the fix we're able to run

with torch.no_grad():
    check_similarity_with_hf_model(tl_model, hf_model, atol=1e-4)

without triggering the assert. The assert is triggered at atol=1e-5.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

The only change is setting eps to 1e-5 instead of 1e-6 in the default config for all Llama2 models.

Here's the notebook where I reproduce and check the fix for the error.

I'd also like to thank @andyrdt for helping me find the bug and for helping me create my first PR in GitHub.

@ArthurConmy
Copy link
Collaborator

Thank god someone found this -- fantastic!

I'm pretty sure we want to keep Llama-1 models with 1e-6 though, from reading this back from March. Can you implement this?

@ArthurConmy ArthurConmy self-requested a review November 29, 2023 21:10
@obalcells
Copy link
Contributor Author

Done! The change should only apply to the new llama2 models now.

Btw, there's an acceptance test here which is made for finding these kinds of bugs, but the test doesn't seem to be turned on? (neither for llama2, nor for any other model)

@neelnanda-io
Copy link
Collaborator

How did a tiny detail cause that big an atol difference?! Thanks a lot for catching this :)

@ArthurConmy ArthurConmy merged commit 3d0db85 into TransformerLensOrg:main Nov 29, 2023
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants