Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load models with a gamma or beta parameter #29554

Closed
malik-ali opened this issue Mar 9, 2024 · 8 comments · Fixed by #31654 · May be fixed by #33192
Closed

Can't load models with a gamma or beta parameter #29554

malik-ali opened this issue Mar 9, 2024 · 8 comments · Fixed by #31654 · May be fixed by #33192
Labels
Feature request Request for a new feature Good Difficult Issue Should Fix This has been identified as a bug and should be fixed.

Comments

@malik-ali
Copy link

It seems that you cannot create parameters with the string gamma or beta in any modules you write if you intend to save/load them with the transformers library. There is a small function called _fix_keys implemented in the model loading (link). It renames all instances of beta or gamma in any substring of the sate_dict keys to be bias and weight. This means if your modules actually have a parameter with these names, they won't be loaded when using a pretrained model.

As far as I can tell, it's completely undocumented that people shouldn't create any parameters with the string gamma or beta in them.

Here is a minimal reproducible example:

import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

class Model(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.gamma = nn.Parameter(torch.zeros(4))

    def forward(self):
        return self.gamma.sum()


if __name__ == '__main__':
    config = PretrainedConfig()

    # 1) First run this
    #model = Model(config)
    #print(model())

    #model.save_pretrained('test_out')

    # 2) Then try this
    model = Model.from_pretrained('test_out', config=config)
    print(model())

When you run this code, you get the following error:

Some weights of Model were not initialized from the model checkpoint at test_out and are newly initialized: ['gamma']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
@malik-ali malik-ali changed the title Can't have models with a gamma or beta parameter Can't load models with a gamma or beta parameter Mar 10, 2024
@NielsRogge
Copy link
Contributor

Yes that's correct, it's a bug I pointed out in my video series on contributing to Transformers.

This is due to these lines:

if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
.

I assume they are there for backwards compatibility reasons. If we would know which models require this exception, we could fix this.

@malik-ali
Copy link
Author

malik-ali commented Mar 10, 2024

I assumed the same, but it's a pretty annoying bug to have to find on your own. Would it be worth adding a warning to the init method of the PreTrainedModel class to let users know if their parameters have the string "gamma" or "beta" in them and encourage them to change it? At least while this block of code still exists in the codebase.

It's further complicated by the fact that accelerate uses torch.load_state, which doesn't do this renaming. So there is an incompatibility between two highly coupled libraries.

@amyeroberts
Copy link
Collaborator

Hi @malik-ali, thanks for raising this issue! Indeed, this isn't a desired behaviour.

If we would know which models require this exception, we could fix this.

I think this would be very hard to do. There are many saved checkpoints both on and off the hub, as well as all sorts of custom models which might rely on this behaviour.

Would it be worth adding a warning to the init method of the PreTrainedModel class to let users know if their parameters have the string "gamma" or "beta" in them and encourage them to change it? At least while this block of code still exists in the codebase.

Yes, I think a warning for a few cycle releases is the best way to go. I would put this in the _load_state_dict_into_model function and trigger if "gamma" or "beta" are in the key.

It won't be possible to tell if the parameter is from an "old" state or a new model, but we can warn that the renaming is happening, that the behaviour will be removed in a future release and they should update the weights in their state dict to use "weight" or "bias" to be loaded properly.

@malik-ali Would you like to open a PR to add this? This way you get the github contribution for your suggested solution

@malik-ali
Copy link
Author

@amyeroberts I'd be happy to! Just one question: if we add this to the _load_state_dict_into_model, is it correct that users would only see this warning when loading their pretrained model?

I ask because I ran into this issue after training a model for several days and later loading it. It would have been nice to see the warning before doing all the training, so that I could rename the parameters on the spot. Do you think a warning like that would be feasible?

(My fix was to manually rename the keys of the saved state_dict and then rename the parameters in my model)

@amyeroberts
Copy link
Collaborator

Good point! In this case, we'll need to add a warning in two places to make sure we catch both new model creations and old state dicts being loaded in.

@fzyzcjy
Copy link
Contributor

fzyzcjy commented Mar 21, 2024

+1 Find this problem today...

@amyeroberts amyeroberts added Should Fix This has been identified as a bug and should be fixed. Feature request Request for a new feature labels Mar 21, 2024
@malik-ali
Copy link
Author

@amyeroberts I might not have a chance to push a fix for this for at least a few weeks so please feel free to make any changes as you (or anyone) wishes!

@amyeroberts
Copy link
Collaborator

@malik-ali OK - thanks for letting us know. I've added a 'Good difficult Issue' to flag for anyone in the community that might want to tackle this in the meantime

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature Good Difficult Issue Should Fix This has been identified as a bug and should be fixed.
Projects
None yet
4 participants