-
Notifications
You must be signed in to change notification settings - Fork 417
/
test_models.py
executable file
·331 lines (269 loc) · 12.2 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, [email protected]
# Notice that this file has been modified to support ensemble testing
import argparse
import time
import torch.nn.parallel
import torch.optim
from sklearn.metrics import confusion_matrix
from ops.dataset import TSNDataSet
from ops.models import TSN
from ops.transforms import *
from ops import dataset_config
from torch.nn import functional as F
# options
parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
parser.add_argument('dataset', type=str)
# may contain splits
parser.add_argument('--weights', type=str, default=None)
parser.add_argument('--test_segments', type=str, default=25)
parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
parser.add_argument('--full_res', default=False, action="store_true",
help='use full resolution 256x256 for test as in Non-local I3D')
parser.add_argument('--test_crops', type=int, default=1)
parser.add_argument('--coeff', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
help='number of data loading workers (default: 8)')
# for true test
parser.add_argument('--test_list', type=str, default=None)
parser.add_argument('--csv_file', type=str, default=None)
parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')
parser.add_argument('--max_num', type=int, default=-1)
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--crop_fusion_type', type=str, default='avg')
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--img_feature_dim',type=int, default=256)
parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
parser.add_argument('--pretrain', type=str, default='imagenet')
args = parser.parse_args()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def parse_shift_option_from_log_name(log_name):
if 'shift' in log_name:
strings = log_name.split('_')
for i, s in enumerate(strings):
if 'shift' in s:
break
return True, int(strings[i].replace('shift', '')), strings[i + 1]
else:
return False, None, None
weights_list = args.weights.split(',')
test_segments_list = [int(s) for s in args.test_segments.split(',')]
assert len(weights_list) == len(test_segments_list)
if args.coeff is None:
coeff_list = [1] * len(weights_list)
else:
coeff_list = [float(c) for c in args.coeff.split(',')]
if args.test_list is not None:
test_file_list = args.test_list.split(',')
else:
test_file_list = [None] * len(weights_list)
data_iter_list = []
net_list = []
modality_list = []
total_num = None
for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
if 'RGB' in this_weights:
modality = 'RGB'
else:
modality = 'Flow'
this_arch = this_weights.split('TSM_')[1].split('_')[2]
modality_list.append(modality)
num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
modality)
print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
net = TSN(num_class, this_test_segments if is_shift else 1, modality,
base_model=this_arch,
consensus_type=args.crop_fusion_type,
img_feature_dim=args.img_feature_dim,
pretrain=args.pretrain,
is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
non_local='_nl' in this_weights,
)
if 'tpool' in this_weights:
from ops.temporal_shift import make_temporal_pool
make_temporal_pool(net.base_model, this_test_segments) # since DataParallel
checkpoint = torch.load(this_weights)
checkpoint = checkpoint['state_dict']
# base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
'base_model.classifier.bias': 'new_fc.bias',
}
for k, v in replace_dict.items():
if k in base_dict:
base_dict[v] = base_dict.pop(k)
net.load_state_dict(base_dict)
input_size = net.scale_size if args.full_res else net.input_size
if args.test_crops == 1:
cropping = torchvision.transforms.Compose([
GroupScale(net.scale_size),
GroupCenterCrop(input_size),
])
elif args.test_crops == 3: # do not flip, so only 5 crops
cropping = torchvision.transforms.Compose([
GroupFullResSample(input_size, net.scale_size, flip=False)
])
elif args.test_crops == 5: # do not flip, so only 5 crops
cropping = torchvision.transforms.Compose([
GroupOverSample(input_size, net.scale_size, flip=False)
])
elif args.test_crops == 10:
cropping = torchvision.transforms.Compose([
GroupOverSample(input_size, net.scale_size)
])
else:
raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))
data_loader = torch.utils.data.DataLoader(
TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
new_length=1 if modality == "RGB" else 5,
modality=modality,
image_tmpl=prefix,
test_mode=True,
remove_missing=len(weights_list) == 1,
transform=torchvision.transforms.Compose([
cropping,
Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
GroupNormalize(net.input_mean, net.input_std),
]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
)
if args.gpus is not None:
devices = [args.gpus[i] for i in range(args.workers)]
else:
devices = list(range(args.workers))
net = torch.nn.DataParallel(net.cuda())
net.eval()
data_gen = enumerate(data_loader)
if total_num is None:
total_num = len(data_loader.dataset)
else:
assert total_num == len(data_loader.dataset)
data_iter_list.append(data_gen)
net_list.append(net)
output = []
def eval_video(video_data, net, this_test_segments, modality):
net.eval()
with torch.no_grad():
i, data, label = video_data
batch_size = label.numel()
num_crop = args.test_crops
if args.dense_sample:
num_crop *= 10 # 10 clips for testing when using dense sample
if args.twice_sample:
num_crop *= 2
if modality == 'RGB':
length = 3
elif modality == 'Flow':
length = 10
elif modality == 'RGBDiff':
length = 18
else:
raise ValueError("Unknown modality "+ modality)
data_in = data.view(-1, length, data.size(2), data.size(3))
if is_shift:
data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
rst = net(data_in)
rst = rst.reshape(batch_size, num_crop, -1).mean(1)
if args.softmax:
# take the softmax to normalize the output to probability
rst = F.softmax(rst, dim=1)
rst = rst.data.cpu().numpy().copy()
if net.module.is_shift:
rst = rst.reshape(batch_size, num_class)
else:
rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
return i, rst, label
proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else total_num
top1 = AverageMeter()
top5 = AverageMeter()
for i, data_label_pairs in enumerate(zip(*data_iter_list)):
with torch.no_grad():
if i >= max_num:
break
this_rst_list = []
this_label = None
for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
rst = eval_video((i, data, label), net, n_seg, modality)
this_rst_list.append(rst[1])
this_label = label
assert len(this_rst_list) == len(coeff_list)
for i_coeff in range(len(this_rst_list)):
this_rst_list[i_coeff] *= coeff_list[i_coeff]
ensembled_predict = sum(this_rst_list) / len(this_rst_list)
for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
output.append([p[None, ...], g])
cnt_time = time.time() - proc_start_time
prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
top1.update(prec1.item(), this_label.numel())
top5.update(prec5.item(), this_label.numel())
if i % 20 == 0:
print('video {} done, total {}/{}, average {:.3f} sec/video, '
'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))
video_pred = [np.argmax(x[0]) for x in output]
video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]
video_labels = [x[1] for x in output]
if args.csv_file is not None:
print('=> Writing result to csv file: {}'.format(args.csv_file))
with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
categories = f.readlines()
categories = [f.strip() for f in categories]
with open(test_file_list[0]) as f:
vid_names = f.readlines()
vid_names = [n.split(' ')[0] for n in vid_names]
assert len(vid_names) == len(video_pred)
if args.dataset != 'somethingv2': # only output top1
with open(args.csv_file, 'w') as f:
for n, pred in zip(vid_names, video_pred):
f.write('{};{}\n'.format(n, categories[pred]))
else:
with open(args.csv_file, 'w') as f:
for n, pred5 in zip(vid_names, video_pred_top5):
fill = [n]
for p in list(pred5):
fill.append(p)
f.write('{};{};{};{};{};{}\n'.format(*fill))
cf = confusion_matrix(video_labels, video_pred).astype(float)
np.save('cm.npy', cf)
cls_cnt = cf.sum(axis=1)
cls_hit = np.diag(cf)
cls_acc = cls_hit / cls_cnt
print(cls_acc)
upper = np.mean(np.max(cf, axis=1) / cls_cnt)
print('upper bound: {}'.format(upper))
print('-----Evaluation is finished------')
print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))