-
Notifications
You must be signed in to change notification settings - Fork 528
/
calibrate_torchvision_model.py
77 lines (59 loc) · 2.65 KB
/
calibrate_torchvision_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
import os
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models.quantization import *
class CalibrationDataset(Dataset):
def __init__(self, root, files, transform):
with open(files, 'r') as f:
self.files = [os.path.join(root, fn.strip()) for fn in f.readlines()]
self.transform = transform
def __getitem__(self, idx):
image = Image.open(self.files[idx]).convert('RGB')
image = self.transform(image)
return image
def __len__(self):
return len(self.files)
def quantize_model(model, dataloader, backend='fbgemm'):
if backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported ")
torch.backends.quantized.engine = backend
model.cpu()
model.eval()
model.fuse_model()
# Make sure that weight qconfig matches that of the serialized models
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(model, inplace=True)
print('calibrating...')
for x in tqdm(dataloader):
model(x)
print('calibration DONE!')
torch.quantization.convert(model, inplace=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--image-dir', type=str, default='imagenet/val')
parser.add_argument('--image-list', type=str, default='../../calibration/ImageNet/cal_image_list_option_1.txt')
args = parser.parse_args()
print(args)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = CalibrationDataset(root=args.image_dir, files=args.image_list, transform=transform)
dataloader = DataLoader(dataset, batch_size=1)
model = eval(args.model)(pretrained=True, progress=True, quantize=False)
quantize_model(model, dataloader)
print(model)
inp = torch.rand(1, 3, 224, 224)
script_module = torch.jit.trace(model, inp)
save_path = f'{args.model}.pt'
torch.jit.save(script_module, save_path)
print(f'saved: {save_path}')
if __name__=='__main__':
main()