Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam solver #2856

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/mnist/lenet_solver_adam.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# The train/test net protocol buffer definition
# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION"
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# All parameters are from the cited paper above
base_lr: 0.001
momentum: 0.9
momentum2: 0.999
# since Adam dynamically changes the learning rate, we set the base learning
# rate to a fixed value
lr_policy: "fixed"
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_type: ADAM
solver_mode: GPU
3 changes: 3 additions & 0 deletions examples/mnist/train_lenet_adam.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt
19 changes: 19 additions & 0 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,23 @@ class RMSPropSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};

template <typename Dtype>
class AdamSolver : public SGDSolver<Dtype> {
public:
explicit AdamSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { initAdam();}
explicit AdamSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { initAdam(); }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdamSolver);

void initAdam();
vector<shared_ptr<Blob<Dtype> > > m_, v_;
};

template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
Expand All @@ -171,6 +188,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSPropSolver<Dtype>(param);
case SolverParameter_SolverType_ADAM:
return new AdamSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
Expand Down
7 changes: 5 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 39 (last added: rms_decay)
// SolverParameter next available ID: 40 (last added: momentum2)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -215,10 +215,13 @@ message SolverParameter {
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
ADAM = 4;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
// numerical stability for AdaGrad and Adam
optional float delta = 31 [default = 1e-8];
// parameters for the Adam solver
optional float momentum2 = 39 [default = 0.999];

// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
Expand Down
136 changes: 135 additions & 1 deletion src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,10 +934,144 @@ void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}


// template <typename Dtype>
// void AdamSolver<Dtype>::SnapshotSolverStateToBinaryProto(
// const string& model_filename) {
// SolverState state;
// state.set_iter(this->iter_);
// state.set_learned_net(model_filename);
// state.set_current_step(this->current_step_);
// state.clear_history();
// state.clear_history_m();
// state.clear_history_v();
// for (int i = 0; i < SGDSolver<Dtype>::history_.size(); ++i) {
// // Add history
// BlobProto* history_blob = state.add_history();
// SGDSolver<Dtype>::history_[i]->ToProto(history_blob);

// BlobProto* history_blob_m = state.add_history_m();
// m_[i]->ToProto(history_blob_m);

// BlobProto* history_blob_v = state.add_history_v();
// v_[i]->ToProto(history_blob_v);
// }
// string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
// LOG(INFO)
// << "Snapshotting solver state to binary proto file" << snapshot_filename;
// WriteProtoToBinaryFile(state, snapshot_filename.c_str());
// }


template <typename Dtype>
void AdamSolver<Dtype>::initAdam() {
// Initialize the history
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
m_.clear();
v_.clear();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
m_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
v_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}

template <typename Dtype>
void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with #2866, use const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); instead.

const vector<float>& net_params_lr = this->net_->params_lr();
Dtype local_rate = rate * net_params_lr[param_id];
const Dtype beta1 = this->param_.momentum();
const Dtype beta2 = this->param_.momentum2();

// we create aliases for convenience
// shared_ptr<Blob<Dtype> > val_m = this->history_[param_id];
shared_ptr<Blob<Dtype> > val_m = this->m_[param_id];
// shared_ptr<Blob<Dtype> > val_v = this->update_[param_id];
shared_ptr<Blob<Dtype> > val_v = this->v_[param_id];
shared_ptr<Blob<Dtype> > val_t = this->temp_[param_id];

const int t = this->iter_ + 1;
const Dtype correction = std::sqrt(Dtype(1)-pow(beta2, t))/
(Dtype(1.)-pow(beta1, t));
const int N = net_params[param_id]->count();
const Dtype eps_hat = this->param_.delta();

