-
Notifications
You must be signed in to change notification settings - Fork 10
/
demo.py
120 lines (92 loc) · 3.47 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
from models import Models
from utils import image, text
from PIL import Image, ImageFont, ImageDraw
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from dataset import process_vqa_dataset
def parse_args():
parser = argparse.ArgumentParser("VQA Demo")
parser.add_argument("image", help="Path to image file")
parser.add_argument("question", help="Question text")
parser.add_argument(
"questions", help="Path to VQA Questions training file")
parser.add_argument(
"annotations", help="Path to COCO Annotations training file")
parser.add_argument("--model", default="DeeperLSTM")
parser.add_argument(
"--checkpoint", default="weights/deeper_lstm_best_weights.pth.tar")
parser.add_argument("--preprocessed_cache",
default="vqa_train_dataset_cache.pickle")
parser.add_argument("--embedding_arch", default="vgg19_bn")
return parser.parse_args()
def generate(output, aid_to_ans):
"""
Return the answer given the Answer ID/label
:param output: The answer label
:param aid_to_ans:
:return:
"""
ans = aid_to_ans[output]
return ans
def display_result(image, question, answer):
draw = ImageDraw.Draw(image)
draw.text((10, 10), question)
draw.text((10, 20), answer)
print("{0}: {1}".format(question, answer))
image.show()
def main():
args = parse_args()
print("Loading encoded data...")
data, vocab, word_to_wid, wid_to_word, \
ans_to_aid, aid_to_ans = process_vqa_dataset(
args.questions, args.annotations, "train", maps=None)
# Get VGG model to process the image
vision_model, _ = image.get_model(args.embedding_arch)
# Get our VQA model
model = Models[args.model].value(len(vocab))
if torch.cuda.is_available():
device = torch.device('cuda:{0}'.format(args.gpu))
else:
device = torch.device('cpu')
try:
weights = torch.load(args.checkpoint)
except (Exception,):
print("ERROR: Default weights missing. Please specify weights for the VQA model")
exit(0)
model.load_state_dict(weights["model"])
vision_model.eval().to(device)
model.eval().to(device)
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print("Processing image")
im = Image.open(args.image)
img = img_transforms(im)
img = img.unsqueeze(0) # add batch dimension
img = img.to(device)
img_features = vision_model(img)
print("Processing question")
q = text.process_single_question(args.question, vocab, word_to_wid)
q = torch.from_numpy(q['question_wids'])
# Convert the question to a sequence of 1 hot vectors over the vocab
# one_hot_vec = np.zeros((len(q["question_wids"]), len(vocab)))
# for k in range(len(q["question_wids"])):
# one_hot_vec[k, q['question_wids'][k]] = 1
# q = torch.from_numpy(one_hot_vec)
q = q.to(device)
# Add the batch dimension
q = q.unsqueeze(0).long()
# Get the model output and classify for the final value
output = model(img_features, q, torch.LongTensor([q.size(1)]))
output = classifier(output).data
_, ans_id = torch.max(output, dim=1)
ans = generate(ans_id.item(), aid_to_ans)
display_result(im, args.question, ans)
if __name__ == "__main__":
main()