-
Notifications
You must be signed in to change notification settings - Fork 18.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adam solver #2856
Adam solver #2856
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# The train/test net protocol buffer definition | ||
# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION" | ||
net: "examples/mnist/lenet_train_test.prototxt" | ||
# test_iter specifies how many forward passes the test should carry out. | ||
# In the case of MNIST, we have test batch size 100 and 100 test iterations, | ||
# covering the full 10,000 testing images. | ||
test_iter: 100 | ||
# Carry out testing every 500 training iterations. | ||
test_interval: 500 | ||
# All parameters are from the cited paper above | ||
base_lr: 0.001 | ||
momentum: 0.9 | ||
momentum2: 0.999 | ||
# since Adam dynamically changes the learning rate, we set the base learning | ||
# rate to a fixed value | ||
lr_policy: "fixed" | ||
# Display every 100 iterations | ||
display: 100 | ||
# The maximum number of iterations | ||
max_iter: 10000 | ||
# snapshot intermediate results | ||
snapshot: 5000 | ||
snapshot_prefix: "examples/mnist/lenet" | ||
# solver mode: CPU or GPU | ||
solver_type: ADAM | ||
solver_mode: GPU |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/usr/bin/env sh | ||
|
||
./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,6 +236,10 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { | |
Blob<Dtype>& updated_bias = *(*updated_params)[1]; | ||
updated_bias.ReshapeLike(bias); | ||
|
||
const Dtype momentum2 = 0.999; | ||
Dtype val_m = 0; | ||
Dtype val_v = 0; | ||
|
||
for (int i = 0; i <= D; ++i) { | ||
// Compute the derivative with respect to the ith weight (i.e., the ith | ||
// element of the gradient). | ||
|
@@ -289,6 +293,15 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { | |
+ grad * grad * (1 - rms_decay)) + delta_; | ||
} | ||
break; | ||
case SolverParameter_SolverType_ADAM: { | ||
const Dtype m_ = 0.0; | ||
const Dtype v_ = 0.0; | ||
val_m = (1-momentum)*grad + momentum*m_; | ||
val_v = (1-momentum2)*grad*grad + momentum2*v_; | ||
Dtype alpha_t = learning_rate*std::sqrt(1-momentum2)/(1-momentum); | ||
update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); | ||
break; | ||
} | ||
default: | ||
LOG(FATAL) << "Unknown solver type: " << solver_type(); | ||
} | ||
|
@@ -739,6 +752,35 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) { | |
} | ||
|
||
|
||
template <typename TypeParam> | ||
class AdamSolverTest : public GradientBasedSolverTest<TypeParam> { | ||
typedef typename TypeParam::Dtype Dtype; | ||
|
||
protected: | ||
virtual void InitSolver(const SolverParameter& param) { | ||
SolverParameter new_param = param; | ||
const Dtype momentum = 0.9; | ||
new_param.set_momentum(momentum); | ||
const Dtype momentum2 = 0.999; | ||
new_param.set_momentum2(momentum2); | ||
this->solver_.reset(new AdamSolver<Dtype>(new_param)); | ||
} | ||
virtual SolverParameter_SolverType solver_type() { | ||
return SolverParameter_SolverType_ADAM; | ||
} | ||
}; | ||
|
||
TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); | ||
|
||
TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) { | ||
typedef typename TypeParam::Dtype Dtype; | ||
const Dtype kLearningRate = 0.001; | ||
const Dtype kWeightDecay = 0.0; | ||
const Dtype kMomentum = 0.9; | ||
const int k = 0; | ||
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, k); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add 4 more test cases:
You may take a look at those There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: #2836 and #2866 modifies solver as well as introduces new share param and snapshot tests. To be consistent, let's also add 4 new test cases:
For TestSnapshot, you can take a look at TestSnapshot in SGDSolverTest as an example. For the other 3 shared cases, you just need to add |
||
template <typename TypeParam> | ||
class NesterovSolverTest : public GradientBasedSolverTest<TypeParam> { | ||
typedef typename TypeParam::Dtype Dtype; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be consistent with #2866, use
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
instead.