Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMSprop implementation based on G. Hinton Lecture 6 #1890

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/mnist/lenet_solver_rmsprop.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.0
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type:RMSPROP
rms_decay:0.98


3 changes: 3 additions & 0 deletions examples/mnist/train_lenet_rmsprop.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt
17 changes: 17 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,21 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};


template <typename Dtype>
class RMSpropSolver : public SGDSolver<Dtype> {
public:
explicit RMSpropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { }
explicit RMSpropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { }

protected:
virtual void ComputeUpdateValue();

DISABLE_COPY_AND_ASSIGN(RMSpropSolver);
};

template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
Expand All @@ -136,6 +151,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSpropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
2 changes: 1 addition & 1 deletion python/caffe/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, model_file, pretrained_file, image_dims=None,
gpu, mean, input_scale, raw_scale, channel_swap: params for
preprocessing options.
"""
caffe.Net.__init__(self, model_file, pretrained_file)
caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
caffe.set_phase_test()

if gpu:
Expand Down
6 changes: 5 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 36 (last added: clip_gradients)
// SolverParameter next available ID: 37 (last added: rms_decay)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -168,10 +168,14 @@ message SolverParameter {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];

//RMSprop decay value
optional float rms_decay = 36;

// If true, print information about the state of the net that may help with
// debugging learning problems.
Expand Down
142 changes: 141 additions & 1 deletion src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue() {
net_params[param_id]->cpu_diff(), momentum,
this->history_[param_id]->mutable_cpu_data());

// compute udpate: step back then over step
// compute uppate: step back then over step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you almost fixed the typo :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah :)

caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
this->history_[param_id]->cpu_data(), -momentum,
this->update_[param_id]->mutable_cpu_data());
Expand Down Expand Up @@ -829,9 +829,149 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
}
}

template <typename Dtype>
void RMSpropSolver<Dtype>::ComputeUpdateValue() {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();

// get the learning rate
Dtype rate = this->GetLearningRate();
Dtype delta = this->param_.delta();
Dtype rms_decay = this->param_.rms_decay();

if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else if (regularization_type == "L1") {
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
this->temp_[param_id]->mutable_cpu_data());
caffe_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

//Compute RMSstep
// compute square of gradient in update
caffe_powx(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(), Dtype(2),
this->update_[param_id]->mutable_cpu_data());

// update history
caffe_cpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
rms_decay, this->history_[param_id]-> mutable_cpu_data());

// prepare update
caffe_powx(net_params[param_id]->count(),
this->history_[param_id]->cpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_cpu_data());


caffe_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_cpu_data());

caffe_div(net_params[param_id]->count(),
net_params[param_id]->cpu_diff(),
this->update_[param_id]->cpu_data(),
this->update_[param_id]->mutable_cpu_data());

// scale and copy
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->cpu_data(), Dtype(0),
net_params[param_id]->mutable_cpu_diff());
}
break;
case Caffe::GPU:
#ifndef CPU_ONLY
for (int param_id = 0; param_id < net_params.size(); ++param_id) {

Dtype local_rate = rate * net_params_lr[param_id];
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];

if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
this->temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
this->temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}

//Compute RMSstep
// compute square of gradient in update
caffe_gpu_powx(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), Dtype(2),
this->update_[param_id]->mutable_gpu_data());

// update history
caffe_gpu_axpby(net_params[param_id] -> count(),
Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
rms_decay, this->history_[param_id]-> mutable_gpu_data());

// prepare update
caffe_gpu_powx(net_params[param_id]->count(),
this->history_[param_id]->gpu_data(), Dtype(0.5),
this->update_[param_id]->mutable_gpu_data());


caffe_gpu_add_scalar(net_params[param_id]->count(),
delta, this->update_[param_id]->mutable_gpu_data());

caffe_gpu_div(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(),
this->update_[param_id]->gpu_data(),
this->update_[param_id]->mutable_gpu_data());

caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
this->update_[param_id]->gpu_data(), Dtype(0),
net_params[param_id]->mutable_gpu_diff());
}
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSpropSolver);

} // namespace caffe
Loading