From 9d10569a79d4dc1a0aa3f42a0bd5214a085ab7c9 Mon Sep 17 00:00:00 2001 From: qipeng Date: Sat, 19 Jul 2014 14:14:21 -0700 Subject: [PATCH 01/23] Solver switching support & implementation of Nesterov's accelerated gradient and AdaGrad --- include/caffe/solver.hpp | 53 ++++++++++- src/caffe/solver.cpp | 189 +++++++++++++++++++++++++++++++++++++++ tools/train_net.cpp | 3 + 3 files changed, 241 insertions(+), 4 deletions(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 9012c5d50d9..87c6563eae4 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -66,17 +66,62 @@ class SGDSolver : public Solver { : Solver(param_file) {} protected: - virtual void PreSolve(); + void PreSolve(); Dtype GetLearningRate(); virtual void ComputeUpdateValue(); - virtual void SnapshotSolverState(SolverState * state); - virtual void RestoreSolverState(const SolverState& state); + void SnapshotSolverState(SolverState * state); + void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. - vector > > history_; + // update maintains update related data and is not needed in snapshots. + vector > > history_, update_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; +template +class NesterovSolver : public SGDSolver { + public: + explicit NesterovSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit NesterovSolver(const string& param_file) + : SGDSolver(param_file) {} + + protected: + virtual void ComputeUpdateValue(); + + DISABLE_COPY_AND_ASSIGN(NesterovSolver); +}; + +template +class AdaGradSolver : public SGDSolver { + public: + explicit AdaGradSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit AdaGradSolver(const string& param_file) + : SGDSolver(param_file) {} + + protected: + virtual void ComputeUpdateValue(); + + DISABLE_COPY_AND_ASSIGN(AdaGradSolver); +}; + +template +Solver* GetSolver(const SolverParameter& param) { + SolverParameter_SolverType type = param.solver_type(); + + switch (type) { + case SolverParameter_SolverType_SGD: + return new SGDSolver(param); + case SolverParameter_SolverType_NESTEROV: + return new NesterovSolver(param); + case SolverParameter_SolverType_ADAGRAD: + return new AdaGradSolver(param); + default: + LOG(FATAL) << "Unknown SolverType: " << type; + } +} + } // namespace caffe diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 80582b312fe..5632f249e33 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -377,11 +377,15 @@ void SGDSolver::PreSolve() { // Initialize the history vector > >& net_params = this->net_->params(); history_.clear(); + update_.clear(); for (int i = 0; i < net_params.size(); ++i) { const Blob* net_param = net_params[i].get(); history_.push_back(shared_ptr >(new Blob( net_param->num(), net_param->channels(), net_param->height(), net_param->width()))); + update_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); } } @@ -470,7 +474,192 @@ void SGDSolver::RestoreSolverState(const SolverState& state) { } } +template +void NesterovSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = this->GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype momentum = this->param_.momentum(); + Dtype weight_decay = this->param_.weight_decay(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + if (local_decay) { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } + // compute udpate: step back then over step + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->cpu_data(), -momentum, + this->update_[param_id]->mutable_cpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + if (local_decay) { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } + // compute udpate: step back then over step + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->gpu_data(), -momentum, + this->update_[param_id]->mutable_gpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void AdaGradSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = this->GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype weight_decay = this->param_.weight_decay(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } + + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_add(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(-0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } + + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_add(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(-0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // scale and copy + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); +INSTANTIATE_CLASS(NesterovSolver); +INSTANTIATE_CLASS(AdaGradSolver); } // namespace caffe diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 622bca311c8..11767591991 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -1,6 +1,9 @@ #include "caffe/caffe.hpp" +using namespace caffe; // NOLINT(build/namespaces) + int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe train --solver=... " "[--snapshot=...] instead."; return 0; From 8a9c268bd53767365fa0760c167bdcd0158a56f3 Mon Sep 17 00:00:00 2001 From: qipeng Date: Sat, 19 Jul 2014 17:33:35 -0700 Subject: [PATCH 02/23] restored vituals in solver.hpp --- include/caffe/solver.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 87c6563eae4..03c655800a7 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -66,11 +66,11 @@ class SGDSolver : public Solver { : Solver(param_file) {} protected: - void PreSolve(); + virtual void PreSolve(); Dtype GetLearningRate(); virtual void ComputeUpdateValue(); - void SnapshotSolverState(SolverState * state); - void RestoreSolverState(const SolverState& state); + virtual void SnapshotSolverState(SolverState * state); + virtual void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. // update maintains update related data and is not needed in snapshots. vector > > history_, update_; From ed8b1da57fbbadb611d98372671fafd77d863234 Mon Sep 17 00:00:00 2001 From: qipeng Date: Sun, 20 Jul 2014 08:57:33 -0700 Subject: [PATCH 03/23] converted pointers to shared_ptr --- include/caffe/solver.hpp | 3 +++ tools/train_net.cpp | 30 ++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 03c655800a7..9d5481cc368 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -120,8 +120,11 @@ Solver* GetSolver(const SolverParameter& param) { default: LOG(FATAL) << "Unknown SolverType: " << type; } + return (Solver*) NULL; } +template Solver* GetSolver(const SolverParameter& param); +template Solver* GetSolver(const SolverParameter& param); } // namespace caffe diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 11767591991..2a2a522dfb9 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -1,10 +1,36 @@ +// Copyright 2014 BVLC and contributors. +// +// This is a simple script that allows one to quickly train a network whose +// parameters are specified by text format protocol buffers. +// Usage: +// train_net net_proto_file solver_proto_file [resume_point_file] + +#include + #include "caffe/caffe.hpp" using namespace caffe; // NOLINT(build/namespaces) int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc < 2 || argc > 3) { + LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]"; + return 1; + } + + SolverParameter solver_param; + ReadProtoFromTextFileOrDie(argv[1], &solver_param); + + LOG(INFO) << "Starting Optimization"; + shared_ptr > solver = + (shared_ptr >) GetSolver(solver_param); + if (argc == 3) { + LOG(INFO) << "Resuming from " << argv[2]; + solver->Solve(argv[2]); + } else { + solver->Solve(); + } + LOG(INFO) << "Optimization Done."; - LOG(FATAL) << "Deprecated. Use caffe train --solver=... " - "[--snapshot=...] instead."; return 0; } From 8b3dde08fc664fa2ba19694c918bc1d411adbd56 Mon Sep 17 00:00:00 2001 From: qipeng Date: Mon, 21 Jul 2014 10:47:55 -0700 Subject: [PATCH 04/23] fixed solver constructor in train_net.cpp --- tools/train_net.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 2a2a522dfb9..3652182d606 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -22,8 +22,7 @@ int main(int argc, char** argv) { ReadProtoFromTextFileOrDie(argv[1], &solver_param); LOG(INFO) << "Starting Optimization"; - shared_ptr > solver = - (shared_ptr >) GetSolver(solver_param); + shared_ptr > solver(GetSolver(solver_param)); if (argc == 3) { LOG(INFO) << "Resuming from " << argv[2]; solver->Solve(argv[2]); From 0144de68e6f98d5158f8c0f94cd31c9bb6f79db6 Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 22 Jul 2014 21:17:19 -0700 Subject: [PATCH 05/23] improved numerical stability for AdaGrad --- src/caffe/solver.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 5632f249e33..abcbe5e6a20 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -481,6 +481,7 @@ void NesterovSolver::ComputeUpdateValue() { vector& net_params_weight_decay = this->net_->params_weight_decay(); // get the learning rate Dtype rate = this->GetLearningRate(); + Dtype delta = this->param_.delta(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } @@ -594,10 +595,13 @@ void AdaGradSolver::ComputeUpdateValue() { // prepare update caffe_powx(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), Dtype(-0.5), + this->history_[param_id]->cpu_data(), Dtype(0.5), this->update_[param_id]->mutable_cpu_data()); - caffe_mul(net_params[param_id]->count(), + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), this->update_[param_id]->mutable_cpu_data()); @@ -635,10 +639,13 @@ void AdaGradSolver::ComputeUpdateValue() { // prepare update caffe_gpu_powx(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), Dtype(-0.5), + this->history_[param_id]->gpu_data(), Dtype(0.5), this->update_[param_id]->mutable_gpu_data()); - caffe_gpu_mul(net_params[param_id]->count(), + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), this->update_[param_id]->mutable_gpu_data()); From 76ef2ca1f6e326622090aa7d57a10e80d5831350 Mon Sep 17 00:00:00 2001 From: qipeng Date: Wed, 23 Jul 2014 10:25:44 -0700 Subject: [PATCH 06/23] bugfixes for AdaGrad --- src/caffe/solver.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index abcbe5e6a20..8928c7b29ae 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -577,14 +577,14 @@ void AdaGradSolver::ComputeUpdateValue() { if (local_decay) { // add weight decay caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, net_params[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } // compute square of gradient in update caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), Dtype(2), + net_params[param_id]->cpu_diff(), Dtype(2), this->update_[param_id]->mutable_cpu_data()); // update history @@ -621,14 +621,14 @@ void AdaGradSolver::ComputeUpdateValue() { if (local_decay) { // add weight decay caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, net_params[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } // compute square of gradient in update caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), Dtype(2), + net_params[param_id]->gpu_diff(), Dtype(2), this->update_[param_id]->mutable_gpu_data()); // update history From a683c40d63b79dcb4e2407d59335fb7f8129c82f Mon Sep 17 00:00:00 2001 From: qipeng Date: Thu, 24 Jul 2014 13:09:28 -0700 Subject: [PATCH 07/23] Added L1 regularization support for the weights --- include/caffe/solver.hpp | 4 +- src/caffe/proto/caffe.proto | 3 + src/caffe/solver.cpp | 139 ++++++++++++++++++++++++++++-------- 3 files changed, 115 insertions(+), 31 deletions(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 9d5481cc368..4bf50d413c8 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -73,7 +73,9 @@ class SGDSolver : public Solver { virtual void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. // update maintains update related data and is not needed in snapshots. - vector > > history_, update_; + // temp maintains other information that might be needed in computation + // of gradients/updates and is not needed in snapshots + vector > > history_, update_, temp_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 49a6e142fcd..0bb5d11cfbe 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -116,6 +116,9 @@ message SolverParameter { optional float power = 10; // The parameter to compute the learning rate. optional float momentum = 11; // The momentum value. optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controled by weight_decay + optional string regularization_type = 25 [default = "L2"]; optional int32 stepsize = 13; // the stepsize for learning rate policy "step" optional int32 snapshot = 14 [default = 0]; // The snapshot interval optional string snapshot_prefix = 15; // The prefix for the snapshot. diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 8928c7b29ae..223194bdb09 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -378,6 +378,7 @@ void SGDSolver::PreSolve() { vector > >& net_params = this->net_->params(); history_.clear(); update_.clear(); + temp_.clear(); for (int i = 0; i < net_params.size(); ++i) { const Blob* net_param = net_params[i].get(); history_.push_back(shared_ptr >(new Blob( @@ -386,6 +387,9 @@ void SGDSolver::PreSolve() { update_.push_back(shared_ptr >(new Blob( net_param->num(), net_param->channels(), net_param->height(), net_param->width()))); + temp_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); } } @@ -402,6 +406,7 @@ void SGDSolver::ComputeUpdateValue() { } Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { @@ -412,11 +417,23 @@ void SGDSolver::ComputeUpdateValue() { net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); if (local_decay) { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->cpu_data(), - history_[param_id]->mutable_cpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + history_[param_id]->mutable_cpu_data()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + temp_[param_id]->cpu_data(), + history_[param_id]->mutable_cpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // copy caffe_copy(net_params[param_id]->count(), @@ -434,11 +451,23 @@ void SGDSolver::ComputeUpdateValue() { net_params[param_id]->gpu_diff(), momentum, history_[param_id]->mutable_gpu_data()); if (local_decay) { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->gpu_data(), - history_[param_id]->mutable_gpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + history_[param_id]->mutable_gpu_data()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + temp_[param_id]->gpu_data(), + history_[param_id]->mutable_gpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // copy caffe_copy(net_params[param_id]->count(), @@ -487,6 +516,7 @@ void NesterovSolver::ComputeUpdateValue() { } Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { @@ -501,11 +531,23 @@ void NesterovSolver::ComputeUpdateValue() { net_params[param_id]->cpu_diff(), momentum, this->history_[param_id]->mutable_cpu_data()); if (local_decay) { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay * local_rate, + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute udpate: step back then over step caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, @@ -532,11 +574,23 @@ void NesterovSolver::ComputeUpdateValue() { net_params[param_id]->gpu_diff(), momentum, this->history_[param_id]->mutable_gpu_data()); if (local_decay) { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, - net_params[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay * local_rate, + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute udpate: step back then over step caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, @@ -568,6 +622,7 @@ void AdaGradSolver::ComputeUpdateValue() { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { @@ -575,11 +630,23 @@ void AdaGradSolver::ComputeUpdateValue() { Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; if (local_decay) { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute square of gradient in update @@ -619,11 +686,23 @@ void AdaGradSolver::ComputeUpdateValue() { Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; if (local_decay) { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } } // compute square of gradient in update From b0ec5314f407f47bd2b593c609f88cf23ffa32cf Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 29 Jul 2014 10:06:32 -0700 Subject: [PATCH 08/23] fixed caffe.proto after a mistaken rebase --- src/caffe/proto/caffe.proto | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 0bb5d11cfbe..73ae89d326a 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -137,6 +137,16 @@ message SolverParameter { // random number generator -- useful for reproducible results. Otherwise, // (and by default) initialize using a seed derived from the system clock. optional int64 random_seed = 20 [default = -1]; + + // Solver type + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + } + optional SolverType solver_type = 26 [default = SGD]; + // numerical stability for AdaGrad + optional float delta = 27 [default = 1e-8]; // If true, print information about the state of the net that may help with // debugging learning problems. From 3f7a910e05c1ee6e9d09ac201831592357fc18ee Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 29 Jul 2014 19:46:37 -0700 Subject: [PATCH 09/23] Addressed Yangqing's comments --- include/caffe/solver.hpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 4bf50d413c8..2f51268f295 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -125,9 +125,6 @@ Solver* GetSolver(const SolverParameter& param) { return (Solver*) NULL; } -template Solver* GetSolver(const SolverParameter& param); -template Solver* GetSolver(const SolverParameter& param); - } // namespace caffe #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_ From 23d44308692e469bd7b587d1697ac68068986608 Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 29 Jul 2014 19:56:36 -0700 Subject: [PATCH 10/23] fixes after rebase --- src/caffe/proto/caffe.proto | 8 ++++---- src/caffe/solver.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 73ae89d326a..9afe8e83a8e 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -63,7 +63,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 27 (last added: test_state) +// SolverParameter next available ID: 31 (last added: delta) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -118,7 +118,7 @@ message SolverParameter { optional float weight_decay = 12; // The weight decay. // regularization types supported: L1 and L2 // controled by weight_decay - optional string regularization_type = 25 [default = "L2"]; + optional string regularization_type = 28 [default = "L2"]; optional int32 stepsize = 13; // the stepsize for learning rate policy "step" optional int32 snapshot = 14 [default = 0]; // The snapshot interval optional string snapshot_prefix = 15; // The prefix for the snapshot. @@ -144,9 +144,9 @@ message SolverParameter { NESTEROV = 1; ADAGRAD = 2; } - optional SolverType solver_type = 26 [default = SGD]; + optional SolverType solver_type = 29 [default = SGD]; // numerical stability for AdaGrad - optional float delta = 27 [default = 1e-8]; + optional float delta = 30 [default = 1e-8]; // If true, print information about the state of the net that may help with // debugging learning problems. diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 223194bdb09..52fd6523e95 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -510,7 +510,6 @@ void NesterovSolver::ComputeUpdateValue() { vector& net_params_weight_decay = this->net_->params_weight_decay(); // get the learning rate Dtype rate = this->GetLearningRate(); - Dtype delta = this->param_.delta(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } @@ -618,6 +617,7 @@ void AdaGradSolver::ComputeUpdateValue() { vector& net_params_weight_decay = this->net_->params_weight_decay(); // get the learning rate Dtype rate = this->GetLearningRate(); + Dtype delta = this->param_.delta(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } From 29b3b2442dc37d7c7481e96a3db747447181b1b8 Mon Sep 17 00:00:00 2001 From: qipeng Date: Wed, 20 Aug 2014 15:19:30 -0700 Subject: [PATCH 11/23] proto conflit, lint, and math_functions (compiler complaint) --- src/caffe/proto/caffe.proto | 8 ++++---- tools/train_net.cpp | 9 --------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 9afe8e83a8e..bb14b838f9b 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -63,7 +63,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 31 (last added: delta) +// SolverParameter next available ID: 32 (last added: delta) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -118,7 +118,7 @@ message SolverParameter { optional float weight_decay = 12; // The weight decay. // regularization types supported: L1 and L2 // controled by weight_decay - optional string regularization_type = 28 [default = "L2"]; + optional string regularization_type = 29 [default = "L2"]; optional int32 stepsize = 13; // the stepsize for learning rate policy "step" optional int32 snapshot = 14 [default = 0]; // The snapshot interval optional string snapshot_prefix = 15; // The prefix for the snapshot. @@ -144,9 +144,9 @@ message SolverParameter { NESTEROV = 1; ADAGRAD = 2; } - optional SolverType solver_type = 29 [default = SGD]; + optional SolverType solver_type = 30 [default = SGD]; // numerical stability for AdaGrad - optional float delta = 30 [default = 1e-8]; + optional float delta = 31 [default = 1e-8]; // If true, print information about the state of the net that may help with // debugging learning problems. diff --git a/tools/train_net.cpp b/tools/train_net.cpp index 3652182d606..9429fe892e9 100644 --- a/tools/train_net.cpp +++ b/tools/train_net.cpp @@ -1,12 +1,3 @@ -// Copyright 2014 BVLC and contributors. -// -// This is a simple script that allows one to quickly train a network whose -// parameters are specified by text format protocol buffers. -// Usage: -// train_net net_proto_file solver_proto_file [resume_point_file] - -#include - #include "caffe/caffe.hpp" using namespace caffe; // NOLINT(build/namespaces) From 7f2e66e6cc0133e18a659611eaffbadb7b0edf2c Mon Sep 17 00:00:00 2001 From: qipeng Date: Wed, 20 Aug 2014 20:07:53 -0700 Subject: [PATCH 12/23] added unit test for solvers and fixed solver bugs --- src/caffe/solver.cpp | 76 ++--- src/caffe/test/test_adagrad_solver.cpp | 351 ++++++++++++++++++++++++ src/caffe/test/test_nesterov_solver.cpp | 351 ++++++++++++++++++++++++ 3 files changed, 746 insertions(+), 32 deletions(-) create mode 100644 src/caffe/test/test_adagrad_solver.cpp create mode 100644 src/caffe/test/test_nesterov_solver.cpp diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 52fd6523e95..dcac4c1537c 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -413,28 +413,30 @@ void SGDSolver::ComputeUpdateValue() { // Compute the value to history, and then copy them to the blob's diff. Dtype local_rate = rate * net_params_lr[param_id]; Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - history_[param_id]->mutable_cpu_data()); + if (local_decay) { if (regularization_type == "L2") { // add weight decay caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, net_params[param_id]->cpu_data(), - history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } else if (regularization_type == "L1") { caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), temp_[param_id]->mutable_cpu_data()); caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, temp_[param_id]->cpu_data(), - history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } + + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + history_[param_id]->mutable_cpu_data()); // copy caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), @@ -447,28 +449,30 @@ void SGDSolver::ComputeUpdateValue() { // Compute the value to history, and then copy them to the blob's diff. Dtype local_rate = rate * net_params_lr[param_id]; Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - history_[param_id]->mutable_gpu_data()); + if (local_decay) { if (regularization_type == "L2") { // add weight decay caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, net_params[param_id]->gpu_data(), - history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } else if (regularization_type == "L1") { caffe_gpu_sign(net_params[param_id]->count(), net_params[param_id]->gpu_data(), temp_[param_id]->mutable_gpu_data()); caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, temp_[param_id]->gpu_data(), - history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + history_[param_id]->mutable_gpu_data()); // copy caffe_copy(net_params[param_id]->count(), history_[param_id]->gpu_data(), @@ -526,28 +530,32 @@ void NesterovSolver::ComputeUpdateValue() { Dtype local_rate = rate * net_params_lr[param_id]; Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - this->history_[param_id]->mutable_cpu_data()); + if (local_decay) { if (regularization_type == "L2") { // add weight decay caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, net_params[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } else if (regularization_type == "L1") { caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), this->temp_[param_id]->mutable_cpu_data()); caffe_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, this->temp_[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } + + // update history + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + // compute udpate: step back then over step caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, this->history_[param_id]->cpu_data(), -momentum, @@ -569,28 +577,32 @@ void NesterovSolver::ComputeUpdateValue() { Dtype local_rate = rate * net_params_lr[param_id]; Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - this->history_[param_id]->mutable_gpu_data()); + if (local_decay) { if (regularization_type == "L2") { // add weight decay caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, net_params[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } else if (regularization_type == "L1") { caffe_gpu_sign(net_params[param_id]->count(), net_params[param_id]->gpu_data(), this->temp_[param_id]->mutable_gpu_data()); caffe_gpu_axpy(net_params[param_id]->count(), - local_decay * local_rate, + local_decay, this->temp_[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } + + // update history + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + // compute udpate: step back then over step caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, this->history_[param_id]->gpu_data(), -momentum, @@ -635,7 +647,7 @@ void AdaGradSolver::ComputeUpdateValue() { caffe_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } else if (regularization_type == "L1") { caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), @@ -643,7 +655,7 @@ void AdaGradSolver::ComputeUpdateValue() { caffe_axpy(net_params[param_id]->count(), local_decay, this->temp_[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); + net_params[param_id]->mutable_cpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } @@ -691,7 +703,7 @@ void AdaGradSolver::ComputeUpdateValue() { caffe_gpu_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } else if (regularization_type == "L1") { caffe_gpu_sign(net_params[param_id]->count(), net_params[param_id]->gpu_data(), @@ -699,7 +711,7 @@ void AdaGradSolver::ComputeUpdateValue() { caffe_gpu_axpy(net_params[param_id]->count(), local_decay, this->temp_[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); + net_params[param_id]->mutable_gpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } diff --git a/src/caffe/test/test_adagrad_solver.cpp b/src/caffe/test/test_adagrad_solver.cpp new file mode 100644 index 00000000000..45cf2002e33 --- /dev/null +++ b/src/caffe/test/test_adagrad_solver.cpp @@ -0,0 +1,351 @@ +#include +#include +#include +#include + +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +using std::ostringstream; + +namespace caffe { + +template +class AdaGradSolverTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + AdaGradSolverTest() : + seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} + + // MockAdaGradSolver: an AdaGradSolver with public history. + class MockAdaGradSolver : public AdaGradSolver { + public: + explicit MockAdaGradSolver(const SolverParameter& param) : + AdaGradSolver(param) {} + vector > >& history() { return this->history_; } + Dtype delta() { return this->param_.delta(); } + }; + + shared_ptr solver_; + int seed_; + int num_, channels_, height_, width_; + + virtual void InitSolverFromProtoString(const string& proto) { + SolverParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + // Disable saving a final snapshot so the tests don't pollute the user's + // working directory with useless snapshots. + param.set_snapshot_after_train(false); + // Set the solver_mode according to current Caffe::mode. + switch (Caffe::mode()) { + case Caffe::CPU: + param.set_solver_mode(SolverParameter_SolverMode_CPU); + break; + case Caffe::GPU: + param.set_solver_mode(SolverParameter_SolverMode_GPU); + break; + default: + LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); + } + solver_.reset(new MockAdaGradSolver(param)); + } + + void RunLeastSquaresSolver(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, const int num_iters) { + ostringstream proto; + proto << + "max_iter: " << num_iters << " " + "base_lr: " << learning_rate << " " + "lr_policy: 'fixed' " + "net_param { " + " name: 'TestNetwork' " + " layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: " << num_ << " " + " channels: " << channels_ << " " + " height: " << height_ << " " + " width: " << width_ << " " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " top: 'data' " + " top: 'targets' " + " } " + " layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " bias_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " bottom: 'data' " + " top: 'innerprod' " + " } " + " layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod' " + " bottom: 'targets' " + " } " + "} "; + if (weight_decay != 0) { + proto << "weight_decay: " << weight_decay << " "; + } + if (momentum != 0) { + proto << "momentum: " << momentum << " "; + } + Caffe::set_random_seed(this->seed_); + this->InitSolverFromProtoString(proto.str()); + this->solver_->Solve(); + } + + // Compute an update value given the current state of the train net, + // using the analytical formula for the least squares gradient. + // updated_params will store the updated weight and bias results, + // using the blobs' diffs to hold the update values themselves. + void ComputeLeastSquaresUpdate(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, + vector > >* updated_params) { + const int N = num_; + const int D = channels_ * height_ * width_; + + // Run a forward pass, and manually compute the update values from the + // result. + Net& net = *this->solver_->net(); + vector*> empty_bottom_vec; + net.Forward(empty_bottom_vec); + ASSERT_TRUE(net.has_blob("data")); + const Blob& data = *net.blob_by_name("data"); + ASSERT_TRUE(net.has_blob("targets")); + const Blob& targets = *net.blob_by_name("targets"); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + const int num_param_blobs = 2; + ASSERT_EQ(num_param_blobs, param_blobs.size()); + const Blob& weights = *param_blobs[0]; + const Blob& bias = *param_blobs[1]; + ASSERT_EQ(D * N, data.count()); + ASSERT_EQ(N, targets.count()); + ASSERT_EQ(D, weights.count()); + ASSERT_EQ(1, bias.count()); + + updated_params->clear(); + updated_params->resize(num_param_blobs); + for (int i = 0; i < num_param_blobs; ++i) { + (*updated_params)[i].reset(new Blob()); + } + Blob& updated_weights = *(*updated_params)[0]; + updated_weights.ReshapeLike(weights); + Blob& updated_bias = *(*updated_params)[1]; + updated_bias.ReshapeLike(bias); + + for (int i = 0; i <= D; ++i) { + // Compute the derivative with respect to the ith weight (i.e., the ith + // element of the gradient). + Dtype grad = 0; + for (int j = 0; j <= D; ++j) { + // Compute element (i, j) of X^T * X. + Dtype element = 0; + for (int k = 0; k < N; ++k) { + // (i, k) in X^T (== (k, i) in X) times (k, j) in X. + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; + element += element_i * element_j; + } + if (j == D) { + grad += element * bias.cpu_data()[0]; + } else { + grad += element * weights.cpu_data()[j]; + } + } + for (int k = 0; k < N; ++k) { + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + grad -= element_i * targets.cpu_data()[k]; + } + // Scale the gradient over the N samples. + grad /= N; + // Add the weight decay to the gradient. + grad += weight_decay * + ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); + // Finally, compute update + const vector > >& history = solver_->history(); + Dtype delta = solver_->delta(); + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + Dtype update_value, temp; + if (i == D) { + temp = history[1]->cpu_data()[0]; + temp += grad * grad; + update_value = learning_rate * grad / (std::sqrt(temp) + delta); + updated_bias.mutable_cpu_diff()[0] = update_value; + updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; + } else { + temp = history[0]->cpu_data()[i]; + temp += grad * grad; + update_value = learning_rate * grad / (std::sqrt(temp) + delta); + updated_weights.mutable_cpu_diff()[i] = update_value; + updated_weights.mutable_cpu_data()[i] = + weights.cpu_data()[i] - update_value; + } + } + } + + void CheckLeastSquaresUpdate( + const vector > >& updated_params) { + const int D = channels_ * height_ * width_; + + const Blob& updated_weights = *updated_params[0]; + const Blob& updated_bias = *updated_params[1]; + + Net& net = *this->solver_->net(); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + ASSERT_EQ(2, param_blobs.size()); + const Blob& solver_updated_weights = *param_blobs[0]; + ASSERT_EQ(D, solver_updated_weights.count()); + const double kPrecision = 1e-3; + const double kMinPrecision = 1e-7; + for (int i = 0; i < D; ++i) { + const Dtype expected_updated_weight = updated_weights.cpu_data()[i]; + const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_weight), fabs(solver_updated_weight))); + EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin); + } + const Blob& solver_updated_bias_blob = *param_blobs[1]; + ASSERT_EQ(1, solver_updated_bias_blob.count()); + const Dtype expected_updated_bias = updated_bias.cpu_data()[0]; + const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_bias), fabs(solver_updated_bias))); + EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); + + // Check the solver's history -- should contain the previous update value. +// vector > >& history = this->solver_->history(); +// ASSERT_EQ(2, history.size()); +// for (int i = 0; i < D; ++i) { +// const Dtype expected_history = updated_weights.cpu_diff()[i]; +// const Dtype solver_history = history[0]->cpu_data()[i]; +// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * +// std::min(fabs(expected_history), fabs(solver_history))); +// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); +// } +// const Dtype expected_history = updated_bias.cpu_diff()[0]; +// const Dtype solver_history = history[1]->cpu_data()[0]; +// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * +// std::min(fabs(expected_history), fabs(solver_history))); +// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + + // Test that the correct update is computed for a regularized least squares + // problem: + // + // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2 + // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w + // + // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1) + // w \in R^{(d+1) x 1} ((d+1)th element is the bias) + // y \in R^{n x 1} + // lambda is weight_decay + // + // TestLeastSquaresUpdate works "inductively", assuming that the solver + // correctly updates the net K (= iter_to_check) times, then given the history + // from the Kth update, we compute the (K+1)th update and check that it + // matches the solver's (K+1)th update. + void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, + const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, + const int iter_to_check = 0) { + // Initialize the solver and run K (= iter_to_check) solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } +}; + +TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} +/* +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} +*/ +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.0; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +} // namespace caffe diff --git a/src/caffe/test/test_nesterov_solver.cpp b/src/caffe/test/test_nesterov_solver.cpp new file mode 100644 index 00000000000..f2fcba30b29 --- /dev/null +++ b/src/caffe/test/test_nesterov_solver.cpp @@ -0,0 +1,351 @@ +#include +#include +#include +#include + +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +using std::ostringstream; + +namespace caffe { + +template +class NesterovSolverTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + NesterovSolverTest() : + seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} + + // MockNesterovSolver: an NesterovSolver with public history. + class MockNesterovSolver : public NesterovSolver { + public: + explicit MockNesterovSolver(const SolverParameter& param) : + NesterovSolver(param) {} + vector > >& history() { return this->history_; } + }; + + shared_ptr solver_; + int seed_; + int num_, channels_, height_, width_; + + virtual void InitSolverFromProtoString(const string& proto) { + SolverParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + // Disable saving a final snapshot so the tests don't pollute the user's + // working directory with useless snapshots. + param.set_snapshot_after_train(false); + // Set the solver_mode according to current Caffe::mode. + switch (Caffe::mode()) { + case Caffe::CPU: + param.set_solver_mode(SolverParameter_SolverMode_CPU); + break; + case Caffe::GPU: + param.set_solver_mode(SolverParameter_SolverMode_GPU); + break; + default: + LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); + } + solver_.reset(new MockNesterovSolver(param)); + } + + void RunLeastSquaresSolver(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, const int num_iters) { + ostringstream proto; + proto << + "max_iter: " << num_iters << " " + "base_lr: " << learning_rate << " " + "lr_policy: 'fixed' " + "net_param { " + " name: 'TestNetwork' " + " layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: " << num_ << " " + " channels: " << channels_ << " " + " height: " << height_ << " " + " width: " << width_ << " " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " top: 'data' " + " top: 'targets' " + " } " + " layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " bias_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " bottom: 'data' " + " top: 'innerprod' " + " } " + " layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod' " + " bottom: 'targets' " + " } " + "} "; + if (weight_decay != 0) { + proto << "weight_decay: " << weight_decay << " "; + } + if (momentum != 0) { + proto << "momentum: " << momentum << " "; + } + Caffe::set_random_seed(this->seed_); + this->InitSolverFromProtoString(proto.str()); + this->solver_->Solve(); + } + + // Compute an update value given the current state of the train net, + // using the analytical formula for the least squares gradient. + // updated_params will store the updated weight and bias results, + // using the blobs' diffs to hold the update values themselves. + void ComputeLeastSquaresUpdate(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, + vector > >* updated_params) { + const int N = num_; + const int D = channels_ * height_ * width_; + + // Run a forward pass, and manually compute the update values from the + // result. + Net& net = *this->solver_->net(); + vector*> empty_bottom_vec; + net.Forward(empty_bottom_vec); + ASSERT_TRUE(net.has_blob("data")); + const Blob& data = *net.blob_by_name("data"); + ASSERT_TRUE(net.has_blob("targets")); + const Blob& targets = *net.blob_by_name("targets"); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + const int num_param_blobs = 2; + ASSERT_EQ(num_param_blobs, param_blobs.size()); + const Blob& weights = *param_blobs[0]; + const Blob& bias = *param_blobs[1]; + ASSERT_EQ(D * N, data.count()); + ASSERT_EQ(N, targets.count()); + ASSERT_EQ(D, weights.count()); + ASSERT_EQ(1, bias.count()); + + updated_params->clear(); + updated_params->resize(num_param_blobs); + for (int i = 0; i < num_param_blobs; ++i) { + (*updated_params)[i].reset(new Blob()); + } + Blob& updated_weights = *(*updated_params)[0]; + updated_weights.ReshapeLike(weights); + Blob& updated_bias = *(*updated_params)[1]; + updated_bias.ReshapeLike(bias); + + for (int i = 0; i <= D; ++i) { + // Compute the derivative with respect to the ith weight (i.e., the ith + // element of the gradient). + Dtype grad = 0; + for (int j = 0; j <= D; ++j) { + // Compute element (i, j) of X^T * X. + Dtype element = 0; + for (int k = 0; k < N; ++k) { + // (i, k) in X^T (== (k, i) in X) times (k, j) in X. + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; + element += element_i * element_j; + } + if (j == D) { + grad += element * bias.cpu_data()[0]; + } else { + grad += element * weights.cpu_data()[j]; + } + } + for (int k = 0; k < N; ++k) { + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + grad -= element_i * targets.cpu_data()[k]; + } + // Scale the gradient over the N samples. + grad /= N; + // Add the weight decay to the gradient. + grad += weight_decay * + ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); + // Finally, add any momentum. + const vector > >& history = solver_->history(); + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + Dtype update_value = learning_rate * grad, temp; + if (i == D) { + temp = history[1]->cpu_data()[0] * momentum; + update_value += temp; // update history + // step back then over-step + update_value = (1 + momentum) * update_value - temp; + updated_bias.mutable_cpu_diff()[0] = update_value; + updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; + } else { + temp = history[0]->cpu_data()[i] * momentum; + update_value += temp; // update history + // step back then over-step + update_value = (1 + momentum) * update_value - temp; + updated_weights.mutable_cpu_diff()[i] = update_value; + updated_weights.mutable_cpu_data()[i] = + weights.cpu_data()[i] - update_value; + } + } + } + + void CheckLeastSquaresUpdate( + const vector > >& updated_params) { + const int D = channels_ * height_ * width_; + + const Blob& updated_weights = *updated_params[0]; + const Blob& updated_bias = *updated_params[1]; + + Net& net = *this->solver_->net(); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + ASSERT_EQ(2, param_blobs.size()); + const Blob& solver_updated_weights = *param_blobs[0]; + ASSERT_EQ(D, solver_updated_weights.count()); + const double kPrecision = 1e-3; + const double kMinPrecision = 1e-7; + for (int i = 0; i < D; ++i) { + const Dtype expected_updated_weight = updated_weights.cpu_data()[i]; + const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_weight), fabs(solver_updated_weight))); + EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin); + } + const Blob& solver_updated_bias_blob = *param_blobs[1]; + ASSERT_EQ(1, solver_updated_bias_blob.count()); + const Dtype expected_updated_bias = updated_bias.cpu_data()[0]; + const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_bias), fabs(solver_updated_bias))); + EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); + + // Check the solver's history -- should contain the previous update value. +// vector > >& history = this->solver_->history(); +// ASSERT_EQ(2, history.size()); +// for (int i = 0; i < D; ++i) { +// const Dtype expected_history = updated_weights.cpu_diff()[i]; +// const Dtype solver_history = history[0]->cpu_data()[i]; +// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * +// std::min(fabs(expected_history), fabs(solver_history))); +// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); +// } +// const Dtype expected_history = updated_bias.cpu_diff()[0]; +// const Dtype solver_history = history[1]->cpu_data()[0]; +// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * +// std::min(fabs(expected_history), fabs(solver_history))); +// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + + // Test that the correct update is computed for a regularized least squares + // problem: + // + // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2 + // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w + // + // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1) + // w \in R^{(d+1) x 1} ((d+1)th element is the bias) + // y \in R^{n x 1} + // lambda is weight_decay + // + // TestLeastSquaresUpdate works "inductively", assuming that the solver + // correctly updates the net K (= iter_to_check) times, then given the history + // from the Kth update, we compute the (K+1)th update and check that it + // matches the solver's (K+1)th update. + void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, + const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, + const int iter_to_check = 0) { + // Initialize the solver and run K (= iter_to_check) solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } +}; + +TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +} // namespace caffe From dbb9296050054095d8d2979992e969534b9b1430 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Fri, 22 Aug 2014 11:51:29 -0700 Subject: [PATCH 13/23] cleanup caffe.proto --- src/caffe/proto/caffe.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index bb14b838f9b..ff18779f62d 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -137,7 +137,7 @@ message SolverParameter { // random number generator -- useful for reproducible results. Otherwise, // (and by default) initialize using a seed derived from the system clock. optional int64 random_seed = 20 [default = -1]; - + // Solver type enum SolverType { SGD = 0; From f206c6479fd355f17bd821ceb6931fd0663c1e70 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Fri, 22 Aug 2014 11:51:16 -0700 Subject: [PATCH 14/23] Merge Test{SGD,AdaGrad,Nesterov}Solver; they become subclasses of TestGradientBasedSolver --- include/caffe/solver.hpp | 2 + src/caffe/test/test_adagrad_solver.cpp | 351 ------------------ ...ver.cpp => test_gradient_based_solver.cpp} | 193 ++++++++-- src/caffe/test/test_nesterov_solver.cpp | 351 ------------------ 4 files changed, 170 insertions(+), 727 deletions(-) delete mode 100644 src/caffe/test/test_adagrad_solver.cpp rename src/caffe/test/{test_sgd_solver.cpp => test_gradient_based_solver.cpp} (66%) delete mode 100644 src/caffe/test/test_nesterov_solver.cpp diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 2f51268f295..b3c7f6f35d6 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -65,6 +65,8 @@ class SGDSolver : public Solver { explicit SGDSolver(const string& param_file) : Solver(param_file) {} + const vector > >& history() { return history_; } + protected: virtual void PreSolve(); Dtype GetLearningRate(); diff --git a/src/caffe/test/test_adagrad_solver.cpp b/src/caffe/test/test_adagrad_solver.cpp deleted file mode 100644 index 45cf2002e33..00000000000 --- a/src/caffe/test/test_adagrad_solver.cpp +++ /dev/null @@ -1,351 +0,0 @@ -#include -#include -#include -#include - -#include "google/protobuf/text_format.h" - -#include "gtest/gtest.h" - -#include "caffe/common.hpp" -#include "caffe/proto/caffe.pb.h" -#include "caffe/solver.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -using std::ostringstream; - -namespace caffe { - -template -class AdaGradSolverTest : public MultiDeviceTest { - typedef typename TypeParam::Dtype Dtype; - - protected: - AdaGradSolverTest() : - seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} - - // MockAdaGradSolver: an AdaGradSolver with public history. - class MockAdaGradSolver : public AdaGradSolver { - public: - explicit MockAdaGradSolver(const SolverParameter& param) : - AdaGradSolver(param) {} - vector > >& history() { return this->history_; } - Dtype delta() { return this->param_.delta(); } - }; - - shared_ptr solver_; - int seed_; - int num_, channels_, height_, width_; - - virtual void InitSolverFromProtoString(const string& proto) { - SolverParameter param; - CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); - // Disable saving a final snapshot so the tests don't pollute the user's - // working directory with useless snapshots. - param.set_snapshot_after_train(false); - // Set the solver_mode according to current Caffe::mode. - switch (Caffe::mode()) { - case Caffe::CPU: - param.set_solver_mode(SolverParameter_SolverMode_CPU); - break; - case Caffe::GPU: - param.set_solver_mode(SolverParameter_SolverMode_GPU); - break; - default: - LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); - } - solver_.reset(new MockAdaGradSolver(param)); - } - - void RunLeastSquaresSolver(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, const int num_iters) { - ostringstream proto; - proto << - "max_iter: " << num_iters << " " - "base_lr: " << learning_rate << " " - "lr_policy: 'fixed' " - "net_param { " - " name: 'TestNetwork' " - " layers: { " - " name: 'data' " - " type: DUMMY_DATA " - " dummy_data_param { " - " num: " << num_ << " " - " channels: " << channels_ << " " - " height: " << height_ << " " - " width: " << width_ << " " - " channels: 1 " - " height: 1 " - " width: 1 " - " data_filler { " - " type: 'gaussian' " - " std: 1.0 " - " } " - " } " - " top: 'data' " - " top: 'targets' " - " } " - " layers: { " - " name: 'innerprod' " - " type: INNER_PRODUCT " - " inner_product_param { " - " num_output: 1 " - " weight_filler { " - " type: 'gaussian' " - " std: 1.0 " - " } " - " bias_filler { " - " type: 'gaussian' " - " std: 1.0 " - " } " - " } " - " bottom: 'data' " - " top: 'innerprod' " - " } " - " layers: { " - " name: 'loss' " - " type: EUCLIDEAN_LOSS " - " bottom: 'innerprod' " - " bottom: 'targets' " - " } " - "} "; - if (weight_decay != 0) { - proto << "weight_decay: " << weight_decay << " "; - } - if (momentum != 0) { - proto << "momentum: " << momentum << " "; - } - Caffe::set_random_seed(this->seed_); - this->InitSolverFromProtoString(proto.str()); - this->solver_->Solve(); - } - - // Compute an update value given the current state of the train net, - // using the analytical formula for the least squares gradient. - // updated_params will store the updated weight and bias results, - // using the blobs' diffs to hold the update values themselves. - void ComputeLeastSquaresUpdate(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, - vector > >* updated_params) { - const int N = num_; - const int D = channels_ * height_ * width_; - - // Run a forward pass, and manually compute the update values from the - // result. - Net& net = *this->solver_->net(); - vector*> empty_bottom_vec; - net.Forward(empty_bottom_vec); - ASSERT_TRUE(net.has_blob("data")); - const Blob& data = *net.blob_by_name("data"); - ASSERT_TRUE(net.has_blob("targets")); - const Blob& targets = *net.blob_by_name("targets"); - ASSERT_TRUE(net.has_layer("innerprod")); - const vector > >& param_blobs = - net.layer_by_name("innerprod")->blobs(); - const int num_param_blobs = 2; - ASSERT_EQ(num_param_blobs, param_blobs.size()); - const Blob& weights = *param_blobs[0]; - const Blob& bias = *param_blobs[1]; - ASSERT_EQ(D * N, data.count()); - ASSERT_EQ(N, targets.count()); - ASSERT_EQ(D, weights.count()); - ASSERT_EQ(1, bias.count()); - - updated_params->clear(); - updated_params->resize(num_param_blobs); - for (int i = 0; i < num_param_blobs; ++i) { - (*updated_params)[i].reset(new Blob()); - } - Blob& updated_weights = *(*updated_params)[0]; - updated_weights.ReshapeLike(weights); - Blob& updated_bias = *(*updated_params)[1]; - updated_bias.ReshapeLike(bias); - - for (int i = 0; i <= D; ++i) { - // Compute the derivative with respect to the ith weight (i.e., the ith - // element of the gradient). - Dtype grad = 0; - for (int j = 0; j <= D; ++j) { - // Compute element (i, j) of X^T * X. - Dtype element = 0; - for (int k = 0; k < N; ++k) { - // (i, k) in X^T (== (k, i) in X) times (k, j) in X. - const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; - const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; - element += element_i * element_j; - } - if (j == D) { - grad += element * bias.cpu_data()[0]; - } else { - grad += element * weights.cpu_data()[j]; - } - } - for (int k = 0; k < N; ++k) { - const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; - grad -= element_i * targets.cpu_data()[k]; - } - // Scale the gradient over the N samples. - grad /= N; - // Add the weight decay to the gradient. - grad += weight_decay * - ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); - // Finally, compute update - const vector > >& history = solver_->history(); - Dtype delta = solver_->delta(); - ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias - Dtype update_value, temp; - if (i == D) { - temp = history[1]->cpu_data()[0]; - temp += grad * grad; - update_value = learning_rate * grad / (std::sqrt(temp) + delta); - updated_bias.mutable_cpu_diff()[0] = update_value; - updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; - } else { - temp = history[0]->cpu_data()[i]; - temp += grad * grad; - update_value = learning_rate * grad / (std::sqrt(temp) + delta); - updated_weights.mutable_cpu_diff()[i] = update_value; - updated_weights.mutable_cpu_data()[i] = - weights.cpu_data()[i] - update_value; - } - } - } - - void CheckLeastSquaresUpdate( - const vector > >& updated_params) { - const int D = channels_ * height_ * width_; - - const Blob& updated_weights = *updated_params[0]; - const Blob& updated_bias = *updated_params[1]; - - Net& net = *this->solver_->net(); - ASSERT_TRUE(net.has_layer("innerprod")); - const vector > >& param_blobs = - net.layer_by_name("innerprod")->blobs(); - ASSERT_EQ(2, param_blobs.size()); - const Blob& solver_updated_weights = *param_blobs[0]; - ASSERT_EQ(D, solver_updated_weights.count()); - const double kPrecision = 1e-3; - const double kMinPrecision = 1e-7; - for (int i = 0; i < D; ++i) { - const Dtype expected_updated_weight = updated_weights.cpu_data()[i]; - const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i]; - const Dtype error_margin = std::max(kMinPrecision, kPrecision * - std::min(fabs(expected_updated_weight), fabs(solver_updated_weight))); - EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin); - } - const Blob& solver_updated_bias_blob = *param_blobs[1]; - ASSERT_EQ(1, solver_updated_bias_blob.count()); - const Dtype expected_updated_bias = updated_bias.cpu_data()[0]; - const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0]; - const Dtype error_margin = std::max(kMinPrecision, kPrecision * - std::min(fabs(expected_updated_bias), fabs(solver_updated_bias))); - EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); - - // Check the solver's history -- should contain the previous update value. -// vector > >& history = this->solver_->history(); -// ASSERT_EQ(2, history.size()); -// for (int i = 0; i < D; ++i) { -// const Dtype expected_history = updated_weights.cpu_diff()[i]; -// const Dtype solver_history = history[0]->cpu_data()[i]; -// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * -// std::min(fabs(expected_history), fabs(solver_history))); -// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); -// } -// const Dtype expected_history = updated_bias.cpu_diff()[0]; -// const Dtype solver_history = history[1]->cpu_data()[0]; -// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * -// std::min(fabs(expected_history), fabs(solver_history))); -// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); - } - - // Test that the correct update is computed for a regularized least squares - // problem: - // - // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2 - // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w - // - // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1) - // w \in R^{(d+1) x 1} ((d+1)th element is the bias) - // y \in R^{n x 1} - // lambda is weight_decay - // - // TestLeastSquaresUpdate works "inductively", assuming that the solver - // correctly updates the net K (= iter_to_check) times, then given the history - // from the Kth update, we compute the (K+1)th update and check that it - // matches the solver's (K+1)th update. - void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, - const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, - const int iter_to_check = 0) { - // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); - - // Compute the (K+1)th update using the analytic least squares gradient. - vector > > updated_params; - ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); - - // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); - - // Check that the solver's solution matches ours. - CheckLeastSquaresUpdate(updated_params); - } -}; - -TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); - -TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) { - typedef typename TypeParam::Dtype Dtype; - this->TestLeastSquaresUpdate(); -} - -TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneTenth) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.1; - this->TestLeastSquaresUpdate(kLearningRate); -} - -TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; - const Dtype kWeightDecay = 0.5; - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); -} -/* -TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithMomentum) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.5; - const int kNumIters = 1; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - -TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.5; - const int kNumIters = 5; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} -*/ -TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.01; - const Dtype kWeightDecay = 0.1; - const Dtype kMomentum = 0.0; - const int kNumIters = 5; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - -} // namespace caffe diff --git a/src/caffe/test/test_sgd_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp similarity index 66% rename from src/caffe/test/test_sgd_solver.cpp rename to src/caffe/test/test_gradient_based_solver.cpp index 1ec24b134c6..9c21499ea5e 100644 --- a/src/caffe/test/test_sgd_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -18,24 +18,20 @@ using std::ostringstream; namespace caffe { template -class SGDSolverTest : public MultiDeviceTest { +class GradientBasedSolverTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; protected: - SGDSolverTest() : + GradientBasedSolverTest() : seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} - // MockSGDSolver: an SGDSolver with public history. - class MockSGDSolver : public SGDSolver { - public: - explicit MockSGDSolver(const SolverParameter& param) : - SGDSolver(param) {} - vector > >& history() { return this->history_; } - }; - - shared_ptr solver_; + shared_ptr > solver_; int seed_; int num_, channels_, height_, width_; + Dtype delta_; // Stability constant for AdaGrad. + + virtual SolverParameter_SolverType solver_type() = 0; + virtual void InitSolver(const SolverParameter& param) = 0; virtual void InitSolverFromProtoString(const string& proto) { SolverParameter param; @@ -54,7 +50,9 @@ class SGDSolverTest : public MultiDeviceTest { default: LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); } - solver_.reset(new MockSGDSolver(param)); + InitSolver(param); + delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) ? + param.delta() : 0; } void RunLeastSquaresSolver(const Dtype learning_rate, @@ -189,16 +187,32 @@ class SGDSolverTest : public MultiDeviceTest { // Add the weight decay to the gradient. grad += weight_decay * ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); - // Finally, add any momentum. + // Finally, compute update. const vector > >& history = solver_->history(); ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias Dtype update_value = learning_rate * grad; + const Dtype history_value = (i == D) ? + history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; + const Dtype temp = momentum * history_value; + switch (solver_type()) { + case SolverParameter_SolverType_SGD: + update_value += temp; + break; + case SolverParameter_SolverType_NESTEROV: + update_value += temp; + // step back then over-step + update_value = (1 + momentum) * update_value - temp; + break; + case SolverParameter_SolverType_ADAGRAD: + update_value /= std::sqrt(history_value + grad * grad) + delta_; + break; + default: + LOG(FATAL) << "Unknown solver type: " << solver_type(); + } if (i == D) { - update_value += momentum * history[1]->cpu_data()[0]; updated_bias.mutable_cpu_diff()[0] = update_value; updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; } else { - update_value += momentum * history[0]->cpu_data()[i]; updated_weights.mutable_cpu_diff()[i] = update_value; updated_weights.mutable_cpu_data()[i] = weights.cpu_data()[i] - update_value; @@ -238,20 +252,22 @@ class SGDSolverTest : public MultiDeviceTest { EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); // Check the solver's history -- should contain the previous update value. - vector > >& history = this->solver_->history(); - ASSERT_EQ(2, history.size()); - for (int i = 0; i < D; ++i) { - const Dtype expected_history = updated_weights.cpu_diff()[i]; - const Dtype solver_history = history[0]->cpu_data()[i]; + if (solver_type() == SolverParameter_SolverType_SGD) { + const vector > >& history = solver_->history(); + ASSERT_EQ(2, history.size()); + for (int i = 0; i < D; ++i) { + const Dtype expected_history = updated_weights.cpu_diff()[i]; + const Dtype solver_history = history[0]->cpu_data()[i]; + const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_history), fabs(solver_history))); + EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + const Dtype expected_history = updated_bias.cpu_diff()[0]; + const Dtype solver_history = history[1]->cpu_data()[0]; const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * std::min(fabs(expected_history), fabs(solver_history))); EXPECT_NEAR(expected_history, solver_history, error_margin_hist); } - const Dtype expected_history = updated_bias.cpu_diff()[0]; - const Dtype solver_history = history[1]->cpu_data()[0]; - const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * - std::min(fabs(expected_history), fabs(solver_history))); - EXPECT_NEAR(expected_history, solver_history, error_margin_hist); } // Test that the correct update is computed for a regularized least squares @@ -289,6 +305,21 @@ class SGDSolverTest : public MultiDeviceTest { } }; + +template +class SGDSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new SGDSolver(param)); + } + + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_SGD; + } +}; + TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices); TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) { @@ -342,4 +373,116 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) { } } + +template +class AdaGradSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new AdaGradSolver(param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADAGRAD; + } +}; + +TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.0; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + + +template +class NesterovSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new NesterovSolver(param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_NESTEROV; + } +}; + +TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + } // namespace caffe diff --git a/src/caffe/test/test_nesterov_solver.cpp b/src/caffe/test/test_nesterov_solver.cpp deleted file mode 100644 index f2fcba30b29..00000000000 --- a/src/caffe/test/test_nesterov_solver.cpp +++ /dev/null @@ -1,351 +0,0 @@ -#include -#include -#include -#include - -#include "google/protobuf/text_format.h" - -#include "gtest/gtest.h" - -#include "caffe/common.hpp" -#include "caffe/proto/caffe.pb.h" -#include "caffe/solver.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -using std::ostringstream; - -namespace caffe { - -template -class NesterovSolverTest : public MultiDeviceTest { - typedef typename TypeParam::Dtype Dtype; - - protected: - NesterovSolverTest() : - seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} - - // MockNesterovSolver: an NesterovSolver with public history. - class MockNesterovSolver : public NesterovSolver { - public: - explicit MockNesterovSolver(const SolverParameter& param) : - NesterovSolver(param) {} - vector > >& history() { return this->history_; } - }; - - shared_ptr solver_; - int seed_; - int num_, channels_, height_, width_; - - virtual void InitSolverFromProtoString(const string& proto) { - SolverParameter param; - CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); - // Disable saving a final snapshot so the tests don't pollute the user's - // working directory with useless snapshots. - param.set_snapshot_after_train(false); - // Set the solver_mode according to current Caffe::mode. - switch (Caffe::mode()) { - case Caffe::CPU: - param.set_solver_mode(SolverParameter_SolverMode_CPU); - break; - case Caffe::GPU: - param.set_solver_mode(SolverParameter_SolverMode_GPU); - break; - default: - LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); - } - solver_.reset(new MockNesterovSolver(param)); - } - - void RunLeastSquaresSolver(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, const int num_iters) { - ostringstream proto; - proto << - "max_iter: " << num_iters << " " - "base_lr: " << learning_rate << " " - "lr_policy: 'fixed' " - "net_param { " - " name: 'TestNetwork' " - " layers: { " - " name: 'data' " - " type: DUMMY_DATA " - " dummy_data_param { " - " num: " << num_ << " " - " channels: " << channels_ << " " - " height: " << height_ << " " - " width: " << width_ << " " - " channels: 1 " - " height: 1 " - " width: 1 " - " data_filler { " - " type: 'gaussian' " - " std: 1.0 " - " } " - " } " - " top: 'data' " - " top: 'targets' " - " } " - " layers: { " - " name: 'innerprod' " - " type: INNER_PRODUCT " - " inner_product_param { " - " num_output: 1 " - " weight_filler { " - " type: 'gaussian' " - " std: 1.0 " - " } " - " bias_filler { " - " type: 'gaussian' " - " std: 1.0 " - " } " - " } " - " bottom: 'data' " - " top: 'innerprod' " - " } " - " layers: { " - " name: 'loss' " - " type: EUCLIDEAN_LOSS " - " bottom: 'innerprod' " - " bottom: 'targets' " - " } " - "} "; - if (weight_decay != 0) { - proto << "weight_decay: " << weight_decay << " "; - } - if (momentum != 0) { - proto << "momentum: " << momentum << " "; - } - Caffe::set_random_seed(this->seed_); - this->InitSolverFromProtoString(proto.str()); - this->solver_->Solve(); - } - - // Compute an update value given the current state of the train net, - // using the analytical formula for the least squares gradient. - // updated_params will store the updated weight and bias results, - // using the blobs' diffs to hold the update values themselves. - void ComputeLeastSquaresUpdate(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, - vector > >* updated_params) { - const int N = num_; - const int D = channels_ * height_ * width_; - - // Run a forward pass, and manually compute the update values from the - // result. - Net& net = *this->solver_->net(); - vector*> empty_bottom_vec; - net.Forward(empty_bottom_vec); - ASSERT_TRUE(net.has_blob("data")); - const Blob& data = *net.blob_by_name("data"); - ASSERT_TRUE(net.has_blob("targets")); - const Blob& targets = *net.blob_by_name("targets"); - ASSERT_TRUE(net.has_layer("innerprod")); - const vector > >& param_blobs = - net.layer_by_name("innerprod")->blobs(); - const int num_param_blobs = 2; - ASSERT_EQ(num_param_blobs, param_blobs.size()); - const Blob& weights = *param_blobs[0]; - const Blob& bias = *param_blobs[1]; - ASSERT_EQ(D * N, data.count()); - ASSERT_EQ(N, targets.count()); - ASSERT_EQ(D, weights.count()); - ASSERT_EQ(1, bias.count()); - - updated_params->clear(); - updated_params->resize(num_param_blobs); - for (int i = 0; i < num_param_blobs; ++i) { - (*updated_params)[i].reset(new Blob()); - } - Blob& updated_weights = *(*updated_params)[0]; - updated_weights.ReshapeLike(weights); - Blob& updated_bias = *(*updated_params)[1]; - updated_bias.ReshapeLike(bias); - - for (int i = 0; i <= D; ++i) { - // Compute the derivative with respect to the ith weight (i.e., the ith - // element of the gradient). - Dtype grad = 0; - for (int j = 0; j <= D; ++j) { - // Compute element (i, j) of X^T * X. - Dtype element = 0; - for (int k = 0; k < N; ++k) { - // (i, k) in X^T (== (k, i) in X) times (k, j) in X. - const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; - const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; - element += element_i * element_j; - } - if (j == D) { - grad += element * bias.cpu_data()[0]; - } else { - grad += element * weights.cpu_data()[j]; - } - } - for (int k = 0; k < N; ++k) { - const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; - grad -= element_i * targets.cpu_data()[k]; - } - // Scale the gradient over the N samples. - grad /= N; - // Add the weight decay to the gradient. - grad += weight_decay * - ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); - // Finally, add any momentum. - const vector > >& history = solver_->history(); - ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias - Dtype update_value = learning_rate * grad, temp; - if (i == D) { - temp = history[1]->cpu_data()[0] * momentum; - update_value += temp; // update history - // step back then over-step - update_value = (1 + momentum) * update_value - temp; - updated_bias.mutable_cpu_diff()[0] = update_value; - updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; - } else { - temp = history[0]->cpu_data()[i] * momentum; - update_value += temp; // update history - // step back then over-step - update_value = (1 + momentum) * update_value - temp; - updated_weights.mutable_cpu_diff()[i] = update_value; - updated_weights.mutable_cpu_data()[i] = - weights.cpu_data()[i] - update_value; - } - } - } - - void CheckLeastSquaresUpdate( - const vector > >& updated_params) { - const int D = channels_ * height_ * width_; - - const Blob& updated_weights = *updated_params[0]; - const Blob& updated_bias = *updated_params[1]; - - Net& net = *this->solver_->net(); - ASSERT_TRUE(net.has_layer("innerprod")); - const vector > >& param_blobs = - net.layer_by_name("innerprod")->blobs(); - ASSERT_EQ(2, param_blobs.size()); - const Blob& solver_updated_weights = *param_blobs[0]; - ASSERT_EQ(D, solver_updated_weights.count()); - const double kPrecision = 1e-3; - const double kMinPrecision = 1e-7; - for (int i = 0; i < D; ++i) { - const Dtype expected_updated_weight = updated_weights.cpu_data()[i]; - const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i]; - const Dtype error_margin = std::max(kMinPrecision, kPrecision * - std::min(fabs(expected_updated_weight), fabs(solver_updated_weight))); - EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin); - } - const Blob& solver_updated_bias_blob = *param_blobs[1]; - ASSERT_EQ(1, solver_updated_bias_blob.count()); - const Dtype expected_updated_bias = updated_bias.cpu_data()[0]; - const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0]; - const Dtype error_margin = std::max(kMinPrecision, kPrecision * - std::min(fabs(expected_updated_bias), fabs(solver_updated_bias))); - EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); - - // Check the solver's history -- should contain the previous update value. -// vector > >& history = this->solver_->history(); -// ASSERT_EQ(2, history.size()); -// for (int i = 0; i < D; ++i) { -// const Dtype expected_history = updated_weights.cpu_diff()[i]; -// const Dtype solver_history = history[0]->cpu_data()[i]; -// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * -// std::min(fabs(expected_history), fabs(solver_history))); -// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); -// } -// const Dtype expected_history = updated_bias.cpu_diff()[0]; -// const Dtype solver_history = history[1]->cpu_data()[0]; -// const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * -// std::min(fabs(expected_history), fabs(solver_history))); -// EXPECT_NEAR(expected_history, solver_history, error_margin_hist); - } - - // Test that the correct update is computed for a regularized least squares - // problem: - // - // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2 - // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w - // - // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1) - // w \in R^{(d+1) x 1} ((d+1)th element is the bias) - // y \in R^{n x 1} - // lambda is weight_decay - // - // TestLeastSquaresUpdate works "inductively", assuming that the solver - // correctly updates the net K (= iter_to_check) times, then given the history - // from the Kth update, we compute the (K+1)th update and check that it - // matches the solver's (K+1)th update. - void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, - const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, - const int iter_to_check = 0) { - // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); - - // Compute the (K+1)th update using the analytic least squares gradient. - vector > > updated_params; - ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); - - // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); - - // Check that the solver's solution matches ours. - CheckLeastSquaresUpdate(updated_params); - } -}; - -TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); - -TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) { - typedef typename TypeParam::Dtype Dtype; - this->TestLeastSquaresUpdate(); -} - -TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneTenth) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.1; - this->TestLeastSquaresUpdate(kLearningRate); -} - -TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; - const Dtype kWeightDecay = 0.5; - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); -} - -TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.5; - const int kNumIters = 1; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - -TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; - const Dtype kWeightDecay = 0.0; - const Dtype kMomentum = 0.5; - const int kNumIters = 5; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - -TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) { - typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 0.01; - const Dtype kWeightDecay = 0.1; - const Dtype kMomentum = 0.9; - const int kNumIters = 5; - for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); - } -} - -} // namespace caffe From c1ff97c947af1a6cea0d9f47b397e87b74334d87 Mon Sep 17 00:00:00 2001 From: qipeng Date: Mon, 25 Aug 2014 23:02:56 -0700 Subject: [PATCH 15/23] Added sanity check for AdaGradSolver; added MNIST examples for solvers --- .../mnist_autoencoder_solver_adagrad.prototxt | 15 +++++++++++++++ .../mnist_autoencoder_solver_nesterov.prototxt | 15 +++++++++++++++ examples/mnist/train_mnist_autoencoder_adagrad.sh | 4 ++++ .../mnist/train_mnist_autoencoder_nesterov.sh | 4 ++++ include/caffe/solver.hpp | 7 +++++-- 5 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 examples/mnist/mnist_autoencoder_solver_adagrad.prototxt create mode 100644 examples/mnist/mnist_autoencoder_solver_nesterov.prototxt create mode 100755 examples/mnist/train_mnist_autoencoder_adagrad.sh create mode 100755 examples/mnist/train_mnist_autoencoder_nesterov.sh diff --git a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt new file mode 100644 index 00000000000..6193351fcd2 --- /dev/null +++ b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt @@ -0,0 +1,15 @@ +net: "mnist_autoencoder.prototxt" +test_iter: 50 +test_interval: 100 +test_compute_loss: true +base_lr: 0.01 +lr_policy: "fixed" +display: 20 +max_iter: 4000000 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "mnist_autoencoder_train" +momentum: 0.9 +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADAGRAD diff --git a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt new file mode 100644 index 00000000000..17487301554 --- /dev/null +++ b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt @@ -0,0 +1,15 @@ +net: "mnist_autoencoder.prototxt" +test_iter: 50 +test_interval: 100 +test_compute_loss: true +base_lr: 0.0001 +lr_policy: "fixed" +display: 20 +max_iter: 4000000 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "mnist_autoencoder_train" +momentum: 0.95 +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: NESTEROV diff --git a/examples/mnist/train_mnist_autoencoder_adagrad.sh b/examples/mnist/train_mnist_autoencoder_adagrad.sh new file mode 100755 index 00000000000..628c74b969a --- /dev/null +++ b/examples/mnist/train_mnist_autoencoder_adagrad.sh @@ -0,0 +1,4 @@ +#!/bin/bash +TOOLS=../../build/tools + +$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver.prototxt diff --git a/examples/mnist/train_mnist_autoencoder_nesterov.sh b/examples/mnist/train_mnist_autoencoder_nesterov.sh new file mode 100755 index 00000000000..8f004c4635d --- /dev/null +++ b/examples/mnist/train_mnist_autoencoder_nesterov.sh @@ -0,0 +1,4 @@ +#!/bin/bash +TOOLS=../../build/tools + +$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver_nesterov.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index b3c7f6f35d6..a73ba2ccd9f 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -100,12 +100,15 @@ template class AdaGradSolver : public SGDSolver { public: explicit AdaGradSolver(const SolverParameter& param) - : SGDSolver(param) {} + : SGDSolver(param) { constructor_sanity_check(); } explicit AdaGradSolver(const string& param_file) - : SGDSolver(param_file) {} + : SGDSolver(param_file) { constructor_sanity_check(); } protected: virtual void ComputeUpdateValue(); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with AdaGrad."; + } DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; From 06f335f04b16321d73eeb667c39afe65798fef93 Mon Sep 17 00:00:00 2001 From: qipeng Date: Mon, 25 Aug 2014 23:17:51 -0700 Subject: [PATCH 16/23] lint --- include/caffe/solver.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index a73ba2ccd9f..5a7f31be017 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -107,7 +107,8 @@ class AdaGradSolver : public SGDSolver { protected: virtual void ComputeUpdateValue(); void constructor_sanity_check() { - CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with AdaGrad."; + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with AdaGrad."; } DISABLE_COPY_AND_ASSIGN(AdaGradSolver); From a464df45a782e2cd45412744de9c0abbd671df6a Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 26 Aug 2014 12:01:26 -0700 Subject: [PATCH 17/23] Re-added solver switch into the new caffe main excutable; fixed AdaGrad MNIST example --- .../mnist/mnist_autoencoder_solver_adagrad.prototxt | 1 - examples/mnist/train_mnist_autoencoder_adagrad.sh | 2 +- tools/caffe.cpp | 10 +++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt index 6193351fcd2..fa7d65cd70c 100644 --- a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt @@ -9,7 +9,6 @@ max_iter: 4000000 weight_decay: 0.0005 snapshot: 10000 snapshot_prefix: "mnist_autoencoder_train" -momentum: 0.9 # solver mode: CPU or GPU solver_mode: GPU solver_type: ADAGRAD diff --git a/examples/mnist/train_mnist_autoencoder_adagrad.sh b/examples/mnist/train_mnist_autoencoder_adagrad.sh index 628c74b969a..25a48c3c442 100755 --- a/examples/mnist/train_mnist_autoencoder_adagrad.sh +++ b/examples/mnist/train_mnist_autoencoder_adagrad.sh @@ -1,4 +1,4 @@ #!/bin/bash TOOLS=../../build/tools -$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver.prototxt +$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver_adagrad.prototxt diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 5b3ad0b4691..9958ac3701a 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -88,7 +88,7 @@ int train() { caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); LOG(INFO) << "Starting Optimization"; - caffe::SGDSolver solver(solver_param); + shared_ptr> solver(caffe::GetSolver(solver_param)); // Set device id and mode if (FLAGS_gpu >= 0) { @@ -102,13 +102,13 @@ int train() { if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; - solver.Solve(FLAGS_snapshot); + solver->Solve(FLAGS_snapshot); } else if (FLAGS_weights.size()) { LOG(INFO) << "Finetuning from " << FLAGS_weights; - solver.net()->CopyTrainedLayersFrom(FLAGS_weights); - solver.Solve(); + solver->net()->CopyTrainedLayersFrom(FLAGS_weights); + solver->Solve(); } else { - solver.Solve(); + solver->Solve(); } LOG(INFO) << "Optimization Done."; return 0; From 36f9de4303ddb0166d34a2104ec70abe0493d924 Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 26 Aug 2014 12:02:20 -0700 Subject: [PATCH 18/23] lint --- tools/caffe.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 9958ac3701a..9c49a0a9d15 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -88,7 +88,8 @@ int train() { caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); LOG(INFO) << "Starting Optimization"; - shared_ptr> solver(caffe::GetSolver(solver_param)); + shared_ptr> + solver(caffe::GetSolver(solver_param)); // Set device id and mode if (FLAGS_gpu >= 0) { From 972264946cf8896689facfbec2f683347b1a8f7e Mon Sep 17 00:00:00 2001 From: qipeng Date: Tue, 26 Aug 2014 12:21:06 -0700 Subject: [PATCH 19/23] hot fix for warning --- tools/caffe.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 9c49a0a9d15..10ba70d152e 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -88,7 +88,7 @@ int train() { caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); LOG(INFO) << "Starting Optimization"; - shared_ptr> + shared_ptr > solver(caffe::GetSolver(solver_param)); // Set device id and mode From 5894f039a454c249ba65141b9c20f9edc9b71627 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Mon, 1 Sep 2014 11:39:39 -0700 Subject: [PATCH 20/23] mnist_autoencoder: always compute both cross-entropy loss and L2 (euclidean) error --- examples/mnist/mnist_autoencoder.prototxt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mnist/mnist_autoencoder.prototxt b/examples/mnist/mnist_autoencoder.prototxt index b6009fb50a0..b2bce47b008 100644 --- a/examples/mnist/mnist_autoencoder.prototxt +++ b/examples/mnist/mnist_autoencoder.prototxt @@ -248,22 +248,22 @@ layers { layers { bottom: "decode1" bottom: "flatdata" + top: "cross_entropy_loss" name: "loss" type: SIGMOID_CROSS_ENTROPY_LOSS - include: { phase: TRAIN } + loss_weight: 1 } layers { bottom: "decode1" top: "decode1neuron" name: "decode1neuron" type: SIGMOID - include: { phase: TEST } } layers { bottom: "decode1neuron" bottom: "flatdata" + top: "l2_error" name: "loss" type: EUCLIDEAN_LOSS - top: "loss" - include: { phase: TEST } + loss_weight: 0 } From b49b2d3e8e0f5ebe453404c79ac9528f426ab2f0 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Mon, 1 Sep 2014 12:04:31 -0700 Subject: [PATCH 21/23] Add "test-on-train" stage to test accuracy on the training data; correct test_iter (should be 100 instead of 50) --- examples/mnist/mnist_autoencoder.prototxt | 22 ++++++++++++++++++- .../mnist/mnist_autoencoder_solver.prototxt | 7 ++++-- .../mnist_autoencoder_solver_adagrad.prototxt | 9 +++++--- ...mnist_autoencoder_solver_nesterov.prototxt | 11 ++++++---- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/examples/mnist/mnist_autoencoder.prototxt b/examples/mnist/mnist_autoencoder.prototxt index b2bce47b008..45d0802339e 100644 --- a/examples/mnist/mnist_autoencoder.prototxt +++ b/examples/mnist/mnist_autoencoder.prototxt @@ -13,6 +13,23 @@ layers { } include: { phase: TRAIN } } +layers { + top: "data" + name: "data" + type: DATA + data_param { + source: "examples/mnist/mnist_train_lmdb" + backend: LMDB + batch_size: 100 + transform_param { + scale: 0.0039215684 + } + } + include: { + phase: TEST + stage: 'test-on-train' + } +} layers { top: "data" name: "data" @@ -25,7 +42,10 @@ layers { scale: 0.0039215684 } } - include: { phase: TEST } + include: { + phase: TEST + stage: 'test-on-test' + } } layers { bottom: "data" diff --git a/examples/mnist/mnist_autoencoder_solver.prototxt b/examples/mnist/mnist_autoencoder_solver.prototxt index af1202fc1fd..be0939d92db 100644 --- a/examples/mnist/mnist_autoencoder_solver.prototxt +++ b/examples/mnist/mnist_autoencoder_solver.prototxt @@ -1,6 +1,9 @@ net: "examples/mnist/mnist_autoencoder.prototxt" -test_iter: 50 -test_interval: 100 +test_state: { stage: 'test-on-train' } +test_iter: 500 +test_state: { stage: 'test-on-test' } +test_iter: 100 +test_interval: 500 test_compute_loss: true base_lr: 0.0001 lr_policy: "fixed" diff --git a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt index fa7d65cd70c..641ce8a001c 100644 --- a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt @@ -1,6 +1,9 @@ -net: "mnist_autoencoder.prototxt" -test_iter: 50 -test_interval: 100 +net: "examples/mnist/mnist_autoencoder.prototxt" +test_state: { stage: 'test-on-train' } +test_iter: 500 +test_state: { stage: 'test-on-test' } +test_iter: 100 +test_interval: 500 test_compute_loss: true base_lr: 0.01 lr_policy: "fixed" diff --git a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt index 17487301554..254dceecf9b 100644 --- a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt @@ -1,6 +1,9 @@ -net: "mnist_autoencoder.prototxt" -test_iter: 50 -test_interval: 100 +net: "examples/mnist/mnist_autoencoder.prototxt" +test_state: { stage: 'test-on-train' } +test_iter: 500 +test_state: { stage: 'test-on-test' } +test_iter: 100 +test_interval: 500 test_compute_loss: true base_lr: 0.0001 lr_policy: "fixed" @@ -8,7 +11,7 @@ display: 20 max_iter: 4000000 weight_decay: 0.0005 snapshot: 10000 -snapshot_prefix: "mnist_autoencoder_train" +snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train" momentum: 0.95 # solver mode: CPU or GPU solver_mode: GPU From eaf28fe5758164c4af73f9a721e41ee7952488b0 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Mon, 1 Sep 2014 12:32:54 -0700 Subject: [PATCH 22/23] make adagrad/nesterov train scripts follow new "run-from-root" convention --- examples/mnist/train_mnist_autoencoder_adagrad.sh | 4 ++-- examples/mnist/train_mnist_autoencoder_nesterov.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/mnist/train_mnist_autoencoder_adagrad.sh b/examples/mnist/train_mnist_autoencoder_adagrad.sh index 25a48c3c442..95fe1b17bd5 100755 --- a/examples/mnist/train_mnist_autoencoder_adagrad.sh +++ b/examples/mnist/train_mnist_autoencoder_adagrad.sh @@ -1,4 +1,4 @@ #!/bin/bash -TOOLS=../../build/tools -$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver_adagrad.prototxt +./build/tools/caffe train \ + --solver=examples/mnist/mnist_autoencoder_solver_adagrad.prototxt diff --git a/examples/mnist/train_mnist_autoencoder_nesterov.sh b/examples/mnist/train_mnist_autoencoder_nesterov.sh index 8f004c4635d..cf19ea749b3 100755 --- a/examples/mnist/train_mnist_autoencoder_nesterov.sh +++ b/examples/mnist/train_mnist_autoencoder_nesterov.sh @@ -1,4 +1,4 @@ #!/bin/bash -TOOLS=../../build/tools -$TOOLS/caffe.bin train --solver=mnist_autoencoder_solver_nesterov.prototxt +./build/tools/caffe train \ + --solver=examples/mnist/mnist_autoencoder_solver_nesterov.prototxt From b0f97fda761b8687503a36f4545bedd6497f3c92 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Mon, 1 Sep 2014 12:35:03 -0700 Subject: [PATCH 23/23] make MNIST autoencoder solvers start from base_lr 0.01 and step (much better performance) and terminate at iter 65K --- examples/mnist/mnist_autoencoder_solver.prototxt | 10 ++++++---- .../mnist/mnist_autoencoder_solver_adagrad.prototxt | 4 ++-- .../mnist/mnist_autoencoder_solver_nesterov.prototxt | 10 ++++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/mnist/mnist_autoencoder_solver.prototxt b/examples/mnist/mnist_autoencoder_solver.prototxt index be0939d92db..6e35cb6a32d 100644 --- a/examples/mnist/mnist_autoencoder_solver.prototxt +++ b/examples/mnist/mnist_autoencoder_solver.prototxt @@ -5,10 +5,12 @@ test_state: { stage: 'test-on-test' } test_iter: 100 test_interval: 500 test_compute_loss: true -base_lr: 0.0001 -lr_policy: "fixed" -display: 20 -max_iter: 4000000 +base_lr: 0.01 +lr_policy: "step" +gamma: 0.1 +stepsize: 10000 +display: 100 +max_iter: 65000 weight_decay: 0.0005 snapshot: 10000 snapshot_prefix: "examples/mnist/mnist_autoencoder" diff --git a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt index 641ce8a001c..b18a0cf5f27 100644 --- a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt @@ -7,8 +7,8 @@ test_interval: 500 test_compute_loss: true base_lr: 0.01 lr_policy: "fixed" -display: 20 -max_iter: 4000000 +display: 100 +max_iter: 65000 weight_decay: 0.0005 snapshot: 10000 snapshot_prefix: "mnist_autoencoder_train" diff --git a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt index 254dceecf9b..2a59fd45c8d 100644 --- a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt @@ -5,10 +5,12 @@ test_state: { stage: 'test-on-test' } test_iter: 100 test_interval: 500 test_compute_loss: true -base_lr: 0.0001 -lr_policy: "fixed" -display: 20 -max_iter: 4000000 +base_lr: 0.01 +lr_policy: "step" +gamma: 0.1 +stepsize: 10000 +display: 100 +max_iter: 65000 weight_decay: 0.0005 snapshot: 10000 snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train"