switch (Caffe::mode()) {
case Caffe::CPU: {
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_cpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->cpu_diff(), beta1,
val_m->mutable_cpu_data());

// update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
caffe_mul(N,
net_params[param_id]->cpu_diff(),
net_params[param_id]->cpu_diff(),
val_t->mutable_cpu_data());
caffe_cpu_axpby(N, Dtype(1)-beta2,
val_t->cpu_data(), beta2,
val_v->mutable_cpu_data());

// set update
caffe_powx(N,
val_v->cpu_data(), Dtype(0.5),
val_t->mutable_cpu_data());
caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
caffe_div(N,
val_m->cpu_data(),
val_t->cpu_data(),
val_t->mutable_cpu_data());

caffe_cpu_scale(N, local_rate*correction,
val_t->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
// update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
caffe_gpu_axpby(N, Dtype(1)-beta1,
net_params[param_id]->gpu_diff(), beta1,
val_m->mutable_gpu_data());

// update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
caffe_gpu_mul(N,
net_params[param_id]->gpu_diff(),
net_params[param_id]->gpu_diff(),
val_t->mutable_gpu_data());
caffe_gpu_axpby(N, Dtype(1)-beta2,
val_t->gpu_data(), beta2,
val_v->mutable_gpu_data());

// set update
caffe_gpu_powx(N,
val_v->gpu_data(), Dtype(0.5),
val_t->mutable_gpu_data());
caffe_gpu_add_scalar(N, eps_hat,
val_t->mutable_gpu_data());
caffe_gpu_div(N,
val_m->gpu_data(),
val_t->gpu_data(),
val_t->mutable_gpu_data());

caffe_gpu_scale(N, local_rate*correction,
val_t->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}

INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
INSTANTIATE_CLASS(RMSPropSolver);

INSTANTIATE_CLASS(AdamSolver);
} // namespace caffe
42 changes: 42 additions & 0 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
Blob<Dtype>& updated_bias = *(*updated_params)[1];
updated_bias.ReshapeLike(bias);

const Dtype momentum2 = 0.999;
Dtype val_m = 0;
Dtype val_v = 0;

for (int i = 0; i <= D; ++i) {
// Compute the derivative with respect to the ith weight (i.e., the ith
// element of the gradient).
Expand Down Expand Up @@ -289,6 +293,15 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
+ grad * grad * (1 - rms_decay)) + delta_;
}
break;
case SolverParameter_SolverType_ADAM: {
const Dtype m_ = 0.0;
const Dtype v_ = 0.0;
val_m = (1-momentum)*grad + momentum*m_;
val_v = (1-momentum2)*grad*grad + momentum2*v_;
Dtype alpha_t = learning_rate*std::sqrt(1-momentum2)/(1-momentum);
update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_);
break;
}
default:
LOG(FATAL) << "Unknown solver type: " << solver_type();
}
Expand Down Expand Up @@ -739,6 +752,35 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
}


template <typename TypeParam>
class AdamSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;

protected:
virtual void InitSolver(const SolverParameter& param) {
SolverParameter new_param = param;
const Dtype momentum = 0.9;
new_param.set_momentum(momentum);
const Dtype momentum2 = 0.999;
new_param.set_momentum2(momentum2);
this->solver_.reset(new AdamSolver<Dtype>(new_param));
}
virtual SolverParameter_SolverType solver_type() {
return SolverParameter_SolverType_ADAM;
}
};

TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices);

TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.001;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.9;
const int k = 0;
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, k);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add 4 more test cases:

  • TestAdamLeastSquaresUpdateLROneTenth (where you set kLearningRate = 0.1).
  • TestAdamLeastSquaresUpdateWithWeightDecay (where you may set kWeightDecay = 0.5).
  • TestAdamLeastSquaresUpdateWithEverything (where you may set kWeightDecay = 0.5 and kNumIters = 4 and loop for kNumIters times)
  • TestLeastSquaresUpdateWithEverythingAccum (where you may set kNumIters = 4 and kIterSize = 2 and use CheckAccumulation)

You may take a look at those AdaGradSolverTest for details.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: #2836 and #2866 modifies solver as well as introduces new share param and snapshot tests. To be consistent, let's also add 4 new test cases:

TestLeastSquaresUpdateWithEverythingShare
TestLeastSquaresUpdateWithEverythingAccumShare
TestSnapshot
TestSnapshotShare

For TestSnapshot, you can take a look at TestSnapshot in SGDSolverTest as an example. For the other 3 shared cases, you just need to add this->share_ = true; to the corresponding test cases.

template <typename TypeParam>
class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
Expand Down