-
Notifications
You must be signed in to change notification settings - Fork 202
/
finetune.py
277 lines (218 loc) · 9.37 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch
from torch.autograd import Variable
from torchvision import models
import cv2
import sys
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dataset
from prune import *
import argparse
from operator import itemgetter
from heapq import nsmallest
import time
class ModifiedVGG16Model(torch.nn.Module):
def __init__(self):
super(ModifiedVGG16Model, self).__init__()
model = models.vgg16(pretrained=True)
self.features = model.features
for param in self.features.parameters():
param.requires_grad = False
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(25088, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 2))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class FilterPrunner:
def __init__(self, model):
self.model = model
self.reset()
def reset(self):
self.filter_ranks = {}
def forward(self, x):
self.activations = []
self.gradients = []
self.grad_index = 0
self.activation_to_layer = {}
activation_index = 0
for layer, (name, module) in enumerate(self.model.features._modules.items()):
x = module(x)
if isinstance(module, torch.nn.modules.conv.Conv2d):
x.register_hook(self.compute_rank)
self.activations.append(x)
self.activation_to_layer[activation_index] = layer
activation_index += 1
return self.model.classifier(x.view(x.size(0), -1))
def compute_rank(self, grad):
activation_index = len(self.activations) - self.grad_index - 1
activation = self.activations[activation_index]
taylor = activation * grad
# Get the average value for every filter,
# accross all the other dimensions
taylor = taylor.mean(dim=(0, 2, 3)).data
if activation_index not in self.filter_ranks:
self.filter_ranks[activation_index] = \
torch.FloatTensor(activation.size(1)).zero_()
if args.use_cuda:
self.filter_ranks[activation_index] = self.filter_ranks[activation_index].cuda()
self.filter_ranks[activation_index] += taylor
self.grad_index += 1
def lowest_ranking_filters(self, num):
data = []
for i in sorted(self.filter_ranks.keys()):
for j in range(self.filter_ranks[i].size(0)):
data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j]))
return nsmallest(num, data, itemgetter(2))
def normalize_ranks_per_layer(self):
for i in self.filter_ranks:
v = torch.abs(self.filter_ranks[i])
v = v / np.sqrt(torch.sum(v * v))
self.filter_ranks[i] = v.cpu()
def get_prunning_plan(self, num_filters_to_prune):
filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune)
# After each of the k filters are prunned,
# the filter index of the next filters change since the model is smaller.
filters_to_prune_per_layer = {}
for (l, f, _) in filters_to_prune:
if l not in filters_to_prune_per_layer:
filters_to_prune_per_layer[l] = []
filters_to_prune_per_layer[l].append(f)
for l in filters_to_prune_per_layer:
filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l])
for i in range(len(filters_to_prune_per_layer[l])):
filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i
filters_to_prune = []
for l in filters_to_prune_per_layer:
for i in filters_to_prune_per_layer[l]:
filters_to_prune.append((l, i))
return filters_to_prune
class PrunningFineTuner_VGG16:
def __init__(self, train_path, test_path, model):
self.train_data_loader = dataset.loader(train_path)
self.test_data_loader = dataset.test_loader(test_path)
self.model = model
self.criterion = torch.nn.CrossEntropyLoss()
self.prunner = FilterPrunner(self.model)
self.model.train()
def test(self):
return
self.model.eval()
correct = 0
total = 0
for i, (batch, label) in enumerate(self.test_data_loader):
if args.use_cuda:
batch = batch.cuda()
output = model(Variable(batch))
pred = output.data.max(1)[1]
correct += pred.cpu().eq(label).sum()
total += label.size(0)
print("Accuracy :", float(correct) / total)
self.model.train()
def train(self, optimizer = None, epoches=10):
if optimizer is None:
optimizer = optim.SGD(model.classifier.parameters(), lr=0.0001, momentum=0.9)
for i in range(epoches):
print("Epoch: ", i)
self.train_epoch(optimizer)
self.test()
print("Finished fine tuning.")
def train_batch(self, optimizer, batch, label, rank_filters):
if args.use_cuda:
batch = batch.cuda()
label = label.cuda()
self.model.zero_grad()
input = Variable(batch)
if rank_filters:
output = self.prunner.forward(input)
self.criterion(output, Variable(label)).backward()
else:
self.criterion(self.model(input), Variable(label)).backward()
optimizer.step()
def train_epoch(self, optimizer = None, rank_filters = False):
for i, (batch, label) in enumerate(self.train_data_loader):
self.train_batch(optimizer, batch, label, rank_filters)
def get_candidates_to_prune(self, num_filters_to_prune):
self.prunner.reset()
self.train_epoch(rank_filters = True)
self.prunner.normalize_ranks_per_layer()
return self.prunner.get_prunning_plan(num_filters_to_prune)
def total_num_filters(self):
filters = 0
for name, module in self.model.features._modules.items():
if isinstance(module, torch.nn.modules.conv.Conv2d):
filters = filters + module.out_channels
return filters
def prune(self):
#Get the accuracy before prunning
self.test()
self.model.train()
#Make sure all the layers are trainable
for param in self.model.features.parameters():
param.requires_grad = True
number_of_filters = self.total_num_filters()
num_filters_to_prune_per_iteration = 512
iterations = int(float(number_of_filters) / num_filters_to_prune_per_iteration)
iterations = int(iterations * 2.0 / 3)
print("Number of prunning iterations to reduce 67% filters", iterations)
for _ in range(iterations):
print("Ranking filters.. ")
prune_targets = self.get_candidates_to_prune(num_filters_to_prune_per_iteration)
layers_prunned = {}
for layer_index, filter_index in prune_targets:
if layer_index not in layers_prunned:
layers_prunned[layer_index] = 0
layers_prunned[layer_index] = layers_prunned[layer_index] + 1
print("Layers that will be prunned", layers_prunned)
print("Prunning filters.. ")
model = self.model.cpu()
for layer_index, filter_index in prune_targets:
model = prune_vgg16_conv_layer(model, layer_index, filter_index, use_cuda=args.use_cuda)
self.model = model
if args.use_cuda:
self.model = self.model.cuda()
message = str(100*float(self.total_num_filters()) / number_of_filters) + "%"
print("Filters prunned", str(message))
self.test()
print("Fine tuning to recover from prunning iteration.")
optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
self.train(optimizer, epoches = 10)
print("Finished. Going to fine tune the model a bit more")
self.train(optimizer, epoches=15)
torch.save(model.state_dict(), "model_prunned")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train", dest="train", action="store_true")
parser.add_argument("--prune", dest="prune", action="store_true")
parser.add_argument("--train_path", type = str, default = "train")
parser.add_argument("--test_path", type = str, default = "test")
parser.add_argument('--use-cuda', action='store_true', default=False, help='Use NVIDIA GPU acceleration')
parser.set_defaults(train=False)
parser.set_defaults(prune=False)
args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
return args
if __name__ == '__main__':
args = get_args()
if args.train:
model = ModifiedVGG16Model()
elif args.prune:
model = torch.load("model", map_location=lambda storage, loc: storage)
if args.use_cuda:
model = model.cuda()
fine_tuner = PrunningFineTuner_VGG16(args.train_path, args.test_path, model)
if args.train:
fine_tuner.train(epoches=10)
torch.save(model, "model")
elif args.prune:
fine_tuner.prune()