Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSR weight filler #1883

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,27 @@ class XavierFiller : public Filler<Dtype> {
}
};

/**
* @brief Fills a Blob with values @f$ x \sim N(-a, +a) @f$ where @f$ a @f$ is
* set by the number of incoming nodes, based on the paper [He,
* Zhang, Ren and Sun 2015]
*/
template <typename Dtype>
class MSRFiller : public Filler<Dtype> {
public:
explicit MSRFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my understanding, they use number of output channels instead of input in order to avoid increasing/decreasing variances of gradients though backward pass, don't they? In this case, it should be the following.

fan_in = blob->count() / blobs->channels()

I am sorry if my understanding is not correct.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, the current version implements the forward propagation case, which is equation (10) in the paper. If you used the fan_out

fan_out = blob->count() / blobs->channels()

that would implement the backward propagation case in equation (14). They say at the end of that section that "We note that it is sufficient to use either Eqn.(14) or Eqn.(10) alone." and that " For all models in this paper, both forms can make them converge."

I don't know which is better. The current Caffe Xavier implementation only considers the fan_in, so this PR follows that lead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nickcarlevaris I havn't read those. It might be better to name it like MSRForwardFiller? Anyways, this PR should be helpful since we no longer need to set the filler variances by hand. Thanks!

Dtype std = sqrt(Dtype(2) / fan_in);
caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};


/**
* @brief Get a specific filler from the specification given in FillerParameter.
Expand All @@ -177,6 +198,8 @@ Filler<Dtype>* GetFiller(const FillerParameter& param) {
return new UniformFiller<Dtype>(param);
} else if (type == "xavier") {
return new XavierFiller<Dtype>(param);
} else if (type == "msr") {
return new MSRFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
Expand Down
36 changes: 36 additions & 0 deletions src/caffe/test/test_filler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,40 @@ TYPED_TEST(GaussianFillerTest, TestFill) {
EXPECT_LE(var, target_var * 5.);
}

template <typename Dtype>
class MSRFillerTest : public ::testing::Test {
protected:
MSRFillerTest()
: blob_(new Blob<Dtype>(1000, 3, 4, 5)),
filler_param_() {
filler_.reset(new MSRFiller<Dtype>(filler_param_));
filler_->Fill(blob_);
}
virtual ~MSRFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<MSRFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(MSRFillerTest, TestDtypes);

TYPED_TEST(MSRFillerTest, TestFill) {
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const TypeParam* data = this->blob_->cpu_data();
TypeParam mean = 0.;
TypeParam ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
TypeParam std = sqrt(ex2 - mean*mean);
int fan_in = 3*4*5;
TypeParam target_std = sqrt(2.0 / fan_in);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}

} // namespace caffe