-
Notifications
You must be signed in to change notification settings - Fork 34
/
evaluate.py
109 lines (79 loc) · 2.97 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
import warnings
import torch
import torch.nn as nn
from sklearn.metrics import precision_score, f1_score, recall_score, confusion_matrix
from data.fer2013 import get_dataloaders
from utils.hparams import setup_hparams
from utils.setup_network import setup_network
warnings.filterwarnings("ignore")
device = torch.device("cpu")
def correct_count(output, target, topk=(1,)):
"""Computes the top k corrrect count for the specified values of k"""
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k)
return res
def evaluate(net, dataloader, criterion):
net = net.eval()
loss_tr, n_samples = 0.0, 0.0
y_pred = []
y_gt = []
correct_count1 = 0
correct_count2 = 0
for data in dataloader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# fuse crops and batchsize
bs, ncrops, c, h, w = inputs.shape
inputs = inputs.view(-1, c, h, w)
# forward
outputs = net(inputs)
# combine results across the crops
outputs = outputs.view(bs, ncrops, -1)
outputs = torch.sum(outputs, dim=1) / ncrops
loss = criterion(outputs, labels)
# calculate performance metrics
loss_tr += loss.item()
# accuracy
counts = correct_count(outputs, labels, topk=(1, 2))
correct_count1 += counts[0].item()
correct_count2 += counts[1].item()
_, preds = torch.max(outputs.data, 1)
preds = preds.to("cpu")
labels = labels.to("cpu")
n_samples += labels.size(0)
y_pred.extend(pred.item() for pred in preds)
y_gt.extend(y.item() for y in labels)
acc1 = 100 * correct_count1 / n_samples
acc2 = 100 * correct_count2 / n_samples
loss = loss_tr / n_samples
print("--------------------------------------------------------")
print("Top 1 Accuracy: %2.6f %%" % acc1)
print("Top 2 Accuracy: %2.6f %%" % acc2)
print("Loss: %2.6f" % loss)
print("Precision: %2.6f" % precision_score(y_gt, y_pred, average='micro'))
print("Recall: %2.6f" % recall_score(y_gt, y_pred, average='micro'))
print("F1 Score: %2.6f" % f1_score(y_gt, y_pred, average='micro'))
print("Confusion Matrix:\n", confusion_matrix(y_gt, y_pred), '\n')
if __name__ == "__main__":
# Important parameters
hps = setup_hparams(sys.argv[1:])
# build network
logger, net = setup_network(hps)
net = net.to(device)
print(net)
criterion = nn.CrossEntropyLoss()
# Get data with no augmentation
trainloader, valloader, testloader = get_dataloaders(augment=False)
print("Train")
evaluate(net, trainloader, criterion)
print("Val")
evaluate(net, valloader, criterion)
print("Test")
evaluate(net, testloader, criterion)