Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The TransformerLens implementation of Llama2-7B-chat doesn't match the one by Huggingface. This probably also occurs with the other Llama2 models but I haven't tested it.
In the following code the
torch.allclose
assert is triggered if we use an error tolerance ofatol=1.0
. It doesn't get triggered if we useatol=1.5
instead.This issue has already been reported (#385) and I have been able to reproduce it using two RTX A6000 (48GB) GPUs (one for each model).
I believe the issue lies in the value of the constant in the layer norm module (called
eps
) that's added to avoid division by zero (see implementation here).The TransformerLens Llama2-7B-chat model uses the value
1e-6
but the Huggingface Llama2-7B-chat model uses1e-5
by default.Confusingly, both the
LlamaRMSNorm
class and the default Llama config use the value1e-6
by default. However this gets overriden by the model's config in Huggingface. The Huggingface config for the other Llama2 models (chat and base, as well as the larger models) also uses the correct value1e-5
. This was probably the reason why the error was introduced.After Fix
I've calculated the mean absolute error and the mean relative error between the residual stream vectors of both models. The plots with
eps=1e-6
are those before the fix, and the ones witheps=1e-5
are the ones afterwards:The relative error was calculated with
abs((hf_resid_post-tl_resid_post) / (abs(hf_resid_post) + 1e-5))
.Notice that the y-axis is scaled in the last two plots.
After the fix we're able to run
without triggering the assert. The assert is triggered at atol=
1e-5
.Type of change
The only change is setting
eps
to1e-5
instead of1e-6
in the default config for all Llama2 models.Here's the notebook where I reproduce and check the fix for the error.
I'd also like to thank @andyrdt for helping me find the bug and for helping me create my first PR in GitHub.