We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
After training both gill and decision model, load_model failed:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ in <cell line: 2>:2 │ │ │ │ /content/gill/gill/models.py:873 in load_gill │ │ │ │ 870 │ decision_model_path = None │ │ 871 │ │ 872 # Initialize model for inference. │ │ ❱ 873 model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix, │ │ 874 │ │ │ load_sd=True, num_gen_images=1, decision_model_path=decision_model_path) │ │ 875 model = model.eval() │ │ 876 model = model.bfloat16() │ │ │ │ /content/gill/gill/models.py:560 in __init__ │ │ │ │ 557 │ │ nn.Linear(768, 2), │ │ 558 │ ]) │ │ 559 │ mlp_checkpoint = torch.load(decision_model_path) │ │ ❱ 560 │ self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=False) │ │ 561 │ self.decision_model.eval() │ │ 562 │ │ 563 def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: O │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671 in load_state_dict │ │ │ │ 1668 │ │ │ │ │ │ ', '.join('"{}"'.format(k) for k in missing_keys))) │ │ 1669 │ │ │ │ 1670 │ │ if len(error_msgs) > 0: │ │ ❱ 1671 │ │ │ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( │ │ 1672 │ │ │ │ │ │ │ self.__class__.__name__, "\n\t".join(error_msgs))) │ │ 1673 │ │ return _IncompatibleKeys(missing_keys, unexpected_keys) │ │ 1674 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ RuntimeError: Error(s) in loading state_dict for Sequential: size mismatch for 1.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([2, 4096]).
The text was updated successfully, but these errors were encountered:
Looks like you hardcode 4096 when init decision model
My hotfix is here but maybe use param or something?
Sorry, something went wrong.
Glad you solved it. Yeah, I think the right way to do this would be to add this as a param to the model args.
No branches or pull requests
After training both gill and decision model, load_model failed:
The text was updated successfully, but these errors were encountered: