-
Notifications
You must be signed in to change notification settings - Fork 16
/
iMAML.py
298 lines (224 loc) · 11.5 KB
/
iMAML.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
Meta-learning Omniglot and mini-imagenet experiments with iMAML-GD (see [1] for more details).
The code is quite simple and easy to read thanks to the following two libraries which need both to be installed.
- higher: https://github.com/facebookresearch/higher (used to get stateless version of torch nn.Module-s)
- torchmeta: https://github.com/tristandeleu/pytorch-meta (used for meta-dataset loading and minibatching)
[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019).
Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124).
https://arxiv.org/abs/1909.04630
"""
import math
import argparse
import time
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchmeta.datasets.helpers import omniglot, miniimagenet
from torchmeta.utils.data import BatchMetaDataLoader
import higher
import hypergrad as hg
class Task:
"""
Handles the train and valdation loss for a single task
"""
def __init__(self, reg_param, meta_model, data, batch_size=None):
device = next(meta_model.parameters()).device
# stateless version of meta_model
self.fmodel = higher.monkeypatch(meta_model, device=device, copy_initial_weights=True)
self.n_params = len(list(meta_model.parameters()))
self.train_input, self.train_target, self.test_input, self.test_target = data
self.reg_param = reg_param
self.batch_size = 1 if not batch_size else batch_size
self.val_loss, self.val_acc = None, None
def bias_reg_f(self, bias, params):
# l2 biased regularization
return sum([((b - p) ** 2).sum() for b, p in zip(bias, params)])
def train_loss_f(self, params, hparams):
# biased regularized cross-entropy loss where the bias are the meta-parameters in hparams
out = self.fmodel(self.train_input, params=params)
return F.cross_entropy(out, self.train_target) + 0.5 * self.reg_param * self.bias_reg_f(hparams, params)
def val_loss_f(self, params, hparams):
# cross-entropy loss (uses only the task-specific weights in params
out = self.fmodel(self.test_input, params=params)
val_loss = F.cross_entropy(out, self.test_target)/self.batch_size
self.val_loss = val_loss.item() # avoid memory leaks
pred = out.argmax(dim=1, keepdim=True) # get the index of the max log-probability
self.val_acc = pred.eq(self.test_target.view_as(pred)).sum().item() / len(self.test_target)
return val_loss
def main():
parser = argparse.ArgumentParser(description='Data HyperCleaner')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dataset', type=str, default='omniglot', metavar='N', help='omniglot or miniimagenet')
parser.add_argument('--hg-mode', type=str, default='CG', metavar='N',
help='hypergradient approximation: CG or fixed_point')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
args = parser.parse_args()
log_interval = 100
eval_interval = 500
inner_log_interval = None
inner_log_interval_test = None
ways = 5
batch_size = 16
n_tasks_test = 1000 # usually 1000 tasks are used for testing
if args.dataset == 'omniglot':
reg_param = 2 # reg_param = 2
T, K = 16, 5 # T, K = 16, 5
elif args.dataset == 'miniimagenet':
reg_param = 0.5 # reg_param = 0.5
T, K = 10, 5 # T, K = 10, 5
else:
raise NotImplementedError(args.dataset, " not implemented!")
T_test = T
inner_lr = .1
loc = locals()
del loc['parser']
del loc['args']
print(args, '\n', loc, '\n')
cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
# the following are for reproducibility on GPU, see https://pytorch.org/docs/master/notes/randomness.html
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
torch.random.manual_seed(args.seed)
np.random.seed(args.seed)
if args.dataset == 'omniglot':
dataset = omniglot("data", ways=ways, shots=1, test_shots=15, meta_train=True, download=True)
test_dataset = omniglot("data", ways=ways, shots=1, test_shots=15, meta_test=True, download=True)
meta_model = get_cnn_omniglot(64, ways).to(device)
elif args.dataset == 'miniimagenet':
dataset = miniimagenet("data", ways=ways, shots=1, test_shots=15, meta_train=True, download=True)
test_dataset = miniimagenet("data", ways=ways, shots=1, test_shots=15, meta_test=True, download=True)
meta_model = get_cnn_miniimagenet(32, ways).to(device)
else:
raise NotImplementedError("DATASET NOT IMPLEMENTED! only omniglot and miniimagenet ")
dataloader = BatchMetaDataLoader(dataset, batch_size=batch_size, **kwargs)
test_dataloader = BatchMetaDataLoader(test_dataset, batch_size=batch_size, **kwargs)
outer_opt = torch.optim.Adam(params=meta_model.parameters())
# outer_opt = torch.optim.SGD(lr=0.1, params=meta_model.parameters())
inner_opt_class = hg.GradientDescent
inner_opt_kwargs = {'step_size': inner_lr}
def get_inner_opt(train_loss):
return inner_opt_class(train_loss, **inner_opt_kwargs)
for k, batch in enumerate(dataloader):
start_time = time.time()
meta_model.train()
tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to(device)
tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to(device)
outer_opt.zero_grad()
val_loss, val_acc = 0, 0
forward_time, backward_time = 0, 0
for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
start_time_task = time.time()
# single task set up
task = Task(reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y), batch_size=tr_xs.shape[0])
inner_opt = get_inner_opt(task.train_loss_f)
# single task inner loop
params = [p.detach().clone().requires_grad_(True) for p in meta_model.parameters()]
last_param = inner_loop(meta_model.parameters(), params, inner_opt, T, log_interval=inner_log_interval)[-1]
forward_time_task = time.time() - start_time_task
# single task hypergradient computation
if args.hg_mode == 'CG':
# This is the approximation used in the paper CG stands for conjugate gradient
cg_fp_map = hg.GradientDescent(loss_f=task.train_loss_f, step_size=1.)
hg.CG(last_param, list(meta_model.parameters()), K=K, fp_map=cg_fp_map, outer_loss=task.val_loss_f)
elif args.hg_mode == 'fixed_point':
hg.fixed_point(last_param, list(meta_model.parameters()), K=K, fp_map=inner_opt,
outer_loss=task.val_loss_f)
backward_time_task = time.time() - start_time_task - forward_time_task
val_loss += task.val_loss
val_acc += task.val_acc/task.batch_size
forward_time += forward_time_task
backward_time += backward_time_task
outer_opt.step()
step_time = time.time() - start_time
if k % log_interval == 0:
print('MT k={} ({:.3f}s F: {:.3f}s, B: {:.3f}s) Val Loss: {:.2e}, Val Acc: {:.2f}.'
.format(k, step_time, forward_time, backward_time, val_loss, 100. * val_acc))
if k % eval_interval == 0:
test_losses, test_accs = evaluate(n_tasks_test, test_dataloader, meta_model, T_test, get_inner_opt,
reg_param, log_interval=inner_log_interval_test)
print("Test loss {:.2e} +- {:.2e}: Test acc: {:.2f} +- {:.2e} (mean +- std over {} tasks)."
.format(test_losses.mean(), test_losses.std(), 100. * test_accs.mean(),
100.*test_accs.std(), len(test_losses)))
def inner_loop(hparams, params, optim, n_steps, log_interval, create_graph=False):
params_history = [optim.get_opt_params(params)]
for t in range(n_steps):
params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))
if log_interval and (t % log_interval == 0 or t == n_steps-1):
print('t={}, Loss: {:.6f}'.format(t, optim.curr_loss.item()))
return params_history
def evaluate(n_tasks, dataloader, meta_model, n_steps, get_inner_opt, reg_param, log_interval=None):
meta_model.train()
device = next(meta_model.parameters()).device
val_losses, val_accs = [], []
for k, batch in enumerate(dataloader):
tr_xs, tr_ys = batch["train"][0].to(device), batch["train"][1].to(device)
tst_xs, tst_ys = batch["test"][0].to(device), batch["test"][1].to(device)
for t_idx, (tr_x, tr_y, tst_x, tst_y) in enumerate(zip(tr_xs, tr_ys, tst_xs, tst_ys)):
task = Task(reg_param, meta_model, (tr_x, tr_y, tst_x, tst_y))
inner_opt = get_inner_opt(task.train_loss_f)
params = [p.detach().clone().requires_grad_(True) for p in meta_model.parameters()]
last_param = inner_loop(meta_model.parameters(), params, inner_opt, n_steps, log_interval=log_interval)[-1]
task.val_loss_f(last_param, meta_model.parameters())
val_losses.append(task.val_loss)
val_accs.append(task.val_acc)
if len(val_accs) >= n_tasks:
return np.array(val_losses), np.array(val_accs)
def get_cnn_omniglot(hidden_size, n_classes):
def conv_layer(ic, oc, ):
return nn.Sequential(
nn.Conv2d(ic, oc, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.BatchNorm2d(oc, momentum=1., affine=True,
track_running_stats=True # When this is true is called the "transductive setting"
)
)
net = nn.Sequential(
conv_layer(1, hidden_size),
conv_layer(hidden_size, hidden_size),
conv_layer(hidden_size, hidden_size),
conv_layer(hidden_size, hidden_size),
nn.Flatten(),
nn.Linear(hidden_size, n_classes)
)
initialize(net)
return net
def get_cnn_miniimagenet(hidden_size, n_classes):
def conv_layer(ic, oc):
return nn.Sequential(
nn.Conv2d(ic, oc, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
nn.BatchNorm2d(oc, momentum=1., affine=True,
track_running_stats=False # When this is true is called the "transductive setting"
)
)
net = nn.Sequential(
conv_layer(3, hidden_size),
conv_layer(hidden_size, hidden_size),
conv_layer(hidden_size, hidden_size),
conv_layer(hidden_size, hidden_size),
nn.Flatten(),
nn.Linear(hidden_size*5*5, n_classes,)
)
initialize(net)
return net
def initialize(net):
# initialize weights properly
for m in net.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
#m.weight.data.normal_(0, 0.01)
#m.bias.data = torch.ones(m.bias.data.size())
m.weight.data.zero_()
m.bias.data.zero_()
return net
if __name__ == '__main__':
main()