diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 23d94c97c07..f0bf594936c 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -427,12 +427,11 @@ int Net::AppendBottom(const NetParameter& param, const int layer_id, bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); - bool propagate_down = true; + bool need_backward = blob_need_backward_[blob_id]; // Check if the backpropagation on bottom_id should be skipped - if (layer_param.propagate_down_size() > 0) - propagate_down = layer_param.propagate_down(bottom_id); - const bool need_backward = blob_need_backward_[blob_id] && - propagate_down; + if (layer_param.propagate_down_size() > 0) { + need_backward = layer_param.propagate_down(bottom_id); + } bottom_need_backward_[layer_id].push_back(need_backward); return blob_id; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 6900bb71482..650c87ae3a6 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -328,7 +328,12 @@ message LayerParameter { // The blobs containing the numeric parameters of the layer. repeated BlobProto blobs = 7; - // Specifies on which bottoms the backpropagation should be skipped. + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // // The size must be either 0 or equal to the number of bottoms. repeated bool propagate_down = 11; diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index 1e0788ec127..92fd317fee8 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest { InitNetFromProtoString(proto); } + virtual void InitForcePropNet(bool test_force_true) { + string proto = + "name: 'ForcePropTestNetwork' " + "layer { " + " name: 'data' " + " type: 'DummyData' " + " dummy_data_param { " + " shape { " + " dim: 5 " + " dim: 2 " + " dim: 3 " + " dim: 4 " + " } " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " shape { " + " dim: 5 " + " } " + " data_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layer { " + " name: 'innerproduct' " + " type: 'InnerProduct' " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " bottom: 'data' " + " top: 'innerproduct' "; + if (test_force_true) { + proto += " propagate_down: true "; + } + proto += + "} " + "layer { " + " name: 'loss' " + " bottom: 'innerproduct' " + " bottom: 'label' " + " top: 'cross_entropy_loss' " + " type: 'SigmoidCrossEntropyLoss' " + "} "; + InitNetFromProtoString(proto); + } + int seed_; shared_ptr > net_; }; @@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) { } } +TYPED_TEST(NetTest, TestForcePropagateDown) { + this->InitForcePropNet(false); + vector layer_need_backward = this->net_->layer_need_backward(); + for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) { + const string& layer_name = this->net_->layer_names()[layer_id]; + const vector need_backward = + this->net_->bottom_need_backward()[layer_id]; + if (layer_name == "data") { + ASSERT_EQ(need_backward.size(), 0); + EXPECT_FALSE(layer_need_backward[layer_id]); + } else if (layer_name == "innerproduct") { + ASSERT_EQ(need_backward.size(), 1); + EXPECT_FALSE(need_backward[0]); // data + EXPECT_TRUE(layer_need_backward[layer_id]); + } else if (layer_name == "loss") { + ASSERT_EQ(need_backward.size(), 2); + EXPECT_TRUE(need_backward[0]); // innerproduct + EXPECT_FALSE(need_backward[1]); // label + EXPECT_TRUE(layer_need_backward[layer_id]); + } else { + LOG(FATAL) << "Unknown layer: " << layer_name; + } + } + this->InitForcePropNet(true); + layer_need_backward = this->net_->layer_need_backward(); + for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) { + const string& layer_name = this->net_->layer_names()[layer_id]; + const vector need_backward = + this->net_->bottom_need_backward()[layer_id]; + if (layer_name == "data") { + ASSERT_EQ(need_backward.size(), 0); + EXPECT_FALSE(layer_need_backward[layer_id]); + } else if (layer_name == "innerproduct") { + ASSERT_EQ(need_backward.size(), 1); + EXPECT_TRUE(need_backward[0]); // data + EXPECT_TRUE(layer_need_backward[layer_id]); + } else if (layer_name == "loss") { + ASSERT_EQ(need_backward.size(), 2); + EXPECT_TRUE(need_backward[0]); // innerproduct + EXPECT_FALSE(need_backward[1]); // label + EXPECT_TRUE(layer_need_backward[layer_id]); + } else { + LOG(FATAL) << "Unknown layer: " << layer_name; + } + } +} + } // namespace caffe