-
Notifications
You must be signed in to change notification settings - Fork 150
/
base_model.py
240 lines (207 loc) · 10.1 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import os
from abc import ABC, abstractmethod
from collections import OrderedDict
import torch
from torch.nn import DataParallel
from models import networks
from utils import util
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids \
else torch.device('cpu') # get device name: CPU or GPU
if opt.isTrain:
self.save_dir = os.path.join(opt.log_dir, 'checkpoints') # save all the checkpoints to save_dir
os.makedirs(self.save_dir, exist_ok=True)
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
self.model_names = []
self.visual_names = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self, steps):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt, verbose=True):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.load_networks(verbose=verbose)
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if verbose:
self.print_networks()
def print_networks(self):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
if hasattr(self.opt, 'log_dir'):
with open(os.path.join(self.opt.log_dir, 'net' + name + '.txt'), 'w') as f:
f.write(str(net) + '\n')
f.write('[Network %s] Total number of parameters : %.3f M\n' % (name, num_params / 1e6))
print('-----------------------------------------------')
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def train(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.train()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self, epoch, total_iter, logger=None):
opt = logger.opt
old_lr = float(self.optimizers[0].param_groups[0]['lr'])
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.opt.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
if logger is not None:
if opt.scheduler_counter == 'epoch' or abs(old_lr - lr) >= 1e-12:
logger.print_info('(epoch: %d, iters: %d) learning rate = %.7f\n' % (epoch, total_iter, lr))
if opt.scheduler_counter == 'epoch' or total_iter % opt.print_freq == 0:
logger.plot({'lr': lr}, total_iter)
else:
if opt.scheduler_counter == 'epoch' or abs(old_lr - lr) >= 1e-12:
print('(epoch: %d, iters: %d) learning rate = %.7f\n' % (epoch, total_iter, lr))
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str) and hasattr(self, name):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
errors_set = OrderedDict()
for name in self.loss_names:
if not hasattr(self, 'loss_' + name):
continue
key = name
def has_number(s):
for i in range(10):
if str(i) in s:
return True
return False
if has_number(key):
key = 'Specific_loss/' + key
elif key.startswith('D_'):
key = 'D_loss/' + key
elif key.startswith('G_'):
key = 'G_loss/' + key
else:
assert False
errors_set[key] = float(getattr(self, 'loss_' + name))
return errors_set
def load_networks(self, verbose=True):
for name in self.model_names:
net = getattr(self, 'net' + name, None)
path = getattr(self.opt, 'restore_%s_path' % name, None)
if path is not None:
util.load_network(net, path, verbose)
if self.isTrain:
if self.opt.restore_O_path is not None:
for i, optimizer in enumerate(self.optimizers):
path = '%s-%d.pth' % (self.opt.restore_O_path, i)
util.load_optimizer(optimizer, path, verbose)
for param_group in optimizer.param_groups:
param_group['lr'] = self.opt.lr
def save_networks(self, epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
if isinstance(net, DataParallel):
torch.save(net.module.cpu().state_dict(), save_path)
else:
torch.save(net.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
if self.isTrain:
for i, optimizer in enumerate(self.optimizers):
save_filename = '%s_optim-%d.pth' % (epoch, i)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(optimizer.state_dict(), save_path)
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def evaluate_model(self, step):
raise NotImplementedError
def profile(self):
raise NotImplementedError