Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-#16167] Refactor Optimizer (#17400)
Browse files Browse the repository at this point in the history
* refactor optimizer

* refactor optimizer

* fix svrg test

* fix rmsprop param naming

* fix signum test

* fix pylint and perl test

* fix perl test and signsgd test

* fix

* retrigger ci

* reduce ci overheads
  • Loading branch information
szhengac authored Feb 29, 2020
1 parent 88b3051 commit f70c7b7
Show file tree
Hide file tree
Showing 48 changed files with 3,934 additions and 3,278 deletions.
14 changes: 7 additions & 7 deletions R-package/R/optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ mx.opt.sgd <- function(learning.rate = 0.01,
#'
#' @param learning.rate float, default=0.002
#' The initial learning rate.
#' @param gamma1 float, default=0.95
#' @param rho float, default=0.95
#' decay factor of moving average for gradient, gradient^2.
#' @param gamma2 float, default=0.9
#' @param momentum float, default=0.9
#' "momentum" factor.
#' @param epsilon float, default=1e-4
#' @param wd float, default=0.0
Expand All @@ -125,8 +125,8 @@ mx.opt.sgd <- function(learning.rate = 0.01,
#'
mx.opt.rmsprop <- function(learning.rate = 0.002,
centered = TRUE,
gamma1 = 0.95,
gamma2 = 0.9,
rho = 0.95,
momentum = 0.9,
epsilon = 1e-4,
wd = 0,
rescale.grad = 1,
Expand Down Expand Up @@ -158,8 +158,8 @@ mx.opt.rmsprop <- function(learning.rate = 0.002,
g,
delta,
lr = lr,
gamma1 = gamma1,
gamma2 = gamma2,
rho = rho,
momentum = momentum,
epsilon = epsilon,
wd = wd,
rescale_grad = rescale.grad,
Expand All @@ -174,7 +174,7 @@ mx.opt.rmsprop <- function(learning.rate = 0.002,
grad,
n,
lr = lr,
gamma1 = gamma1,
rho = rho,
epsilon = epsilon,
wd = wd,
rescale_grad = rescale.grad,
Expand Down
4 changes: 2 additions & 2 deletions R-package/tests/testthat/test_optimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ test_that("rmsprop", {
fc1_weight = w1, label = y), aux.arrays = NULL, grad.reqs = c("null", "write",
"null"))

optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, gamma1 = 0.95,
gamma2 = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)
optimizer <- mx.opt.create("rmsprop", learning.rate = 1, centered = TRUE, rho = 0.95,
momentum = 0.9, epsilon = 1e-04, wd = 0, rescale.grad = 1, clip_gradient = -1)

updaters <- mx.opt.get.updater(optimizer, exec$ref.arg.arrays, ctx = mx.ctx.default())

Expand Down
59 changes: 28 additions & 31 deletions benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@
DEFAULT_R1 = [(1, 1024), (1, 1), (1, 100)]
DEFAULT_R2 = [(1, 1024), (1, 1), (1, 100)]
DEFAULT_DELTA = [(1024, 1024), (10000, 1), (10000, 100)]
DEFAULT_LRS = [(0.1, 0.1)]
DEFAULT_LR = [0.1, 0.5, 0.9]
DEFAULT_GAMMA_1 = [0.1, 0.5, 0.9]
DEFAULT_GAMMA_2 = [0.1, 0.5, 0.9]
DEFAULT_LRS = [(0.1,0.1)]
DEFAULT_LR = [0.1,0.5,0.9]
DEFAULT_RHO = [0.1,0.5,0.9]
DEFAULT_MOMENTUM = [0.1,0.5,0.9]
DEFAULT_EPSILON = [1e-08]
DEFAULT_BETA_1 = [0.1, 0.5, 0.9]
DEFAULT_BETA_2 = [0.1, 0.5, 0.9]
Expand Down Expand Up @@ -417,33 +417,30 @@
"p_nd": DEFAULT_P_ND,
"axis_shape": DEFAULT_AXIS_SHAPE,
"axis": DEFAULT_AXIS,
"weight": DEFAULT_WEIGHT,
"weight32": DEFAULT_WEIGHT,
"grad": DEFAULT_GRAD,
"mean": DEFAULT_MEAN,
"var": DEFAULT_VAR,
"mom": DEFAULT_MOM,
"r1": DEFAULT_R1,
"r2": DEFAULT_R2,
"n": DEFAULT_N,
"d": DEFAULT_D,
"v": DEFAULT_V,
"z": DEFAULT_Z,
"g": DEFAULT_G,
"delta": DEFAULT_DELTA,
"lr": DEFAULT_LR,
"lrs": DEFAULT_LRS,
"wds": DEFAULT_LRS,
"wd": DEFAULT_LR,
"gamma1": DEFAULT_GAMMA_1,
"gamma2": DEFAULT_GAMMA_2,
"epsilon": DEFAULT_EPSILON,
"beta1": DEFAULT_BETA_1,
"beta2": DEFAULT_BETA_2,
"t": DEFAULT_T,
"rescale_grad": DEFAULT_RESCALE_GRAD,
"clip_grad": DEFAULT_CLIP_GRADIENT,
"lazy_update": DEFAULT_LAZY_UPDATE,
"weight" : DEFAULT_WEIGHT,
"weight32" : DEFAULT_WEIGHT,
"grad" : DEFAULT_GRAD,
"mean" : DEFAULT_MEAN,
"var" : DEFAULT_VAR,
"mom" : DEFAULT_MOM,
"n" : DEFAULT_N,
"d" : DEFAULT_D,
"v" : DEFAULT_V,
"z" : DEFAULT_Z,
"g" : DEFAULT_G,
"delta" : DEFAULT_DELTA,
"lr" : DEFAULT_LR,
"lrs" : DEFAULT_LRS,
"wds" : DEFAULT_LRS,
"rho" : DEFAULT_RHO,
"momentum" : DEFAULT_MOMENTUM,
"epsilon" : DEFAULT_EPSILON,
"beta1" : DEFAULT_BETA_1,
"beta2" : DEFAULT_BETA_2,
"t" : DEFAULT_T,
"rescale_grad" : DEFAULT_RESCALE_GRAD,
"clip_grad" : DEFAULT_CLIP_GRADIENT,
"lazy_update" : DEFAULT_LAZY_UPDATE,
"data_4d": DEFAULT_DATA_4d,
"dim1": DEFAULT_DIM_1,
"dim2": DEFAULT_DIM_2,
Expand Down
28 changes: 14 additions & 14 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,30 @@
([]
(ada-delta {})))

(s/def gamma1 number?)
(s/def gamma2 number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::gamma1 ::gamma2 ::wd ::clip-gradient]))
(s/def rho number?)
(s/def momentum number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::rho ::momentum ::wd ::clip-gradient]))

