Skip to content

Commit

Permalink
Merge pull request #2037 from shelhamer/expose-solver-restore
Browse files Browse the repository at this point in the history
Expose Solver::Restore() as public for restoring without solving
  • Loading branch information
shelhamer committed Mar 7, 2015
2 parents 65d84a5 + c0219cc commit 4d0103b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
8 changes: 4 additions & 4 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class Solver {
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
// function that restores the state from a SolverState protocol buffer.
void Restore(const char* resume_file);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
Expand All @@ -46,10 +50,6 @@ class Solver {
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(SolverState* state) = 0;
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
// function that restores the state from a SolverState protocol buffer.
void Restore(const char* resume_file);
virtual void RestoreSolverState(const SolverState& state) = 0;
void DisplayOutputBlobs(const int net_id);

Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("iter", &Solver<Dtype>::iter)
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
&Solver<Dtype>::Solve), SolveOverloads())
.def("step", &Solver<Dtype>::Step);
.def("step", &Solver<Dtype>::Step)
.def("restore", &Solver<Dtype>::Restore);

bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
Expand Down

0 comments on commit 4d0103b

Please sign in to comment.