Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-#16167] Refactor Optimizer #17400

Merged
merged 25 commits into from
Feb 29, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4de4ef5
refactor optimizer
Jan 9, 2020
73dcf48
refactor optimizer
Jan 9, 2020
b4ae721
fix svrg test
Jan 22, 2020
7577eb9
Merge branch 'optim' of https://github.com/szhengac/mxnet into optim
Jan 22, 2020
56421a2
fix rmsprop param naming
Jan 22, 2020
302465c
Merge branch 'optim' of https://github.com/szhengac/mxnet into optim
Jan 22, 2020
f0519f8
fix signum test
Jan 22, 2020
a85edaf
fix pylint and perl test
Jan 23, 2020
2ff1ad4
fix perl test and signsgd test
Jan 23, 2020
2b182de
fix
Jan 23, 2020
b14a415
resolve conflict
Feb 13, 2020
d69dac7
Merge branch 'master' into optim
szhengac Feb 13, 2020
f58734a
retrigger
Feb 13, 2020
72a9a28
Merge branch 'optim' of https://github.com/szhengac/mxnet into optim
Feb 13, 2020
a014f4f
fix conflict
Feb 15, 2020
832189d
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Feb 15, 2020
90c2074
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Feb 17, 2020
1d5c47c
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Feb 20, 2020
716902b
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Feb 20, 2020
030a3be
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Feb 20, 2020
092f606
retrigger ci
Feb 21, 2020
35b86b9
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
szhengac Feb 23, 2020
0b98dc7
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
szhengac Feb 24, 2020
54333e5
reduce ci overheads
Feb 24, 2020
db893fa
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Feb 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions R-package/R/optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ mx.opt.sgd <- function(learning.rate = 0.01,
#'
#' @param learning.rate float, default=0.002
#' The initial learning rate.
#' @param gamma1 float, default=0.95
#' @param rho float, default=0.95
#' decay factor of moving average for gradient, gradient^2.
#' @param gamma2 float, default=0.9
#' @param momentum float, default=0.9
#' "momentum" factor.
#' @param epsilon float, default=1e-4
#' @param wd float, default=0.0
Expand All @@ -125,8 +125,8 @@ mx.opt.sgd <- function(learning.rate = 0.01,
#'
mx.opt.rmsprop <- function(learning.rate = 0.002,
centered = TRUE,
gamma1 = 0.95,
gamma2 = 0.9,
rho = 0.95,
momentum = 0.9,
epsilon = 1e-4,
wd = 0,
rescale.grad = 1,
Expand Down Expand Up @@ -158,8 +158,8 @@ mx.opt.rmsprop <- function(learning.rate = 0.002,
g,
delta,
lr = lr,
gamma1 = gamma1,
gamma2 = gamma2,
rho = rho,
momentum = momentum,
epsilon = epsilon,
wd = wd,
rescale_grad = rescale.grad,
Expand All @@ -174,7 +174,7 @@ mx.opt.rmsprop <- function(learning.rate = 0.002,
grad,
n,
lr = lr,
gamma1 = gamma1,
rho = rho,
epsilon = epsilon,
wd = wd,
rescale_grad = rescale.grad,
Expand Down
4 changes: 2 additions & 2 deletions R-package/tests/testthat/test_optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ test_that("rmsprop", {
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, gamma1 = 0.95,
gamma2 = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)
optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, rho = 0.95,
momentum = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

Expand Down
8 changes: 4 additions & 4 deletions benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
DEFAULT_DELTA = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_LRS = [(0.1,0.1)]
DEFAULT_LR = [0.1,0.5,0.9]
DEFAULT_GAMMA_1 = [0.1,0.5,0.9]
DEFAULT_GAMMA_2 = [0.1,0.5,0.9]
DEFAULT_RHO = [0.1,0.5,0.9]
DEFAULT_MOMENTUM = [0.1,0.5,0.9]
DEFAULT_EPSILON = [1e-08]
DEFAULT_BETA_1 = [0.1,0.5,0.9]
DEFAULT_BETA_2 = [0.1,0.5,0.9]
Expand Down Expand Up @@ -139,8 +139,8 @@
"lr" : DEFAULT_LR,
"lrs" : DEFAULT_LRS,
"wds" : DEFAULT_LRS,
"gamma1" : DEFAULT_GAMMA_1,
"gamma2" : DEFAULT_GAMMA_2,
"rho" : DEFAULT_RHO,
"momentum" : DEFAULT_MOMENTUM,
"epsilon" : DEFAULT_EPSILON,
"beta1" : DEFAULT_BETA_1,
"beta2" : DEFAULT_BETA_2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,30 @@
([]
(ada-delta {})))

(s/def gamma1 number?)
(s/def gamma2 number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::gamma1 ::gamma2 ::wd ::clip-gradient]))
(s/def rho number?)
(s/def momentum number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::rho ::momentum ::wd ::clip-gradient]))

(defn rms-prop
"RMSProp optimizer as described in Tieleman & Hinton, 2012.
http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
- learningRate Step size.
- gamma1 decay factor of moving average for gradient, gradient^^2.
- gamma2 momentum factor of moving average for gradient.
- rescale-gradient rescaling factor of gradient.
- wd L2 regularization coefficient add to all the weights
- clip-gradient clip gradient in range [-clip_gradient, clip_gradient]
- lr-scheduler The learning rate scheduler"
([{:keys [learning-rate rescale-gradient gamma1 gamma2 wd lr-scheduler clip-gradient] :as opts
- rho decay factor of moving average for gradient, gradient^^2.
- momentum momentum factor of moving average for gradient.
- rescale-gradient rescaling factor of gradient.
- wd L2 regularization coefficient add to all the weights
- clip-gradient clip gradient in range [-clip_gradient, clip_gradient]
- lr-scheduler The learning rate scheduler"
([{:keys [learning-rate rescale-gradient rho momentum wd lr-scheduler clip-gradient] :as opts
:or {learning-rate 0.002
rescale-gradient 1.0
gamma1 0.95
gamma2 0.9
rho 0.95
momentum 0.9
wd 0.0
clip-gradient 0}}]
(util/validate! ::rms-prop-opts opts "Incorrect rms-prop optimizer options")
(new RMSProp (float learning-rate) (float rescale-gradient) (float gamma1)
(float gamma2) (float wd) lr-scheduler (float clip-gradient)))
(new RMSProp (float learning-rate) (float rescale-gradient) (float rho)
(float momentum) (float wd) lr-scheduler (float clip-gradient)))
([]
(rms-prop {})))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
(is (thrown? Exception (optimizer/dcasgd {:lambda 'a})))
(is (thrown? Exception (optimizer/nag {:momentum 'a})))
(is (thrown? Exception (optimizer/ada-delta {:epsilon 'a})))
(is (thrown? Exception (optimizer/rms-prop {:gamma1 'a})))
(is (thrown? Exception (optimizer/rms-prop {:rho 'a})))
(is (thrown? Exception (optimizer/ada-grad {:rescale-gradient 'a})))
(is (thrown? Exception (optimizer/adam {:beta1 'a})))
(is (thrown? Exception (optimizer/sgld {:lr-scheduler 0.1}))))
2 changes: 1 addition & 1 deletion cpp-package/example/charRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ void trainWithBuiltInRNNOp(const std::string file, int batch_size, int max_epoch
}
start_epoch++;

Optimizer* opt = OptimizerRegistry::Find("ccsgd");
Optimizer* opt = OptimizerRegistry::Find("sgd");
// opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
// ->SetParam("clip_gradient", 10);

Expand Down
2 changes: 1 addition & 1 deletion cpp-package/example/lenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Lenet {
// args_map["fc1_b"] = 0;

lenet.InferArgsMap(ctx_dev, &args_map, args_map);
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0)
->SetParam("clip_gradient", 10)
Expand Down
5 changes: 2 additions & 3 deletions cpp-package/include/mxnet-cpp/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
if (cmap().empty()) {
// Optimizers should only be registered once
MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility
MXNETCPP_REGISTER_OPTIMIZER(rmsprop, RMSPropOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
Expand Down Expand Up @@ -271,8 +270,8 @@ inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
alex_update_handle_ = op_map()->GetSymbolCreator("rmspropalex_update");
SetParam("gamma1", 0.9f);
SetParam("gamma2", 0.9f);
SetParam("rho", 0.9f);
SetParam("momentum", 0.9f);
SetParam("epsilon", 1e-8);
}

Expand Down
30 changes: 2 additions & 28 deletions docs/python_docs/python/tutorials/packages/optimizer/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ Here is an example snippet creating the RMSProp optimizer in MXNet.


```python
rmsprop_optimizer = optimizer.RMSProp(learning_rate=0.001, gamma1=0.9, gamma2=0.9, epsilon=1e-07, centered=False)
rmsprop_optimizer = optimizer.RMSProp(learning_rate=0.001, rho=0.9, momentum=0.9, epsilon=1e-07, centered=False)
```

In the code snippet above, `gamma1` is $\beta$ in the equations above and `gamma2` is $\gamma$, which is only used where `centered=True`.
In the code snippet above, `rho` is $\beta$ in the equations above and `momentum` is $\gamma$, which is only used where `centered=True`.

### [AdaDelta](/api/python/docs/api/optimizer/index.html#mxnet.optimizer.AdaDelta)

Expand Down Expand Up @@ -281,32 +281,6 @@ Here is how to create the signum optimizer in MXNet.
signum_optimizer = optimizer.Signum(learning_rate=0.01, momentum=0.9, wd_lh=0.0)
```

### [LBSGD](/api/python/docs/api/optimizer/index.html#mxnet.optimizer.LBSGD)
LBSGD stands for Large Batch Stochastic Gradient Descent and implements a technique where Layer-wise Adaptive Rate Scaling (LARS) is used to maintain a separate learning rate for each layer of the neural network. LBSGD has no additional modifications to SGD and performs the same parameter update steps as the SGD optimizer described above.

LBSGD was introduced by [You et al](https://arxiv.org/pdf/1708.03888.pdf) for distributed training with data-parallel synchronous SGD across multiple worker nodes to overcome the issue of reduced model accuracy when the number of workers, and by extension effective batch size, is increased.

Here is how to initialize the LBSGD optimizer in MXNet.


```python
lbsgd_optimizer = optimizer.LBSGD(momentum=0.0,
multi_precision=False,
warmup_strategy='linear',
warmup_epochs=5,
batch_scale=1,
updates_per_epoch=32,
begin_epoch=0,
num_epochs=60)
```

LBSGD has a number of extra keyword arguments described below
* `multi_precision` - When True performs updates with float32 precision weights regardless of whether weights are initialized with lower precision. When False perform updates with same precision as the weights when initialized. Set to True to improve performance when training with low precision weight represenations.
* `warmup_strategy` - The warmup is period where the learning rate is increased through the first few epochs. The following strategies are supported: ['linear', 'power2', 'sqrt','lars']
* `warmup_epochs` - How many epochs to perform warmup for
* `batch_scale` - use batch size*numworkers
* `updates_per_epoch` - How many updates to the learning rate to perform every epoch. For example during warmup the warmup strategy is applied to increase the learning rate a total of `warmup_epochs*updates_per_epoch` number of times.
* `begin_epoch` - The epoch at which to start warmup.

### [DCASGD](/api/python/docs/api/optimizer/index.html#mxnet.optimizer.DCASGD)

Expand Down
4 changes: 2 additions & 2 deletions example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ def fit(args, network, data_loader, **kwargs):
'multi_precision': True}

# Only a limited number of optimizers have 'momentum' property
has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
has_momentum = {'sgd', 'dcasgd', 'nag', 'signum'}
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom

monitor = mx.mon.Monitor(
args.monitor, pattern=".*") if args.monitor > 0 else None

# A limited number of optimizers have a warmup period
has_warmup = {'lbsgd', 'lbnag'}
has_warmup = {'lbnag'}
if args.optimizer in has_warmup:
nworkers = kv.num_workers
if epoch_size < 1:
Expand Down
2 changes: 1 addition & 1 deletion example/profiler/profiler_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_module(ctx, sym, provide_data, provide_label, batch_size=None, is_train=
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=False, inputs_need_grad=False)

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
mod.init_optimizer(optimizer='ccsgd',
mod.init_optimizer(optimizer='sgd',
optimizer_params={
'learning_rate': 0.0001,
'momentum': 0.0,
Expand Down
2 changes: 1 addition & 1 deletion example/speech_recognition/deepspeech.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ optimizer_params_dictionary={"momentum":0.9}
# adagrad
# optimizer_params_dictionary={"eps":1e-08}
# rmsprop
# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08}
# optimizer_params_dictionary={"rho":0.9, "momentum":0.9,"epsilon":1e-08}
# adadelta
# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08}
# set to 0 to disable gradient clipping
Expand Down
2 changes: 1 addition & 1 deletion example/speech_recognition/default.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ optimizer_params_dictionary={"beta1":0.9,"beta2":0.999}
# adagrad
# optimizer_params_dictionary={"eps":1e-08}
# rmsprop
# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08}
# optimizer_params_dictionary={"rho":0.9, "momentum":0.9,"epsilon":1e-08}
# adadelta
# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08}
# set to 0 to disable gradient clipping
Expand Down
40 changes: 21 additions & 19 deletions perl-package/AI-MXNet/lib/AI/MXNet/Optimizer.pm
Original file line number Diff line number Diff line change
Expand Up @@ -1037,12 +1037,13 @@ method update($index, $weight, $grad, $state)
}
else
{
$grad += $wd * $weight;
my $mom = $state;
$mom *= $self->momentum;
$grad += $wd * $weight;
$mom += $grad;
$mom -= $lr * $grad;
$grad *= -$lr;
$grad += $self->momentum * $mom;
$weight += -$lr * $grad;
$weight += $grad;
}
}
else
Expand All @@ -1061,11 +1062,12 @@ method update($index, $weight, $grad, $state)
}
else
{
$grad32 += $wd * $weight32;
$mom *= $self->momentum;
$grad32 += $wd * $weight32;
$mom += $grad32;
$mom -= $lr * $grad32;
$grad32 *= -$lr;
$grad32 += $self->momentum * $mom;
$weight32 += -$lr * $grad32;
$weight32 += $grad32;
}
my $tmp = $weight32->astype($weight->dtype);
$tmp->copyto($weight);
Expand Down Expand Up @@ -1276,7 +1278,7 @@ __PACKAGE__->register;
rescale_grad : Num, optional
rescaling factor of gradient. Normally should be 1/batch_size.

eps: Num, optional
epsilon: Num, optional
A small float number to make the updating processing stable
Default value is set to 1e-7.

Expand All @@ -1288,7 +1290,7 @@ use Mouse;

extends 'AI::MXNet::Optimizer';

has 'eps' => (is => "rw", isa => "Num", default => 1e-7);
has 'epsilon' => (is => "rw", isa => "Num", default => 1e-7);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
Expand All @@ -1314,7 +1316,7 @@ method update(
if($is_sparse)
{
my %kwargs = (
epsilon => $self->eps,
epsilon => $self->epsilon,
rescale_grad => $self->rescale_grad
);
if($self->clip_gradient)
Expand All @@ -1330,9 +1332,10 @@ method update(
{
$grad = AI::MXNet::NDArray->clip($grad, -$self->clip_gradient, $self->clip_gradient);
}
$grad += $wd * $weight;
$history += $grad->square;
my $div = $grad / ($history + $self->eps)->sqrt;
$weight += ($div + $weight * $wd) * -$lr;
my $div = $grad / (($history)->sqrt + $self->epsilon);
$weight += $div * -$lr;
}
}

Expand All @@ -1359,11 +1362,10 @@ __PACKAGE__->register;
learning_rate : Num, optional
Step size.
Default value is set to 0.001.
gamma1: Num, optional
rho: Num, optional
decay factor of moving average for gradient^2.
Default value is set to 0.9.
gamma2: Num, optional
"momentum" factor.
momentum: Num, optional
Default value if set to 0.9.
Only used if centered=True
epsilon : Num, optional
Expand All @@ -1386,8 +1388,8 @@ use Mouse;
extends 'AI::MXNet::Optimizer';

has '+learning_rate' => (default => 0.001);
has 'gamma1' => (is => "ro", isa => "Num", default => 0.9);
has 'gamma2' => (is => "ro", isa => "Num", default => 0.9);
has 'rho' => (is => "ro", isa => "Num", default => 0.9);
has 'momentum' => (is => "ro", isa => "Num", default => 0.9);
has 'epsilon' => (is => "ro", isa => "Num", default => 1e-8);
has 'centered' => (is => "ro", isa => "Bool", default => 0);
has 'clip_weights' => (is => "ro", isa => "Num");
Expand All @@ -1397,12 +1399,12 @@ sub BUILD
{
my $self = shift;
$self->kwargs({
gamma1 => $self->gamma1,
rho => $self->rho,
epsilon => $self->epsilon
});
if($self->centered)
{
$self->kwargs->{gamma2} = $self->gamma2;
$self->kwargs->{momentum} = $self->momentum;
}
if($self->clip_gradient)
{
Expand Down Expand Up @@ -1461,7 +1463,7 @@ method update(
if($self->centered)
{
AI::MXNet::NDArray->rmspropalex_update(
$weight, $grad, $n, $g, $delta,
$weight, $grad, $g, $n, $delta,
{
out => $weight,
lr => $lr,
Expand Down
Loading