(defn rms-prop
"RMSProp optimizer as described in Tieleman & Hinton, 2012.
http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
- learningRate Step size.
- gamma1 decay factor of moving average for gradient, gradient^^2.
- gamma2 momentum factor of moving average for gradient.
- rescale-gradient rescaling factor of gradient.
- wd L2 regularization coefficient add to all the weights
- clip-gradient clip gradient in range [-clip_gradient, clip_gradient]
- lr-scheduler The learning rate scheduler"
([{:keys [learning-rate rescale-gradient gamma1 gamma2 wd lr-scheduler clip-gradient] :as opts
- rho decay factor of moving average for gradient, gradient^^2.
- momentum momentum factor of moving average for gradient.
- rescale-gradient rescaling factor of gradient.
- wd L2 regularization coefficient add to all the weights
- clip-gradient clip gradient in range [-clip_gradient, clip_gradient]
- lr-scheduler The learning rate scheduler"
([{:keys [learning-rate rescale-gradient rho momentum wd lr-scheduler clip-gradient] :as opts
:or {learning-rate 0.002
rescale-gradient 1.0
gamma1 0.95
gamma2 0.9
rho 0.95
momentum 0.9
wd 0.0
clip-gradient 0}}]
(util/validate! ::rms-prop-opts opts "Incorrect rms-prop optimizer options")
(new RMSProp (float learning-rate) (float rescale-gradient) (float gamma1)
(float gamma2) (float wd) lr-scheduler (float clip-gradient)))
(new RMSProp (float learning-rate) (float rescale-gradient) (float rho)
(float momentum) (float wd) lr-scheduler (float clip-gradient)))
([]
(rms-prop {})))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
(is (thrown? Exception (optimizer/dcasgd {:lambda 'a})))
(is (thrown? Exception (optimizer/nag {:momentum 'a})))
(is (thrown? Exception (optimizer/ada-delta {:epsilon 'a})))
(is (thrown? Exception (optimizer/rms-prop {:gamma1 'a})))
(is (thrown? Exception (optimizer/rms-prop {:rho 'a})))
(is (thrown? Exception (optimizer/ada-grad {:rescale-gradient 'a})))
(is (thrown? Exception (optimizer/adam {:beta1 'a})))
(is (thrown? Exception (optimizer/sgld {:lr-scheduler 0.1}))))
2 changes: 1 addition & 1 deletion cpp-package/example/charRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ void trainWithBuiltInRNNOp(const std::string file, int batch_size, int max_epoch
}
start_epoch++;

Optimizer* opt = OptimizerRegistry::Find("ccsgd");
Optimizer* opt = OptimizerRegistry::Find("sgd");
// opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
// ->SetParam("clip_gradient", 10);

Expand Down
2 changes: 1 addition & 1 deletion cpp-package/example/lenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Lenet {
// args_map["fc1_b"] = 0;

lenet.InferArgsMap(ctx_dev, &args_map, args_map);
Optimizer* opt = OptimizerRegistry::Find("ccsgd");
Optimizer* opt = OptimizerRegistry::Find("sgd");
opt->SetParam("momentum", 0.9)
->SetParam("rescale_grad", 1.0)
->SetParam("clip_gradient", 10)
Expand Down
5 changes: 2 additions & 3 deletions cpp-package/include/mxnet-cpp/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ inline Optimizer* OptimizerRegistry::Find(const std::string& name) {
if (cmap().empty()) {
// Optimizers should only be registered once
MXNETCPP_REGISTER_OPTIMIZER(sgd, SGDOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(ccsgd, SGDOptimizer); // For backward compatibility
MXNETCPP_REGISTER_OPTIMIZER(rmsprop, RMSPropOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adam, AdamOptimizer);
MXNETCPP_REGISTER_OPTIMIZER(adagrad, AdaGradOptimizer);
Expand Down Expand Up @@ -271,8 +270,8 @@ inline RMSPropOptimizer::RMSPropOptimizer(unsigned begin_num_update)
: Optimizer(begin_num_update) {
update_handle_ = op_map()->GetSymbolCreator("rmsprop_update");
alex_update_handle_ = op_map()->GetSymbolCreator("rmspropalex_update");
SetParam("gamma1", 0.9f);
SetParam("gamma2", 0.9f);
SetParam("rho", 0.9f);
SetParam("momentum", 0.9f);
SetParam("epsilon", 1e-8);
}

Expand Down
30 changes: 2 additions & 28 deletions docs/python_docs/python/tutorials/packages/optimizer/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ Here is an example snippet creating the RMSProp optimizer in MXNet.


```python
rmsprop_optimizer = optimizer.RMSProp(learning_rate=0.001, gamma1=0.9, gamma2=0.9, epsilon=1e-07, centered=False)
rmsprop_optimizer = optimizer.RMSProp(learning_rate=0.001, rho=0.9, momentum=0.9, epsilon=1e-07, centered=False)
```

In the code snippet above, `gamma1` is $\beta$ in the equations above and `gamma2` is $\gamma$, which is only used where `centered=True`.
In the code snippet above, `rho` is $\beta$ in the equations above and `momentum` is $\gamma$, which is only used where `centered=True`.

### [AdaDelta](/api/python/docs/api/optimizer/index.html#mxnet.optimizer.AdaDelta)

Expand Down Expand Up @@ -281,32 +281,6 @@ Here is how to create the signum optimizer in MXNet.
signum_optimizer = optimizer.Signum(learning_rate=0.01, momentum=0.9, wd_lh=0.0)
```

### [LBSGD](/api/python/docs/api/optimizer/index.html#mxnet.optimizer.LBSGD)
LBSGD stands for Large Batch Stochastic Gradient Descent and implements a technique where Layer-wise Adaptive Rate Scaling (LARS) is used to maintain a separate learning rate for each layer of the neural network. LBSGD has no additional modifications to SGD and performs the same parameter update steps as the SGD optimizer described above.

LBSGD was introduced by [You et al](https://arxiv.org/pdf/1708.03888.pdf) for distributed training with data-parallel synchronous SGD across multiple worker nodes to overcome the issue of reduced model accuracy when the number of workers, and by extension effective batch size, is increased.

Here is how to initialize the LBSGD optimizer in MXNet.


```python
lbsgd_optimizer = optimizer.LBSGD(momentum=0.0,
multi_precision=False,
warmup_strategy='linear',
warmup_epochs=5,
batch_scale=1,
updates_per_epoch=32,
begin_epoch=0,
num_epochs=60)
```

LBSGD has a number of extra keyword arguments described below
* `multi_precision` - When True performs updates with float32 precision weights regardless of whether weights are initialized with lower precision. When False perform updates with same precision as the weights when initialized. Set to True to improve performance when training with low precision weight represenations.
* `warmup_strategy` - The warmup is period where the learning rate is increased through the first few epochs. The following strategies are supported: ['linear', 'power2', 'sqrt','lars']
* `warmup_epochs` - How many epochs to perform warmup for
* `batch_scale` - use batch size*numworkers
* `updates_per_epoch` - How many updates to the learning rate to perform every epoch. For example during warmup the warmup strategy is applied to increase the learning rate a total of `warmup_epochs*updates_per_epoch` number of times.
* `begin_epoch` - The epoch at which to start warmup.

### [DCASGD](/api/python/docs/api/optimizer/index.html#mxnet.optimizer.DCASGD)

Expand Down
4 changes: 2 additions & 2 deletions example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ def fit(args, network, data_loader, **kwargs):
'multi_precision': True}

# Only a limited number of optimizers have 'momentum' property
has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
has_momentum = {'sgd', 'dcasgd', 'nag', 'signum'}
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom

monitor = mx.mon.Monitor(
args.monitor, pattern=".*") if args.monitor > 0 else None

# A limited number of optimizers have a warmup period
has_warmup = {'lbsgd', 'lbnag'}
has_warmup = {'lbnag'}
if args.optimizer in has_warmup:
nworkers = kv.num_workers
if epoch_size < 1:
Expand Down
2 changes: 1 addition & 1 deletion example/profiler/profiler_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_module(ctx, sym, provide_data, provide_label, batch_size=None, is_train=
mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=False, inputs_need_grad=False)

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
mod.init_optimizer(optimizer='ccsgd',
mod.init_optimizer(optimizer='sgd',
optimizer_params={
'learning_rate': 0.0001,
'momentum': 0.0,
Expand Down
2 changes: 1 addition & 1 deletion example/speech_recognition/deepspeech.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ optimizer_params_dictionary={"momentum":0.9}
# adagrad
# optimizer_params_dictionary={"eps":1e-08}
# rmsprop
# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08}
# optimizer_params_dictionary={"rho":0.9, "momentum":0.9,"epsilon":1e-08}
# adadelta
# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08}
# set to 0 to disable gradient clipping
Expand Down
2 changes: 1 addition & 1 deletion example/speech_recognition/default.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ optimizer_params_dictionary={"beta1":0.9,"beta2":0.999}
# adagrad
# optimizer_params_dictionary={"eps":1e-08}
# rmsprop
# optimizer_params_dictionary={"gamma1":0.9, "gamma2":0.9,"epsilon":1e-08}
# optimizer_params_dictionary={"rho":0.9, "momentum":0.9,"epsilon":1e-08}
# adadelta
# optimizer_params_dictionary={"rho":0.95, "epsilon":1e-08}
# set to 0 to disable gradient clipping
Expand Down
Loading

0 comments on commit f70c7b7

Please sign in to comment.