-
Notifications
You must be signed in to change notification settings - Fork 57
/
adan.py
156 lines (128 loc) · 6.5 KB
/
adan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# https://raw.githubusercontent.com/sail-sg/Adan/main/adan.py
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch.optim.optimizer import Optimizer
from timm.utils import *
class Adan(Optimizer):
"""
Implements a pytorch variant of Adan
Adan was proposed in
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float, flot], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0)
max_grad_norm (float, optional): value used to clip
global grad norm (default: 0.0 no clip)
no_prox (bool): how to perform the decoupled weight decay (default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-8,
weight_decay=0.0, max_grad_norm=0.0, no_prox=False):
if not 0.0 <= max_grad_norm:
raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm))
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm, no_prox=no_prox)
super(Adan, self).__init__(params, defaults)
def __setstate__(self, state):
super(Adan, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('no_prox', False)
@torch.no_grad()
def restart_opt(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
if p.requires_grad:
state = self.state[p]
# State initialization
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
# Exponential moving average of gradient difference
state['exp_avg_diff'] = torch.zeros_like(p)
@torch.no_grad()
def step(self):
"""
Performs a single optimization step.
"""
if self.defaults['max_grad_norm'] > 0:
device = self.param_groups[0]['params'][0].device
global_grad_norm = torch.zeros(1, device=device)
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
grad = p.grad
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
clip_global_grad_norm = torch.clamp(max_grad_norm / (global_grad_norm + group['eps']), max=1.0)
else:
clip_global_grad_norm = 1.0
for group in self.param_groups:
beta1, beta2, beta3 = group['betas']
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
bias_correction1 = 1.0 - beta1 ** group['step']
bias_correction2 = 1.0 - beta2 ** group['step']
bias_correction3 = 1.0 - beta3 ** group['step']
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
grad = p.grad.mul_(clip_global_grad_norm)
if 'pre_grad' not in state or group['step'] == 1:
state['pre_grad'] = grad
copy_grad = grad.clone()
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']
diff = grad - state['pre_grad']
update = grad + beta2 * diff
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t
exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t
denom = ((exp_avg_sq).sqrt() / math.sqrt(bias_correction3)).add_(group['eps'])
update = ((exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2)).div_(denom)
if group['no_prox']:
p.data.mul_(1 - group['lr'] * group['weight_decay'])
p.add_(update, alpha=-group['lr'])
else:
p.add_(update, alpha=-group['lr'])
p.data.div_(1 + group['lr'] * group['weight_decay'])
state['pre_grad'] = copy_grad