Skip to content

Commit

Permalink
fix-gpu-test
Browse files Browse the repository at this point in the history
  • Loading branch information
ronghanghu committed Aug 15, 2015
1 parent 65c7fa6 commit 1cb3c44
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,19 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
const int iter_size = 1, const int devices = 1,
const bool snapshot = false, const char* from_snapshot = NULL) {
ostringstream proto;
int device_id = 0;
#ifndef CPU_ONLY
if (Caffe::mode() == Caffe::GPU) {
CUDA_CHECK(cudaGetDevice(&device_id));
}
#endif
proto <<
"snapshot_after_train: " << snapshot << " "
"max_iter: " << num_iters << " "
"base_lr: " << learning_rate << " "
"lr_policy: 'fixed' "
"iter_size: " << iter_size << " "
"device_id: " << device_id << " "
"net_param { "
" name: 'TestNetwork' "
" layer { "
Expand Down Expand Up @@ -188,9 +195,12 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
this->solver_->Solve();
} else {
LOG(INFO) << "Multi-GPU test on " << devices << " devices";
int device_id = solver_->param().device_id();
vector<int> gpus;
for (int i = 0; i < devices; ++i) {
gpus.push_back(i);
gpus.push_back(device_id);
for (int i = 0; gpus.size() < devices; ++i) {
if (i != device_id)
gpus.push_back(i);
}
Caffe::set_solver_count(gpus.size());
this->sync_.reset(new P2PSync<Dtype>(
Expand Down

0 comments on commit 1cb3c44

Please sign in to comment.