Skip to content

Commit

Permalink
Merge pull request #3942 from jeffdonahue/propagate-down-true
Browse files Browse the repository at this point in the history
Set propagate_down=true to force backprop to a particular bottom
  • Loading branch information
jeffdonahue committed Apr 5, 2016
2 parents 389db96 + 77cde9c commit 843575e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 6 deletions.
9 changes: 4 additions & 5 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,11 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
bottom_id_vecs_[layer_id].push_back(blob_id);
available_blobs->erase(blob_name);
bool propagate_down = true;
bool need_backward = blob_need_backward_[blob_id];
// Check if the backpropagation on bottom_id should be skipped
if (layer_param.propagate_down_size() > 0)
propagate_down = layer_param.propagate_down(bottom_id);
const bool need_backward = blob_need_backward_[blob_id] &&
propagate_down;
if (layer_param.propagate_down_size() > 0) {
need_backward = layer_param.propagate_down(bottom_id);
}
bottom_need_backward_[layer_id].push_back(need_backward);
return blob_id;
}
Expand Down
7 changes: 6 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,12 @@ message LayerParameter {
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;

// Specifies on which bottoms the backpropagation should be skipped.
// Specifies whether to backpropagate to each bottom. If unspecified,
// Caffe will automatically infer whether each input needs backpropagation
// to compute parameter gradients. If set to true for some inputs,
// backpropagation to those inputs is forced; if set false for some inputs,
// backpropagation to those inputs is skipped.
//
// The size must be either 0 or equal to the number of bottoms.
repeated bool propagate_down = 11;

Expand Down
102 changes: 102 additions & 0 deletions src/caffe/test/test_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest<TypeParam> {
InitNetFromProtoString(proto);
}

virtual void InitForcePropNet(bool test_force_true) {
string proto =
"name: 'ForcePropTestNetwork' "
"layer { "
" name: 'data' "
" type: 'DummyData' "
" dummy_data_param { "
" shape { "
" dim: 5 "
" dim: 2 "
" dim: 3 "
" dim: 4 "
" } "
" data_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" shape { "
" dim: 5 "
" } "
" data_filler { "
" type: 'constant' "
" value: 0 "
" } "
" } "
" top: 'data' "
" top: 'label' "
"} "
"layer { "
" name: 'innerproduct' "
" type: 'InnerProduct' "
" inner_product_param { "
" num_output: 1 "
" weight_filler { "
" type: 'gaussian' "
" std: 0.01 "
" } "
" } "
" bottom: 'data' "
" top: 'innerproduct' ";
if (test_force_true) {
proto += " propagate_down: true ";
}
proto +=
"} "
"layer { "
" name: 'loss' "
" bottom: 'innerproduct' "
" bottom: 'label' "
" top: 'cross_entropy_loss' "
" type: 'SigmoidCrossEntropyLoss' "
"} ";
InitNetFromProtoString(proto);
}

int seed_;
shared_ptr<Net<Dtype> > net_;
};
Expand Down Expand Up @@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
}
}

TYPED_TEST(NetTest, TestForcePropagateDown) {
this->InitForcePropNet(false);
vector<bool> layer_need_backward = this->net_->layer_need_backward();
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
const string& layer_name = this->net_->layer_names()[layer_id];
const vector<bool> need_backward =
this->net_->bottom_need_backward()[layer_id];
if (layer_name == "data") {
ASSERT_EQ(need_backward.size(), 0);
EXPECT_FALSE(layer_need_backward[layer_id]);
} else if (layer_name == "innerproduct") {
ASSERT_EQ(need_backward.size(), 1);
EXPECT_FALSE(need_backward[0]); // data
EXPECT_TRUE(layer_need_backward[layer_id]);
} else if (layer_name == "loss") {
ASSERT_EQ(need_backward.size(), 2);
EXPECT_TRUE(need_backward[0]); // innerproduct
EXPECT_FALSE(need_backward[1]); // label
EXPECT_TRUE(layer_need_backward[layer_id]);
} else {
LOG(FATAL) << "Unknown layer: " << layer_name;
}
}
this->InitForcePropNet(true);
layer_need_backward = this->net_->layer_need_backward();
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
const string& layer_name = this->net_->layer_names()[layer_id];
const vector<bool> need_backward =
this->net_->bottom_need_backward()[layer_id];
if (layer_name == "data") {
ASSERT_EQ(need_backward.size(), 0);
EXPECT_FALSE(layer_need_backward[layer_id]);
} else if (layer_name == "innerproduct") {
ASSERT_EQ(need_backward.size(), 1);
EXPECT_TRUE(need_backward[0]); // data
EXPECT_TRUE(layer_need_backward[layer_id]);
} else if (layer_name == "loss") {
ASSERT_EQ(need_backward.size(), 2);
EXPECT_TRUE(need_backward[0]); // innerproduct
EXPECT_FALSE(need_backward[1]); // label
EXPECT_TRUE(layer_need_backward[layer_id]);
} else {
LOG(FATAL) << "Unknown layer: " << layer_name;
}
}
}

} // namespace caffe

0 comments on commit 843575e

Please sign in to comment.