Skip to content

Commit

Permalink
Merge branch 'qipeng-solvers' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdonahue committed Sep 1, 2014
2 parents 5f350aa + b0f97fd commit 88e1797
Show file tree
Hide file tree
Showing 12 changed files with 647 additions and 61 deletions.
30 changes: 25 additions & 5 deletions examples/mnist/mnist_autoencoder.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ layers {
}
include: { phase: TRAIN }
}
layers {
top: "data"
name: "data"
type: DATA
data_param {
source: "examples/mnist/mnist_train_lmdb"
backend: LMDB
batch_size: 100
transform_param {
scale: 0.0039215684
}
}
include: {
phase: TEST
stage: 'test-on-train'
}
}
layers {
top: "data"
name: "data"
Expand All @@ -25,7 +42,10 @@ layers {
scale: 0.0039215684
}
}
include: { phase: TEST }
include: {
phase: TEST
stage: 'test-on-test'
}
}
layers {
bottom: "data"
Expand Down Expand Up @@ -248,22 +268,22 @@ layers {
layers {
bottom: "decode1"
bottom: "flatdata"
top: "cross_entropy_loss"
name: "loss"
type: SIGMOID_CROSS_ENTROPY_LOSS
include: { phase: TRAIN }
loss_weight: 1
}
layers {
bottom: "decode1"
top: "decode1neuron"
name: "decode1neuron"
type: SIGMOID
include: { phase: TEST }
}
layers {
bottom: "decode1neuron"
bottom: "flatdata"
top: "l2_error"
name: "loss"
type: EUCLIDEAN_LOSS
top: "loss"
include: { phase: TEST }
loss_weight: 0
}
17 changes: 11 additions & 6 deletions examples/mnist/mnist_autoencoder_solver.prototxt
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
net: "examples/mnist/mnist_autoencoder.prototxt"
test_iter: 50
test_interval: 100
test_state: { stage: 'test-on-train' }
test_iter: 500
test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 0.0001
lr_policy: "fixed"
display: 20
max_iter: 4000000
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 10000
display: 100
max_iter: 65000
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "examples/mnist/mnist_autoencoder"
Expand Down
17 changes: 17 additions & 0 deletions examples/mnist/mnist_autoencoder_solver_adagrad.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
net: "examples/mnist/mnist_autoencoder.prototxt"
test_state: { stage: 'test-on-train' }
test_iter: 500
test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 0.01
lr_policy: "fixed"
display: 100
max_iter: 65000
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "mnist_autoencoder_train"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: ADAGRAD
20 changes: 20 additions & 0 deletions examples/mnist/mnist_autoencoder_solver_nesterov.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
net: "examples/mnist/mnist_autoencoder.prototxt"
test_state: { stage: 'test-on-train' }
test_iter: 500
test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 10000
display: 100
max_iter: 65000
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train"
momentum: 0.95
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: NESTEROV
4 changes: 4 additions & 0 deletions examples/mnist/train_mnist_autoencoder_adagrad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

./build/tools/caffe train \
--solver=examples/mnist/mnist_autoencoder_solver_adagrad.prototxt
4 changes: 4 additions & 0 deletions examples/mnist/train_mnist_autoencoder_nesterov.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

./build/tools/caffe train \
--solver=examples/mnist/mnist_autoencoder_solver_nesterov.prototxt
55 changes: 54 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,71 @@ class SGDSolver : public Solver<Dtype> {
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) {}

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
virtual void PreSolve();
Dtype GetLearningRate();
virtual void ComputeUpdateValue();
virtual void SnapshotSolverState(SolverState * state);
virtual void RestoreSolverState(const SolverState& state);
// history maintains the historical momentum data.
vector<shared_ptr<Blob<Dtype> > > history_;
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
// of gradients/updates and is not needed in snapshots
vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;

DISABLE_COPY_AND_ASSIGN(SGDSolver);
};

template <typename Dtype>
class NesterovSolver : public SGDSolver<Dtype> {
public:
explicit NesterovSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) {}
explicit NesterovSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) {}

protected:
virtual void ComputeUpdateValue();

DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};

template <typename Dtype>
class AdaGradSolver : public SGDSolver<Dtype> {
public:
explicit AdaGradSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }

protected:
virtual void ComputeUpdateValue();
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
}

DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};

template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();

switch (type) {
case SolverParameter_SolverType_SGD:
return new SGDSolver<Dtype>(param);
case SolverParameter_SolverType_NESTEROV:
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
return (Solver<Dtype>*) NULL;
}

} // namespace caffe

Expand Down
15 changes: 14 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 27 (last added: test_state)
// SolverParameter next available ID: 32 (last added: delta)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -116,6 +116,9 @@ message SolverParameter {
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
optional float weight_decay = 12; // The weight decay.
// regularization types supported: L1 and L2
// controled by weight_decay
optional string regularization_type = 29 [default = "L2"];
optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
Expand All @@ -135,6 +138,16 @@ message SolverParameter {
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [default = -1];

// Solver type
enum SolverType {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];

// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false];
Expand Down
Loading

0 comments on commit 88e1797

Please sign in to comment.