Skip to content

Commit

Permalink
Implement RMSprop
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol authored and ronghanghu committed Aug 6, 2015
1 parent d4aa5fe commit ab5cb0b
Show file tree
Hide file tree
Showing 7 changed files with 369 additions and 13 deletions.
29 changes: 29 additions & 0 deletions examples/mnist/lenet_solver_rmsprop.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.0
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type:RMSPROP
rms_decay:0.98


3 changes: 3 additions & 0 deletions examples/mnist/train_lenet_rmsprop.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt
17 changes: 17 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};


template <typename Dtype>
class RMSpropSolver : public SGDSolver<Dtype> {
public:
explicit RMSpropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { }
explicit RMSpropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { }

protected:
virtual void ComputeUpdateValue();

DISABLE_COPY_AND_ASSIGN(RMSpropSolver);
};

template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
Expand All @@ -139,6 +154,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSpropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
17 changes: 17 additions & 0 deletions python/caffe/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,24 @@ class Classifier(caffe.Net):
def __init__(self, model_file, pretrained_file, image_dims=None,
mean=None, input_scale=None, raw_scale=None,
channel_swap=None):
<<<<<<< HEAD
caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
=======
"""
Take
image_dims: dimensions to scale input for cropping/sampling.
Default is to scale to net input size for whole-image crop.
gpu, mean, input_scale, raw_scale, channel_swap: params for
preprocessing options.
"""
caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
caffe.set_phase_test()

if gpu:
caffe.set_mode_gpu()
else:
caffe.set_mode_cpu()
>>>>>>> Implement RMSprop

# configure pre-processing
in_ = self.inputs[0]
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
<<<<<<< HEAD
// SolverParameter next available ID: 37 (last added: iter_size)
=======
// SolverParameter next available ID: 37 (last added: rms_decay)
>>>>>>> Implement RMSprop
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -191,10 +195,14 @@ message SolverParameter {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];

//RMSprop decay value
optional float rms_decay = 36;

// If true, print information about the state of the net that may help with
// debugging learning problems.
Expand Down
188 changes: 188 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
switch (Caffe::mode()) {
<<<<<<< HEAD
case Caffe::CPU: {
// save history momentum for stepping back
caffe_copy(net_params[param_id]->count(),
Expand All @@ -664,6 +665,53 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
caffe_copy(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
=======
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// save history momentum for stepping back
caffe_copy(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

// update history
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
this->history_[param_id]->mutable_cpu_data());

// compute uppate: step back then over step
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
this->history_[param_id]->cpu_data(), -momentum,
this->update_[param_id]->mutable_cpu_data());

// copy
caffe_copy(net_params[param_id]->count(),
this->update_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
}
>>>>>>> Implement RMSprop
break;
}
case Caffe::GPU: {
Expand Down Expand Up @@ -775,9 +823,149 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}

template <typename Dtype>
void RMSpropSolver<Dtype>::ComputeUpdateValue() {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();

// get the learning rate
Dtype rate = this->GetLearningRate();
Dtype delta = this->param_.delta();
Dtype rms_decay = this->param_.rms_decay();

if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

//Compute RMSstep
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history
caffe_cpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
rms_decay, this->history_[param_id]-> mutable_cpu_data());

// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());


caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());

caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(),
this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
}
break;
case Caffe::GPU:
#ifndef CPU_ONLY
for (int param_id = 0; param_id < net_params.size(); ++param_id) {

Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

//Compute RMSstep
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());


caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
}
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSpropSolver);

} // namespace caffe
Loading

0 comments on commit ab5cb0b

Please sign in to comment.