-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
92 lines (75 loc) · 3.46 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from model.model import Encoder
import numpy as np
import argparse
from data_loader.load_images import ImageList
import data_loader.transforms as transforms
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.sum += val
self.count += n
self.avg = self.sum / self.count
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MemSAC')
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('--dataset', type=str, nargs='?', default='c', help="target dataset")
parser.add_argument('--target', type=str, nargs='?', default='c', help="target domain")
parser.add_argument('--batch_size', type=int, default=64, help="batch size should be samples * classes")
parser.add_argument('--nClasses', type=int, help="#Classes")
parser.add_argument('--checkpoint' , type=str, help="Checkpoint to load from.")
parser.add_argument('--multi_gpu', type=int, default=0, help="use dataparallel if 1")
parser.add_argument('--data_dir', required=True)
parser.add_argument('--resnet', default="resnet50", help="Resnet backbone")
args = parser.parse_args()
if args.dataset == "cub2011":
file_path = {
"cub": "./data_files/cub200/cub200_2011.txt" ,
"drawing": "./data_files/cub200/cub200_drawing.txt" ,
}
elif args.dataset == "domainNet":
file_path = {
"real": "./data_files/DomainNet/real_test.txt" ,
"sketch": "./data_files/DomainNet/sketch_test.txt" ,
"painting": "./data_files/DomainNet/painting_test.txt" ,
"clipart": "./data_files/DomainNet/clipart_test.txt"}
else:
raise NotImplementedError
dataset_test = file_path[args.target]
print("Target" , args.target)
dataset_loaders = {}
dataset_list = ImageList(args.data_dir, open(dataset_test).readlines(), transform=transforms.image_test(resize_size=256, crop_size=224))
print("Size of target dataset:" , len(dataset_list))
dataset_loaders["test"] = torch.utils.data.DataLoader(dataset_list, batch_size=args.batch_size, shuffle=False,
num_workers=16, drop_last=False)
# network construction
print(args.nClasses)
base_network = Encoder(args.resnet, 256, args.nClasses).cuda()
accuracy = AverageMeter()
saved_state_dict = torch.load(args.checkpoint)
base_network.load_state_dict(saved_state_dict, strict=True)
base_network.eval()
start_test = True
iter_test = iter(dataset_loaders["test"])
with torch.no_grad():
for i in range(len(dataset_loaders['test'])):
print("{0}/{1}".format(i,len(dataset_loaders['test'])) , end="\r")
data = iter_test.next()
inputs = data[0]
labels = data[1]
inputs = inputs.cuda()
labels = labels.cuda()
_, outputs = base_network(inputs)
predictions = outputs.argmax(1)
correct = torch.sum((predictions == labels).float())
accuracy.update(correct, len(outputs))
print_str = "\nCorrect Predictions: {}/{}".format(int(accuracy.sum), accuracy.count)
print_str1 = '\ntest_acc:{:.4f}'.format(accuracy.avg)
print(print_str + print_str1)