Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gemma to reflect upstream HF changes #596

Merged
merged 5 commits into from
May 15, 2024

Conversation

cmathw
Copy link
Contributor

@cmathw cmathw commented May 15, 2024

Description

Fixes #594.

Since Gemma was merged in March 14th, there have been a number of upstream HuggingFace changes that mean our activations/logits no longer matched this implementation well. This PR implements the changes described here, changes include:

  • Changing the activation from standard GeLU to the tanh-approximated GeLU.
  • Keeping RMSNorm weights in float32
  • Setting the dtype of the embedding scale to match that in the HookedTransformer config.

A demo showing this agreement can be found here and summarised below. Note: There is currently a bug in various HF/transformer models where the default attention implementation used is not causally masking patterns, I opened an issue here but it is already being addressed here. For this reason it is important when comparing activations/logits to specify attn_implementation="eager" when using HF/transformers' from_pretrained method.

Tolerances for both 2b and 7b models when running in float32 can be found below:

2b 7b
Logits 5e-04 5e-03
Resid Pre Cache 5e-05 5e-04
Next Token Loss 1e-06 5e-05

Tolerances when running float16 are much worse but I think this is commonly the case when adding models to TL with internal upcasting/downcasting, if this is not the case though and these tolerances should also be similarly close I can investigate this further.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@jbloomAus jbloomAus merged commit 0fd85b9 into TransformerLensOrg:main May 15, 2024
10 checks passed
@jbloomAus
Copy link
Collaborator

Thanks Chris! looks good.

bryce13950 pushed a commit that referenced this pull request May 18, 2024
* update activation function to tanh approximation

* keep RMSNorm calcs in float32 and match cfg dtype for embedding scaling

* formatting

* keep mypy happy

* formatting
bryce13950 pushed a commit that referenced this pull request May 18, 2024
* update activation function to tanh approximation

* keep RMSNorm calcs in float32 and match cfg dtype for embedding scaling

* formatting

* keep mypy happy

* formatting
bryce13950 added a commit that referenced this pull request May 24, 2024
* Initial Commit (add pyright + test by adding few annotations)

* Slightly more typing added

* more typing

* Additional typing

* Completed typing for hook_points.py file

* todo clarifications

* formatting changes to hook_points.py

* Apply some suggestions from code review

Co-authored-by: Alan <[email protected]>

* Added typing for Literals and changed some assertions to if statements

* formatting

* update to accout for merged code

* small typing issue

* changing hookfunction protocol + more assertions

* change the slice input

* change from isinstance to callable checks

* fix: Update Gemma to reflect upstream HF changes (#596)

* update activation function to tanh approximation

* keep RMSNorm calcs in float32 and match cfg dtype for embedding scaling

* formatting

* keep mypy happy

* formatting

* allow user to force trust_remote_code=true via from_pretrained kwargs (#597)

* change + revert HookFunctionProtocol

* format

* module_output is now just a tensor

* set module ouput to be any type

---------

Co-authored-by: Alan <[email protected]>
Co-authored-by: Bryce Meyer <[email protected]>
Co-authored-by: cmathw <[email protected]>
Co-authored-by: Clement Dumas <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug Report] Updates to Gemma
2 participants