Skip to content

Commit

Permalink
Fixing Llama2 numerical errors (#456)
Browse files Browse the repository at this point in the history
* Changing RMS layer norm eps value

* Changing eps value for llama2 only
  • Loading branch information
obalcells authored Nov 29, 2023
1 parent c427735 commit 3d0db85
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"d_mlp": 11008,
"n_layers": 32,
"n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
"eps": 1e-6,
"eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
"d_vocab": 32000,
"act_fn": "silu",
"normalization_type": "RMS",
Expand All @@ -580,7 +580,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"d_mlp": 13824,
"n_layers": 40,
"n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
"eps": 1e-6,
"eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
"d_vocab": 32000,
"act_fn": "silu",
"normalization_type": "RMS",
Expand Down

0 comments on commit 3d0db85

Please sign in to comment.