-
Notifications
You must be signed in to change notification settings - Fork 114
/
amc_fine_tune.py
executable file
·225 lines (183 loc) · 8.19 KB
/
amc_fine_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# Code for "AMC: AutoML for Model Compression and Acceleration on Mobile Devices"
# Yihui He*, Ji Lin*, Zhijian Liu, Hanrui Wang, Li-Jia Li, Song Han
# {jilin, songhan}@mit.edu
import os
import time
import argparse
import shutil
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tensorboardX import SummaryWriter
from lib.utils import accuracy, AverageMeter, progress_bar, get_output_folder
from lib.data import get_dataset
from lib.net_measure import measure_model
def parse_args():
parser = argparse.ArgumentParser(description='AMC fine-tune script')
parser.add_argument('--model', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--dataset', default='imagenet', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train')
parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
parser.add_argument('--seed', default=None, type=int, help='random seed to set')
parser.add_argument('--data_root', default=None, type=str, help='dataset path')
# resume
parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to resume from')
# run eval
parser.add_argument('--eval', action='store_true', help='Simply run eval')
return parser.parse_args()
def get_model():
print('=> Building model..')
if args.model == 'mobilenet':
from models.mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif args.model == 'mobilenet_0.5flops':
from models.mobilenet import MobileNet
net = MobileNet(n_class=1000, profile='0.5flops')
else:
raise NotImplementedError
return net.cuda() if use_cuda else net
def train(epoch, train_loader):
print('\nEpoch: %d' % epoch)
net.train()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for batch_idx, (inputs, targets) in enumerate(train_loader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# timing
batch_time.update(time.time() - end)
end = time.time()
progress_bar(batch_idx, len(train_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'
.format(losses.avg, top1.avg, top5.avg))
writer.add_scalar('loss/train', losses.avg, epoch)
writer.add_scalar('acc/train_top1', top1.avg, epoch)
writer.add_scalar('acc/train_top5', top5.avg, epoch)
def test(epoch, test_loader, save=True):
global best_acc
net.eval()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
loss = criterion(outputs, targets)
# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# timing
batch_time.update(time.time() - end)
end = time.time()
progress_bar(batch_idx, len(test_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'
.format(losses.avg, top1.avg, top5.avg))
if save:
writer.add_scalar('loss/test', losses.avg, epoch)
writer.add_scalar('acc/test_top1', top1.avg, epoch)
writer.add_scalar('acc/test_top5', top5.avg, epoch)
is_best = False
if top1.avg > best_acc:
best_acc = top1.avg
is_best = True
print('Current best acc: {}'.format(best_acc))
save_checkpoint({
'epoch': epoch,
'model': args.model,
'dataset': args.dataset,
'state_dict': net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict(),
'acc': top1.avg,
'optimizer': optimizer.state_dict(),
}, is_best, checkpoint_dir=log_dir)
def adjust_learning_rate(optimizer, epoch):
if args.lr_type == 'cos': # cos without warm-up
lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.n_epoch))
elif args.lr_type == 'exp':
step = 1
decay = 0.96
lr = args.lr * (decay ** (epoch // step))
elif args.lr_type == 'fixed':
lr = args.lr
else:
raise NotImplementedError
print('=> lr: {}'.format(lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def save_checkpoint(state, is_best, checkpoint_dir='.'):
filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar')
print('=> Saving checkpoint to {}'.format(filename))
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar'))
if __name__ == '__main__':
args = parse_args()
use_cuda = torch.cuda.is_available()
if use_cuda:
torch.backends.cudnn.benchmark = True
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
if args.seed is not None:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
print('=> Preparing data..')
train_loader, val_loader, n_class = get_dataset(args.dataset, args.batch_size, args.n_worker,
data_root=args.data_root)
net = get_model() # for measure
IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE)
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6))
del net
net = get_model() # real training
if args.ckpt_path is not None: # assigned checkpoint path to resume from
print('=> Resuming from checkpoint..')
checkpoint = torch.load(args.ckpt_path)
sd = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
net.load_state_dict(sd)
if use_cuda and args.n_gpu > 1:
net = torch.nn.DataParallel(net, list(range(args.n_gpu)))
criterion = nn.CrossEntropyLoss()
print('Using SGD...')
print('weight decay = {}'.format(args.wd))
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
if args.eval: # just run eval
print('=> Start evaluation...')
test(0, val_loader, save=False)
else: # train
print('=> Start training...')
print('Training {} on {}...'.format(args.model, args.dataset))
log_dir = get_output_folder('./logs', '{}_{}_finetune'.format(args.model, args.dataset))
print('=> Saving logs to {}'.format(log_dir))
# tf writer
writer = SummaryWriter(logdir=log_dir)
for epoch in range(start_epoch, start_epoch + args.n_epoch):
lr = adjust_learning_rate(optimizer, epoch)
train(epoch, train_loader)
test(epoch, val_loader)
writer.close()
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M, best top-1 acc: {}%'.format(n_params / 1e6, n_flops / 1e6, best_acc))