-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_vision.py
214 lines (178 loc) · 8.44 KB
/
train_vision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright (c) 2024. Samsung Electronics Co., Ltd.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Based on https://github.com/pytorch/examples/blob/main/mnist/main.py
"""
Example usage (NiNo's checkpoint checkpoints/nino.pt is used by default):
python train_vision.py --task C10-32
"""
import argparse
import os.path
import numpy as np
import platform
import subprocess
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from optim import NiNo
from utils import set_seed, Net, TASKS, mem
def test(model, data, target, verbose=0):
model.eval()
with torch.no_grad():
output = model(data)
test_loss = F.cross_entropy(output, target).item()
correct = torch.sum(torch.argmax(output, -1).eq(target)).item()
acc = correct / len(data) * 100
if verbose > 2:
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(data), acc))
return {'loss': test_loss, 'acc': acc}
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch training with predicting future parameters using NiNo')
parser.add_argument('--nino_ckpt', type=str, default='checkpoints/nino.pt')
parser.add_argument('--task', type=str, default='FM-16', help='see utils/vision.py for all the tasks')
parser.add_argument('--lr', type=float, default=None)
parser.add_argument('--wd', type=float, default=0)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--max_steps', type=int, default=10000,
help='maximum number of iterations to train, early stopping when target is reached')
parser.add_argument('--period', type=int, default=1000,
help='number of base opt steps after which to apply NiNo')
parser.add_argument('--seed', type=int, default=1000, help='random seed')
parser.add_argument('--verbose', type=int, default=0)
parser.add_argument('--save_path', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--log_interval', type=int, default=100,
help='how many batches to wait before logging training status')
args = parser.parse_args()
print('\nEnvironment:')
env = {}
try:
# print git commit to ease code reproducibility
env['git commit'] = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
except Exception as e:
env['git commit'] = str(e)
env['hostname'] = platform.node()
env['torch'] = torch.__version__
env['torchvision'] = torchvision.__version__
env['cuda available'] = torch.cuda.is_available()
env['cudnn enabled'] = cudnn.enabled
env['cuda version'] = torch.version.cuda
env['start time'] = time.strftime('%Y%m%d-%H%M%S')
for x, y in env.items():
print('{:20s}: {}'.format(x[:20], y))
args.env = env
print('\nScript Arguments:', flush=True)
args_var = vars(args)
for x in sorted(args_var.keys()):
y = args_var[x]
print('{:20s}: {}'.format(x[:20], y))
print('\n', flush=True)
return args
def main():
args = parse_args()
device = args.device
try:
task = TASKS[args.task]
except KeyError:
raise ValueError(f"Task {args.task} not found in {list(TASKS.keys())}")
print('task', task)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(*task['norm'])
])
train_kwargs = {'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.num_workers}
test_kwargs = dict(train_kwargs, shuffle=False, batch_size=10000)
train_data = eval(f"datasets.{task['dataset']}('../data', train=True, download=True, transform=transform)")
# in our experiments, we reserved a small subset of the training set for validation, keep the same for consistency
n_all = len(train_data.targets)
idx_train = torch.arange(n_all - n_all // 12)
train_data.data = train_data.data[idx_train]
train_data.targets = [train_data.targets[i] for i in idx_train]
set_seed(args.seed)
generator = torch.Generator()
train_kwargs['generator'] = generator
train_loader = torch.utils.data.DataLoader(
train_data, **train_kwargs)
test_loader = torch.utils.data.DataLoader(
eval(f"datasets.{task['dataset']}('../data', train=False, download=True, transform=transform)"), **test_kwargs)
# preload the test data to avoid overheads of loading them every time for evaluation
for data_eval, target_eval in test_loader:
data_eval, target_eval = data_eval.to(device, non_blocking=True), target_eval.to(device, non_blocking=True)
break # only one big batch is expected
set_seed(args.seed) # set the seed again to make initial weights easily reproducible
model = Net(**task['net_args']).to(device)
print(model, 'params', sum(p.numel() for p in model.parameters()),
'total param norm',
torch.norm(torch.stack([torch.norm(p.data, 2) for p in model.parameters()]), 2).item())
lr = args.lr if args.lr is not None else task['lr']
# create a NiNo-based opt with AdamW as a base optimizer
opt = NiNo(base_opt=optim.AdamW(model.parameters(), lr=lr, weight_decay=args.wd),
ckpt=args.nino_ckpt,
model=model,
period=args.period,
max_steps=args.max_steps,
verbose=args.verbose)
epochs = int(np.ceil(args.max_steps / len(train_loader)))
losses = []
start_time = time.time()
done = False
for epoch in range(epochs):
set_seed(args.seed + epoch) # set the seed again to make batches the same for nino and adam
generator.manual_seed(args.seed + epoch)
for t, (data, target) in enumerate(train_loader):
model.train()
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
if opt.need_grads:
loss = F.cross_entropy(model(data), target)
loss.backward() # only compute gradients for the base optimizer
closure = None
losses.append(loss.item())
else:
def closure():
# eval the loss after the NiNo step to see how it affects the training
with torch.no_grad():
return F.cross_entropy(model(data), target)
loss_ = opt.step(closure) # base_opt step or nowcast params every args.period steps using NiNo
opt.zero_grad()
if loss_ is not None:
losses.append(loss_.item())
scores = test(model, data_eval, target_eval, verbose=args.verbose)
if opt.step_idx % args.log_interval == 0:
print('Train {:04d}/{}: \tTrain loss: {:.4f} \tVal loss: {:.4f} \tVal acc: {:.2f}% '
'\t(sec/b={:.3f}, {}={:.3f}G)'.format(
opt.step_idx, args.max_steps, losses[-1], scores['loss'], scores['acc'],
(time.time() - start_time) / opt.step_idx, device, mem(device)))
if scores['acc'] >= task['target']:
print('\nReached target accuracy of {:.2f}%>={:.2f}% in {} steps '
'({:.4f} seconds)'.format(scores['acc'],
task["target"],
opt.step_idx,
time.time() - start_time))
done = True
if opt.step_idx >= args.max_steps:
done = True
if done:
break
if done:
break
if args.save_path not in [None, '', 'None', 'none']:
if not os.path.isdir(args.save_path):
os.mkdir(args.save_path)
save_path = os.path.join(args.save_path, 'ckpt.pt')
torch.save({
'model': model.state_dict(),
'opt': opt.base_opt.state_dict(), # can save the optimizer state with the history of the past states
'args': args},
save_path)
print('Model and opt saved to', save_path)
if __name__ == '__main__':
main()
print('Done!')