Skip to content

Commit

Permalink
New Optimizers (#159)
Browse files Browse the repository at this point in the history
* Added implementation of RMSProp, AdaGrad, AdaDelta

* Added AdaMax and Nadam
  • Loading branch information
Arkoniak authored and pluskid committed Dec 29, 2016
1 parent c06b211 commit b81b26c
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 6 deletions.
29 changes: 29 additions & 0 deletions docs/src/api/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,32 @@ Modules = [MXNet.mx]
Pages = ["optimizers/adam.jl"]
```

### AdaGrad
```@autodocs
Modules = [MXNet.mx]
Pages = ["optimizers/adagrad.jl"]
```

### AdaDelta
```@autodocs
Modules = [MXNet.mx]
Pages = ["optimizers/adadelta.jl"]
```

### AdaMax
```@autodocs
Modules = [MXNet.mx]
Pages = ["optimizers/adamax.jl"]
```

### RMSProp
```@autodocs
Modules = [MXNet.mx]
Pages = ["optimizers/rmsprop.jl"]
```

### Nadam
```@autodocs
Modules = [MXNet.mx]
Pages = ["optimizers/nadam.jl"]
```
5 changes: 5 additions & 0 deletions examples/mnist/mlp-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ end
function test_mnist_mlp()
@test mnist_fit_and_predict(mx.SGD(lr=0.1, momentum=0.9), mx.UniformInitializer(0.01), 2) > 90
@test mnist_fit_and_predict(mx.ADAM(), mx.NormalInitializer(), 2) > 90
@test mnist_fit_and_predict(mx.AdaGrad(), mx.NormalInitializer(), 2) > 90
@test mnist_fit_and_predict(mx.AdaDelta(), mx.NormalInitializer(), 2) > 90
@test mnist_fit_and_predict(mx.AdaMax(), mx.NormalInitializer(), 2) > 90
@test mnist_fit_and_predict(mx.RMSProp(), mx.NormalInitializer(), 2) > 90
@test mnist_fit_and_predict(mx.Nadam(), mx.NormalInitializer(), 2) > 90
end

test_mnist_mlp()
Expand Down
58 changes: 55 additions & 3 deletions src/optimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,42 @@ type Fixed <: AbstractMomentumScheduler
momentum :: Float64
end
get_momentum(self :: Fixed, state :: OptimizationState) = self.momentum

"""
Momentum.NadamScheduler
Nesterov-accelerated adaptive momentum scheduler.
Description in "Incorporating Nesterov Momentum into Adam."
[http://cs229.stanford.edu/proj2015/054_report.pdf]
(http://cs229.stanford.edu/proj2015/054_report.pdf)
``\mu_t = \mu_0 * (1 - \gamma * \alpha^{t * \delta})``.
Here
* ``t`` is the iteration count
* ``\delta``: default `0.004` is scheduler decay,
* ``\gamma``: default `0.5`
* ``\alpha``: default `0.96`
* ``\mu_0``: default `0.99`
"""
type NadamScheduler <: AbstractMomentumScheduler
mu0 :: Float64
delta :: Float64
gamma :: Float64
alpha :: Float64
end
function NadamScheduler(;mu0::Real=0.99, delta::Real=0.004,
gamma::Real=0.5, alpha::Real=0.96)
@assert(0.0 <= delta)
@assert(0.0 <= alpha <= 1.0)
@assert(0.0 <= mu0 <= 1.0)
@assert(0.0 <= gamma <= 1.0)
NadamScheduler(Float64(mu0), Float64(delta), Float64(gamma), Float64(alpha))
end
get_momentum(self :: NadamScheduler, state :: OptimizationState) =
self.mu0 * (1.0 - self.gamma*self.alpha^(state.curr_iter * self.delta)),
self.mu0 * (1.0 - self.gamma*self.alpha^((state.curr_iter + 1) * self.delta))

end # module Momentum
################################################################################
function get_momentum_scheduler(scheduler :: Any, momentum :: Real)
Expand All @@ -170,6 +206,15 @@ function get_momentum_scheduler(scheduler :: Any, momentum :: Real)
end
end

function get_momentum_scheduler(scheduler :: Any,
another_scheduler :: AbstractMomentumScheduler)

if isa(scheduler, AbstractMomentumScheduler)
return scheduler
else
return another_scheduler
end
end

"""
get_updater(optimizer)
Expand Down Expand Up @@ -198,10 +243,10 @@ Base class for all optimizer options.
abstract AbstractOptimizerOptions

"""
normalized_gradient(opts, state, grad)
normalized_gradient(opts, state, weight, grad)
* `opts::AbstractOptimizerOptions`: options for the optimizer, should contain the field
`grad_scale`, `grad_clip` and `weight_decay`.
`grad_clip` and `weight_decay`.
* `state::OptimizationState`: the current optimization state.
* `weight::NDArray`: the trainable weights.
* `grad::NDArray`: the original gradient of the weights.
Expand All @@ -216,10 +261,17 @@ function normalized_gradient(opts::AbstractOptimizerOptions, state::Optimization
if opts.grad_clip > 0
grad = clip(grad, -opts.grad_clip, opts.grad_clip)
end
@inplace grad += opts.weight_decay * weight
if opts.weight_decay > 0
@inplace grad += opts.weight_decay * weight
end

return grad
end

include("optimizers/sgd.jl")
include("optimizers/adam.jl")
include("optimizers/adagrad.jl")
include("optimizers/adadelta.jl")
include("optimizers/adamax.jl")
include("optimizers/rmsprop.jl")
include("optimizers/nadam.jl")
91 changes: 91 additions & 0 deletions src/optimizers/adadelta.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
@defstruct AdaDeltaOptions <: AbstractOptimizerOptions (
(lr :: Real = 1.0, lr > 0),
(rho :: Real = 0.95, rho > 0 && rho < 1),
(epsilon :: Real = 1e-6, epsilon > 0),
(grad_clip :: Real = 0, grad_clip >= 0),
(weight_decay :: Real = 0.00001, weight_decay >= 0),
lr_scheduler :: Any = nothing
)

"""
AdaDelta
Scale learning rates by the ratio of accumulated gradients to accumulated
updates, see [1] and notes for further description.
AdaDelta(; kwargs...)
# Attributes
* `lr::Real`: default `1.0`, the learning rate controlling the
size of update steps
* `rho::Real`: default `0.9`, squared gradient moving average decay factor
* `epsilon::Real`: default `1e-6`, small value added for
numerical stability
* `grad_clip::Real`: default `0`, if positive, will clip the gradient
into the range `[-grad_clip, grad_clip]`.
* `weight_decay::Real`: default `0.00001`, weight decay is equivalent
to adding a global l2 regularizer for all the parameters.
# Notes
`rho` should be between 0 and 1. A value of `rho` close to 1 will decay the
moving average slowly and a value close to 0 will decay the moving average
fast.
`rho` = 0.95 and `epsilon` = 1e-6 are suggested in the paper and reported to
work for multiple datasets (MNIST, speech). In the paper, no learning rate is
considered (so `lr` = 1.0). Probably best to keep it at this value.
`epsilon` is important for the very first update (so the numerator does
not become 0).
Using the step size `lr` and a decay factor `rho` the learning rate is
calculated as:
``r_t &= \rho r_{t-1} + (1-\rho)*g^2\\
\eta_t &= \eta \frac{\sqrt{s_{t-1} + \epsilon}} {\sqrt{r_t + \epsilon}}\\
s_t &= \rho s_{t-1} + (1-\rho)*(\eta_t*g)^2``
# References
* [1]: Zeiler, M. D. (2012):
ADADELTA: An Adaptive Learning Rate Method. arXiv Preprint arXiv:1212.5701.
"""

type AdaDelta <: AbstractOptimizer
opts :: AdaDeltaOptions
state :: OptimizationState

function AdaDelta(; kwargs...)
opts = AdaDeltaOptions(;kwargs...)
opts.lr_scheduler = get_lr_scheduler(opts.lr_scheduler, opts.lr)

new(opts)
end
end

type AdaDeltaState
acc :: NDArray
delta_acc :: NDArray
end

function create_state(self :: AdaDelta, index :: Int, weight :: NDArray)
return AdaDeltaState(zeros(size(weight), context(weight)),
zeros(size(weight), context(weight)))
end

function update(self :: AdaDelta, index :: Int, weight :: NDArray,
grad :: NDArray, state :: AdaDeltaState)
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
grad = normalized_gradient(self.opts, self.state, weight, grad)

# Update state.acc as in RMSProp
@inplace state.acc .*= self.opts.rho
@inplace state.acc .+= (1 - self.opts.rho) * grad .* grad

# Compute update using the "old" state.delta_acc
update = grad .* sqrt(state.delta_acc + self.opts.epsilon) ./
(sqrt(state.acc + self.opts.epsilon))
@inplace weight .+= -lr * update

# update state.delta_acc using update
@inplace state.delta_acc .*= self.opts.rho
@inplace state.delta_acc .+= (1 - self.opts.rho) * update .* update
end
66 changes: 66 additions & 0 deletions src/optimizers/adagrad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
@defstruct AdaGradOptions <: AbstractOptimizerOptions (
(lr :: Real = 0.1, lr > 0),
(epsilon :: Real = 1e-6, epsilon > 0),
(grad_clip :: Real = 0, grad_clip >= 0),
(weight_decay :: Real = 0.00001, weight_decay >= 0),
lr_scheduler :: Any = nothing
)

"""
AdaGrad
Scale learning rates by dividing with the square root of accumulated
squared gradients. See [1] for further description.
AdaGrad(; kwargs...)
# Attributes
* `lr::Real`: default `0.1`, the learning rate controlling the
size of update steps
* `epsilon::Real`: default `1e-6`, small value added for
numerical stability
* `grad_clip::Real`: default `0`, if positive, will clip the gradient
into the range `[-grad_clip, grad_clip]`.
* `weight_decay::Real`: default `0.00001`, weight decay is equivalent
to adding a global l2 regularizer for all the parameters.
# Notes
Using step size lr AdaGrad calculates the learning rate for feature i at
time step t as:
``η_{t,i} = \frac{lr}{\sqrt{\sum^t_{t^\prime} g^2_{t^\prime,i} + ϵ}} g_{t,i}``
as such the learning rate is monotonically decreasing.
Epsilon is not included in the typical formula, see [2].
# References
* [1]: Duchi, J., Hazan, E., & Singer, Y. (2011):
Adaptive subgradient methods for online learning and
stochastic optimization. JMLR, 12:2121-2159.
* [2]: Chris Dyer: Notes on AdaGrad.
[http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf]
(http://www.ark.cs.cmu.edu/cdyer/adagrad.pdf)
"""

type AdaGrad <: AbstractOptimizer
opts :: AdaGradOptions
state :: OptimizationState

function AdaGrad(; kwargs...)
opts = AdaGradOptions(;kwargs...)
opts.lr_scheduler = get_lr_scheduler(opts.lr_scheduler, opts.lr)

new(opts)
end
end

function create_state(self :: AdaGrad, index :: Int, weight :: NDArray)
return zeros(size(weight), context(weight))
end

function update(self :: AdaGrad, index :: Int, weight :: NDArray,
grad :: NDArray, state :: NDArray)
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
grad = normalized_gradient(self.opts, self.state, weight, grad)

@inplace state .+= grad .* grad
@inplace weight .+= -lr * grad ./ (sqrt(state + self.opts.epsilon))
end
6 changes: 3 additions & 3 deletions src/optimizers/adam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ function update(self :: ADAM, index :: Int, weight :: NDArray, grad :: NDArray,
state.mt = self.opts.beta1 * state.mt + (1 - self.opts.beta1) * grad
state.vt = self.opts.beta2 * state.vt + (1 - self.opts.beta2) * (grad .* grad)

mt = state.mt / (1 - state.beta1Power)
vt = state.vt / (1 - state.beta2Power)
at = sqrt(1.0 - state.beta2Power)/(1.0 - state.beta1Power)

state.beta1Power *= self.opts.beta1
state.beta2Power *= self.opts.beta2

@inplace weight .+= -lr * mt ./ (sqrt(vt) + self.opts.epsilon)
@inplace weight .+= -lr * at * state.mt ./
(sqrt(state.vt) + self.opts.epsilon)
end
77 changes: 77 additions & 0 deletions src/optimizers/adamax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
@defstruct AdaMaxOptions <: AbstractOptimizerOptions (
(lr :: Real = 0.002, lr > 0),
(beta1 :: Real = 0.9, beta1 > 0 && beta1 < 1),
(beta2 :: Real = 0.999, beta2 > 0 && beta2 < 1),
(epsilon :: Real = 1e-8, epsilon > 0),
(grad_clip :: Real = 0, grad_clip >= 0),
(weight_decay :: Real = 0.00001, weight_decay >= 0),
lr_scheduler :: Any = nothing
)

"""
AdaMax
This is a variant of of the Adam algorithm based on the infinity norm.
See [1] for further description.
AdaMax(; kwargs...)
# Attributes
* `lr::Real`: default `0.002`, the learning rate controlling the
size of update steps
* `beta1::Real`: default `0.9`, exponential decay rate
for the first moment estimates
* `beta2::Real`: default `0.999`, exponential decay rate for the
weighted infinity norm estimates
* `epsilon::Real`: default `1e-8`, small value added for
numerical stability
* `grad_clip::Real`: default `0`, if positive, will clip the gradient
into the range `[-grad_clip, grad_clip]`.
* `weight_decay::Real`: default `0.00001`, weight decay is equivalent
to adding a global l2 regularizer for all the parameters.
# References
* [1]: Kingma, Diederik, and Jimmy Ba (2014):
Adam: A Method for Stochastic Optimization.
[http://arxiv.org/abs/1412.6980v8]
(http://arxiv.org/abs/1412.6980v8).
"""

type AdaMax <: AbstractOptimizer
opts :: AdaMaxOptions
state :: OptimizationState

function AdaMax(; kwargs...)
opts = AdaMaxOptions(; kwargs...)
opts.lr_scheduler = get_lr_scheduler(opts.lr_scheduler, opts.lr)

new(opts)
end
end

type AdaMaxState
mt :: NDArray
ut :: NDArray
beta1Power :: Float64
end

function create_state(self :: AdaMax, index :: Int, weight :: NDArray)
return AdaMaxState( zeros(size(weight), context(weight)),
zeros(size(weight), context(weight)),
self.opts.beta1 )
end

function update(self :: AdaMax, index :: Int, weight :: NDArray,
grad :: NDArray, state :: AdaMaxState)
lr = get_learning_rate(self.opts.lr_scheduler, self.state)
grad = normalized_gradient(self.opts, self.state, weight, grad)

@inplace state.mt .*= self.opts.beta1
@inplace state.mt .+= (1 - self.opts.beta1) * grad
state.ut = _maximum(self.opts.beta2 * state.ut, abs(grad))

@inplace weight .+= - lr / (1 - state.beta1Power) *
state.mt ./ (state.ut + self.opts.epsilon)

state.beta1Power *= self.opts.beta1
end
Loading

0 comments on commit b81b26c

Please sign in to comment.