Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
haunt98 committed Mar 24, 2024
1 parent 1e69983 commit bade6cb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gill/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,10 @@ def __init__(self, tokenizer, model_args: Optional[GILLArgs] = None,
print('Loading decision model...')
self.decision_model = nn.Sequential(*[
nn.Dropout(0.5),
nn.Linear(4096, 2),
nn.Linear(768, 2),
])
mlp_checkpoint = torch.load(decision_model_path)
self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=True)
self.decision_model.load_state_dict(mlp_checkpoint['state_dict'])
self.decision_model.eval()

def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
Expand Down

0 comments on commit bade6cb

Please sign in to comment.