Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-channel LRN bounds checking for GPU implementation #1922

Merged
merged 2 commits into from
Mar 18, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions src/caffe/layers/lrn_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,24 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in,
Dtype accum_scale = 0;
// fill the scale at [n, :, h, w]
// accumulate values
while (head < post_pad) {
while (head < post_pad && head < channels) {
accum_scale += in[head * step] * in[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_scale += in[head * step] * in[head * step];
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
// both add and subtract
while (head < channels) {
accum_scale += in[head * step] * in[head * step];
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
if (head - size >= 0) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
if (head - size >= 0) {
accum_scale -= in[(head - size) * step] * in[(head - size) * step];
}
scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size;
++head;
}
Expand Down Expand Up @@ -143,35 +141,30 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
int post_pad = size - pre_pad - 1;
Dtype accum_ratio = 0;
// accumulate values
while (head < post_pad) {
while (head < post_pad && head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
++head;
}
// until we reach size, nothing needs to be subtracted
while (head < size) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// both add and subtract
while (head < channels) {
accum_ratio += top_diff[head * step] * top_data[head * step] /
scale[head * step];
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
if (head - size >= 0) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
}
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
++head;
}
// subtract only
while (head < channels + post_pad) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
if (head - size >= 0) {
accum_ratio -= top_diff[(head - size) * step] *
top_data[(head - size) * step] / scale[(head - size) * step];
}
bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
* pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
bottom_data[(head - post_pad) * step] * accum_ratio;
Expand Down
38 changes: 38 additions & 0 deletions src/caffe/test/test_lrn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ TYPED_TEST(LRNLayerTest, TestForwardAcrossChannels) {
}
}

TYPED_TEST(LRNLayerTest, TestForwardAcrossChannelsLargeRegion) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_lrn_param()->set_local_size(15);
LRNLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
Blob<Dtype> top_reference;
this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
&top_reference);
for (int i = 0; i < this->blob_bottom_->count(); ++i) {
EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
this->epsilon_);
}
}

TYPED_TEST(LRNLayerTest, TestGradientAcrossChannels) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
Expand All @@ -159,6 +175,28 @@ TYPED_TEST(LRNLayerTest, TestGradientAcrossChannels) {
this->blob_top_vec_);
}

TYPED_TEST(LRNLayerTest, TestGradientAcrossChannelsLargeRegion) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
layer_param.mutable_lrn_param()->set_local_size(15);
LRNLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-2, 1e-2);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
for (int i = 0; i < this->blob_top_->count(); ++i) {
this->blob_top_->mutable_cpu_diff()[i] = 1.;
}
vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
layer.Backward(this->blob_top_vec_, propagate_down,
this->blob_bottom_vec_);
// for (int i = 0; i < this->blob_bottom_->count(); ++i) {
// std::cout << "CPU diff " << this->blob_bottom_->cpu_diff()[i]
// << std::endl;
// }
checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
this->blob_top_vec_);
}

TYPED_TEST(LRNLayerTest, TestSetupWithinChannel) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
Expand Down