Skip to content

Commit

Permalink
Merge pull request #1456 from jyegerlehner/load-weights-from-multiple…
Browse files Browse the repository at this point in the history
…-caffemodels

Load weights from multiple models by listing comma separated caffemodels
as the `-weights` arg to the caffe command.
  • Loading branch information
shelhamer committed Mar 8, 2015
2 parents c3aee35 + a0087e4 commit a9bf7b9
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string>
#include <vector>

#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"

using caffe::Blob;
Expand Down Expand Up @@ -76,6 +77,19 @@ int device_query() {
}
RegisterBrewFunction(device_query);

// Load the weights from the specified caffemodel(s) into the train and
// test nets.
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(",") );
for (int i = 0; i < model_names.size(); ++i) {
LOG(INFO) << "Finetuning from " << model_names[i];
solver->net()->CopyTrainedLayersFrom(model_names[i]);
for (int j = 0; j < solver->test_nets().size(); ++j) {
solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
}
}
}

// Train / Finetune a model.
int train() {
Expand Down Expand Up @@ -112,8 +126,7 @@ int train() {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Solve(FLAGS_snapshot);
} else if (FLAGS_weights.size()) {
LOG(INFO) << "Finetuning from " << FLAGS_weights;
solver->net()->CopyTrainedLayersFrom(FLAGS_weights);
CopyLayers(&*solver, FLAGS_weights);
solver->Solve();
} else {
solver->Solve();
Expand Down

0 comments on commit a9bf7b9

Please sign in to comment.