-
Notifications
You must be signed in to change notification settings - Fork 222
/
trainer.py
120 lines (103 loc) · 4.21 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch.optim as optim
import math
from net import *
import util
class Trainer():
def __init__(self, model, lrate, wdecay, clip, step_size, seq_out_len, scaler, device, cl=True):
self.scaler = scaler
self.model = model
self.model.to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay)
self.loss = util.masked_mae
self.clip = clip
self.step = step_size
self.iter = 1
self.task_level = 1
self.seq_out_len = seq_out_len
self.cl = cl
def train(self, input, real_val, idx=None):
self.model.train()
self.optimizer.zero_grad()
output = self.model(input, idx=idx)
output = output.transpose(1,3)
real = torch.unsqueeze(real_val,dim=1)
predict = self.scaler.inverse_transform(output)
if self.iter%self.step==0 and self.task_level<=self.seq_out_len:
self.task_level +=1
if self.cl:
loss = self.loss(predict[:, :, :, :self.task_level], real[:, :, :, :self.task_level], 0.0)
else:
loss = self.loss(predict, real, 0.0)
loss.backward()
if self.clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()
# mae = util.masked_mae(predict,real,0.0).item()
mape = util.masked_mape(predict,real,0.0).item()
rmse = util.masked_rmse(predict,real,0.0).item()
self.iter += 1
return loss.item(),mape,rmse
def eval(self, input, real_val):
self.model.eval()
output = self.model(input)
output = output.transpose(1,3)
real = torch.unsqueeze(real_val,dim=1)
predict = self.scaler.inverse_transform(output)
loss = self.loss(predict, real, 0.0)
mape = util.masked_mape(predict,real,0.0).item()
rmse = util.masked_rmse(predict,real,0.0).item()
return loss.item(),mape,rmse
class Optim(object):
def _makeOptimizer(self):
if self.method == 'sgd':
self.optimizer = optim.SGD(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'adagrad':
self.optimizer = optim.Adagrad(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'adadelta':
self.optimizer = optim.Adadelta(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'adam':
self.optimizer = optim.Adam(self.params, lr=self.lr, weight_decay=self.lr_decay)
else:
raise RuntimeError("Invalid optim method: " + self.method)
def __init__(self, params, method, lr, clip, lr_decay=1, start_decay_at=None):
self.params = params # careful: params may be a generator
self.last_ppl = None
self.lr = lr
self.clip = clip
self.method = method
self.lr_decay = lr_decay
self.start_decay_at = start_decay_at
self.start_decay = False
self._makeOptimizer()
def step(self):
# Compute gradients norm.
grad_norm = 0
if self.clip is not None:
torch.nn.utils.clip_grad_norm_(self.params, self.clip)
# for param in self.params:
# grad_norm += math.pow(param.grad.data.norm(), 2)
#
# grad_norm = math.sqrt(grad_norm)
# if grad_norm > 0:
# shrinkage = self.max_grad_norm / grad_norm
# else:
# shrinkage = 1.
#
# for param in self.params:
# if shrinkage < 1:
# param.grad.data.mul_(shrinkage)
self.optimizer.step()
return grad_norm
# decay learning rate if val perf does not improve or we hit the start_decay_at limit
def updateLearningRate(self, ppl, epoch):
if self.start_decay_at is not None and epoch >= self.start_decay_at:
self.start_decay = True
if self.last_ppl is not None and ppl > self.last_ppl:
self.start_decay = True
if self.start_decay:
self.lr = self.lr * self.lr_decay
print("Decaying learning rate to %g" % self.lr)
#only decay for one epoch
self.start_decay = False
self.last_ppl = ppl
self._makeOptimizer()