Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in Load model weights #76

Open
inakanishi opened this issue Sep 8, 2023 · 3 comments
Open

Error in Load model weights #76

inakanishi opened this issue Sep 8, 2023 · 3 comments

Comments

@inakanishi
Copy link

When I started ClipCap on google colab (Open in Colab), I got the following error at the Load model weights. How can I solve this problem?

image
@abrar-fahim
Copy link

I did some digging and found this to be a version issue with the transformers module.

TLDR: Simply deleting the bias and masked_bias keys from the state dictionary seems to work. Something like:

altered_state_dict = torch.load(model_path, map_location=CPU)
for i in range(12):
    del altered_state_dict['gpt.transformer.h.' + str(i) + '.attn.bias']
    del altered_state_dict['gpt.transformer.h.' + str(i) + '.attn.masked_bias']

And then loading the altered state dictionary into the model by:

model.load_state_dict(altered_state_dict)

Most likely explanation

More specifically, in transformers: version 4.11.3 used in the original notebook, in the source code for GPT2Attention, "bias" and "masked_bias" are persistent buffers, and therefore saved to the state_dictionary:
image

In the latest transformers version (version 4.33.1), in GPT2Attention, "bias" and "masked_bias" buffers are defined to have persistent=False, meaning that they are no longer a part of the model's state dictionary:
image

@inakanishi
Copy link
Author

Thank you for your careful explanation. I did it!

@Hsn37
Copy link

Hsn37 commented Nov 8, 2023

An easier workaround for this is to set strict=False in the load_state_dict function.

model.load_state_dict(model_weights, strict=False) 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants