From 4ecea59b7fe215876961450b272f23099eb9e73c Mon Sep 17 00:00:00 2001 From: Nick Carlevaris-Bianco Date: Mon, 16 Feb 2015 15:49:43 +1030 Subject: [PATCH 1/2] Added MSR weight filler, which implements Xavier-like filler designed for use with LRUs instead of tanh. Based on paper: He et al, "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification," 2015 --- include/caffe/filler.hpp | 23 ++++++++++++++++++++++ src/caffe/test/test_filler.cpp | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index eebf565b1d5..9bcf0d813d2 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -157,6 +157,27 @@ class XavierFiller : public Filler { } }; +/** + * @brief Fills a Blob with values @f$ x \sim N(-a, +a) @f$ where @f$ a @f$ is + * set by the number of incoming nodes, based on the paper [He, + * Zhang, Ren and Sun 2015] + */ +template +class MSRFiller : public Filler { + public: + explicit MSRFiller(const FillerParameter& param) + : Filler(param) {} + virtual void Fill(Blob* blob) { + CHECK(blob->count()); + int fan_in = blob->count() / blob->num(); + Dtype std = sqrt(Dtype(2) / fan_in); + caffe_rng_gaussian(blob->count(), Dtype(0), std, + blob->mutable_cpu_data()); + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; + } +}; + /** * @brief Get a specific filler from the specification given in FillerParameter. @@ -177,6 +198,8 @@ Filler* GetFiller(const FillerParameter& param) { return new UniformFiller(param); } else if (type == "xavier") { return new XavierFiller(param); + } else if (type == "msr") { + return new MSRFiller(param); } else { CHECK(false) << "Unknown filler name: " << param.type(); } diff --git a/src/caffe/test/test_filler.cpp b/src/caffe/test/test_filler.cpp index e04b0fd22af..90478a0aaeb 100644 --- a/src/caffe/test/test_filler.cpp +++ b/src/caffe/test/test_filler.cpp @@ -142,4 +142,40 @@ TYPED_TEST(GaussianFillerTest, TestFill) { EXPECT_LE(var, target_var * 5.); } +template +class MSRFillerTest : public ::testing::Test { + protected: + MSRFillerTest() + : blob_(new Blob(1000, 3, 4, 5)), + filler_param_() { + filler_.reset(new MSRFiller(filler_param_)); + filler_->Fill(blob_); + } + virtual ~MSRFillerTest() { delete blob_; } + Blob* const blob_; + FillerParameter filler_param_; + shared_ptr > filler_; +}; + +TYPED_TEST_CASE(MSRFillerTest, TestDtypes); + +TYPED_TEST(MSRFillerTest, TestFill) { + EXPECT_TRUE(this->blob_); + const int count = this->blob_->count(); + const TypeParam* data = this->blob_->cpu_data(); + TypeParam mean = 0.; + TypeParam ex2 = 0.; + for (int i = 0; i < count; ++i) { + mean += data[i]; + ex2 += data[i] * data[i]; + } + mean /= count; + ex2 /= count; + TypeParam std = sqrt(ex2 - mean*mean); + int fan_in = 3*4*5; + TypeParam target_std = sqrt(2.0 / fan_in); + EXPECT_NEAR(mean, 0.0, 0.1); + EXPECT_NEAR(std, target_std, 0.1); +} + } // namespace caffe From e2b4ae51b763424361a6448a19e7412da8b4f6a7 Mon Sep 17 00:00:00 2001 From: Nick Carlevaris-Bianco Date: Tue, 17 Feb 2015 13:23:06 +1030 Subject: [PATCH 2/2] Added fan_in and fan_out filler parameters. --- include/caffe/filler.hpp | 11 ++++++- src/caffe/proto/caffe.proto | 4 +++ src/caffe/test/test_filler.cpp | 53 +++++++++++++++++++++------------- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index 9bcf0d813d2..03f89504e97 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -169,8 +169,17 @@ class MSRFiller : public Filler { : Filler(param) {} virtual void Fill(Blob* blob) { CHECK(blob->count()); + CHECK(this->filler_param_.fan_in() || this->filler_param_.fan_out()) + << "MSR Filler requires either fan_in, fan_out, or both to be true."; int fan_in = blob->count() / blob->num(); - Dtype std = sqrt(Dtype(2) / fan_in); + int fan_out = blob->count() / blob->channels(); + Dtype n = fan_in; // default to fan_in + if (this->filler_param_.fan_in() && this->filler_param_.fan_out()) { + n = (fan_in + fan_out) / Dtype(2); + } else if (this->filler_param_.fan_out()) { + n = fan_out; + } + Dtype std = sqrt(Dtype(2) / n); caffe_rng_gaussian(blob->count(), Dtype(0), std, blob->mutable_cpu_data()); CHECK_EQ(this->filler_param_.sparse(), -1) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 8d937420eba..15186a1aa4a 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -41,6 +41,10 @@ message FillerParameter { // The expected number of non-zero output weights for a given input in // Gaussian filler -- the default -1 means don't perform sparsification. optional int32 sparse = 7 [default = -1]; + // for the msr filler we can consider the fan_in size, the fan_out size, or + // both (by averaging) + optional bool fan_in = 8 [default = true]; + optional bool fan_out = 9 [default = false]; } message NetParameter { diff --git a/src/caffe/test/test_filler.cpp b/src/caffe/test/test_filler.cpp index 90478a0aaeb..7c2b043e663 100644 --- a/src/caffe/test/test_filler.cpp +++ b/src/caffe/test/test_filler.cpp @@ -146,10 +146,29 @@ template class MSRFillerTest : public ::testing::Test { protected: MSRFillerTest() - : blob_(new Blob(1000, 3, 4, 5)), + : blob_(new Blob(1000, 2, 4, 5)), filler_param_() { - filler_.reset(new MSRFiller(filler_param_)); - filler_->Fill(blob_); + } + virtual void test_params(bool fan_in, bool fan_out, Dtype n) { + this->filler_param_.set_fan_in(fan_in); + this->filler_param_.set_fan_out(fan_out); + this->filler_.reset(new MSRFiller(this->filler_param_)); + this->filler_->Fill(blob_); + EXPECT_TRUE(this->blob_); + const int count = this->blob_->count(); + const Dtype* data = this->blob_->cpu_data(); + Dtype mean = 0.; + Dtype ex2 = 0.; + for (int i = 0; i < count; ++i) { + mean += data[i]; + ex2 += data[i] * data[i]; + } + mean /= count; + ex2 /= count; + Dtype std = sqrt(ex2 - mean*mean); + Dtype target_std = sqrt(2.0 / n); + EXPECT_NEAR(mean, 0.0, 0.1); + EXPECT_NEAR(std, target_std, 0.1); } virtual ~MSRFillerTest() { delete blob_; } Blob* const blob_; @@ -159,23 +178,17 @@ class MSRFillerTest : public ::testing::Test { TYPED_TEST_CASE(MSRFillerTest, TestDtypes); -TYPED_TEST(MSRFillerTest, TestFill) { - EXPECT_TRUE(this->blob_); - const int count = this->blob_->count(); - const TypeParam* data = this->blob_->cpu_data(); - TypeParam mean = 0.; - TypeParam ex2 = 0.; - for (int i = 0; i < count; ++i) { - mean += data[i]; - ex2 += data[i] * data[i]; - } - mean /= count; - ex2 /= count; - TypeParam std = sqrt(ex2 - mean*mean); - int fan_in = 3*4*5; - TypeParam target_std = sqrt(2.0 / fan_in); - EXPECT_NEAR(mean, 0.0, 0.1); - EXPECT_NEAR(std, target_std, 0.1); +TYPED_TEST(MSRFillerTest, TestFillFanIn) { + TypeParam n = 2*4*5; + this->test_params(true, false, n); +} +TYPED_TEST(MSRFillerTest, TestFillFanOut) { + TypeParam n = 1000*4*5; + this->test_params(false, true, n); +} +TYPED_TEST(MSRFillerTest, TestFillFanInFanOut) { + TypeParam n = (2*4*5 + 1000*4*5) / 2.0; + this->test_params(true, true, n); } } // namespace caffe