Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error size mismatch when load decision model #38

Open
haunt98 opened this issue Mar 24, 2024 · 2 comments
Open

Error size mismatch when load decision model #38

haunt98 opened this issue Mar 24, 2024 · 2 comments

Comments

@haunt98
Copy link

haunt98 commented Mar 24, 2024

After training both gill and decision model, load_model failed:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 2>:2                                                                              │
│                                                                                                  │
│ /content/gill/gill/models.py:873 in load_gill                                                    │
│                                                                                                  │
│   870 │   decision_model_path = None                                                             │
│   871                                                                                            │
│   872   # Initialize model for inference.                                                        │
│ ❱ 873   model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,              │
│   874 │   │   │      load_sd=True, num_gen_images=1, decision_model_path=decision_model_path)    │
│   875   model = model.eval()                                                                     │
│   876   model = model.bfloat16()                                                                 │
│                                                                                                  │
│ /content/gill/gill/models.py:560 in __init__                                                     │
│                                                                                                  │
│   557 │   │     nn.Linear(768, 2),                                                               │
│   558 │     ])                                                                                   │
│   559 │     mlp_checkpoint = torch.load(decision_model_path)                                     │
│ ❱ 560 │     self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=False)      │
│   561 │     self.decision_model.eval()                                                           │
│   562                                                                                            │
│   563   def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: O   │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671 in load_state_dict       │
│                                                                                                  │
│   1668 │   │   │   │   │   │   ', '.join('"{}"'.format(k) for k in missing_keys)))               │
│   1669 │   │                                                                                     │
│   1670 │   │   if len(error_msgs) > 0:                                                           │
│ ❱ 1671 │   │   │   raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(     │
│   1672 │   │   │   │   │   │   │      self.__class__.__name__, "\n\t".join(error_msgs)))         │
│   1673 │   │   return _IncompatibleKeys(missing_keys, unexpected_keys)                           │
│   1674                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Error(s) in loading state_dict for Sequential:
        size mismatch for 1.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in 
current model is torch.Size([2, 4096]).
@haunt98
Copy link
Author

haunt98 commented Mar 24, 2024

Looks like you hardcode 4096 when init decision model

My hotfix is here but maybe use param or something?

@kohjingyu
Copy link
Owner

Glad you solved it. Yeah, I think the right way to do this would be to add this as a param to the model args.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants