From d8dcb1dfce772cdb3d6b2063f1e434b2b4fd0d23 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 2 Sep 2015 18:24:55 -0700 Subject: [PATCH 1/7] Add ScalarLayer to multiply two Blobs, broadcasting the shape of the second as needed --- include/caffe/layers/scalar_layer.hpp | 70 +++++++ src/caffe/layers/scalar_layer.cpp | 124 +++++++++++++ src/caffe/layers/scalar_layer.cu | 84 +++++++++ src/caffe/proto/caffe.proto | 20 +- src/caffe/test/test_scalar_layer.cpp | 255 ++++++++++++++++++++++++++ 5 files changed, 552 insertions(+), 1 deletion(-) create mode 100644 include/caffe/layers/scalar_layer.hpp create mode 100644 src/caffe/layers/scalar_layer.cpp create mode 100644 src/caffe/layers/scalar_layer.cu create mode 100644 src/caffe/test/test_scalar_layer.cpp diff --git a/include/caffe/layers/scalar_layer.hpp b/include/caffe/layers/scalar_layer.hpp new file mode 100644 index 00000000000..b57677c640d --- /dev/null +++ b/include/caffe/layers/scalar_layer.hpp @@ -0,0 +1,70 @@ +#ifndef CAFFE_INNER_PRODUCT_LAYER_HPP_ +#define CAFFE_INNER_PRODUCT_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Computes a product of two input Blobs, with the shape of the + * latter Blob "broadcast" to match the shape of the former. + * Equivalent to tiling the latter Blob, then computing the elementwise + * product. + */ +template +class ScalarLayer: public Layer { + public: + explicit ScalarLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "Scalar"; } + // Scalar + virtual inline int MinBottomBlobs() const { return 1; } + virtual inline int MaxBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + /** + * In the below shape specifications, @f$ i @f$ denotes the value of the + * `axis` field given by `this->layer_param_.scalar_param().axis()`, after + * canonicalization (i.e., conversion from negative to positive index, + * if applicable). + * + * @param bottom input Blob vector (length 2) + * -# @f$ (d_0 \times ... \times + * d_i \times ... \times d_j \times ... \times d_n) @f$ + * the first factor @f$ x @f$ + * -# @f$ (d_i \times ... \times d_j) @f$ + * the second factor @f$ y @f$ + * @param top output Blob vector (length 1) + * -# @f$ (d_0 \times ... \times + * d_i \times ... \times d_j \times ... \times d_n) @f$ + * the product @f$ z = x y @f$ computed after "broadcasting" y. + * Equivalent to tiling @f$ y @f$ to have the same shape as @f$ x @f$, + * then computing the elementwise product. + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob sum_multiplier_; + Blob sum_result_; + int axis_; + int outer_dim_, scalar_dim_, inner_dim_; +}; + + +} // namespace caffe + +#endif // CAFFE_INNER_PRODUCT_LAYER_HPP_ diff --git a/src/caffe/layers/scalar_layer.cpp b/src/caffe/layers/scalar_layer.cpp new file mode 100644 index 00000000000..578faf9dc25 --- /dev/null +++ b/src/caffe/layers/scalar_layer.cpp @@ -0,0 +1,124 @@ +#include +#include + +#include "caffe/layers/scalar_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void ScalarLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + // TODO: make ScalarLayer usable in-place. + // Currently, in-place computation is broken during Backward with + // propagate_down[0] && propagate_down[1], as bottom[0]'s diff is used for + // temporary storage of an intermediate result, overwriting top[0]'s diff + // if using in-place computation. + CHECK_NE(bottom[0], top[0]) << "ScalarLayer cannot be used in-place"; + // Always set axis_ == 0 in special case where scalar is an actual scalar + // (num_axes == 0). Mathematically equivalent for any choice of axis_, so the + // actual setting can be safely ignored; and computation is most efficient + // with axis_ == 0 and (therefore) outer_dim_ == 1. (Setting axis_ to + // bottom[0]->num_axes() - 1, giving inner_dim_ == 1, would be equally + // performant.) + const ScalarParameter& param = this->layer_param_.scalar_param(); + axis_ = (bottom[1]->num_axes() == 0) ? + 0 : bottom[0]->CanonicalAxisIndex(param.axis()); + CHECK_GE(bottom[0]->num_axes(), axis_ + bottom[1]->num_axes()) + << "bottom[1]'s shape extends past bottom[0]'s shape when applied " + << "starting with bottom[0] axis = " << axis_; + for (int i = 0; i < bottom[1]->num_axes(); ++i) { + CHECK_EQ(bottom[0]->shape(axis_ + i), bottom[1]->shape(i)) + << "dimension mismatch between bottom[0]->shape(" << axis_ + i + << ") and bottom[1]->shape(" << i << ")"; + } + outer_dim_ = bottom[0]->count(0, axis_); + scalar_dim_ = bottom[1]->count(); + inner_dim_ = bottom[0]->count(axis_ + bottom[1]->num_axes()); + top[0]->ReshapeLike(*bottom[0]); + sum_result_.Reshape(vector(1, outer_dim_ * scalar_dim_)); + const int sum_mult_size = std::max(outer_dim_, inner_dim_); + sum_multiplier_.Reshape(vector(1, sum_mult_size)); + if (sum_multiplier_.cpu_data()[sum_mult_size - 1] != Dtype(1)) { + caffe_set(sum_mult_size, Dtype(1), sum_multiplier_.mutable_cpu_data()); + } +} + +template +void ScalarLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* scalar_data = bottom[1]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + for (int n = 0; n < outer_dim_; ++n) { + for (int d = 0; d < scalar_dim_; ++d) { + const Dtype factor = scalar_data[d]; + caffe_cpu_scale(inner_dim_, factor, bottom_data, top_data); + bottom_data += inner_dim_; + top_data += inner_dim_; + } + } +} + +template +void ScalarLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[1]) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + // Hack: store big eltwise product in bottom[0] diff, except in the special + // case where this layer itself does the eltwise product, in which case we + // can store it directly in the scalar diff, and we're done. + const bool is_eltwise = (bottom[0]->count() == bottom[1]->count()); + Dtype* product = (is_eltwise ? bottom[1] : bottom[0])->mutable_cpu_diff(); + caffe_mul(top[0]->count(), top_diff, bottom_data, product); + if (!is_eltwise) { + Dtype* sum_result = NULL; + if (inner_dim_ == 1) { + sum_result = product; + } else if (sum_result_.count() == 1) { + const Dtype* sum_mult = sum_multiplier_.cpu_data(); + Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); + *scalar_diff = caffe_cpu_dot(inner_dim_, product, sum_mult); + } else { + const Dtype* sum_mult = sum_multiplier_.cpu_data(); + sum_result = (outer_dim_ == 1) ? + bottom[1]->mutable_cpu_diff() : sum_result_.mutable_cpu_data(); + caffe_cpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_, + Dtype(1), product, sum_mult, Dtype(0), sum_result); + } + if (outer_dim_ != 1) { + const Dtype* sum_mult = sum_multiplier_.cpu_data(); + Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); + if (scalar_dim_ == 1) { + *scalar_diff = caffe_cpu_dot(outer_dim_, sum_mult, sum_result); + } else { + caffe_cpu_gemv(CblasTrans, outer_dim_, scalar_dim_, + Dtype(1), sum_result, sum_mult, Dtype(0), scalar_diff); + } + } + } + } + if (propagate_down[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* scalar_data = bottom[1]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + for (int n = 0; n < outer_dim_; ++n) { + for (int d = 0; d < scalar_dim_; ++d) { + const Dtype factor = scalar_data[d]; + caffe_cpu_scale(inner_dim_, factor, top_diff, bottom_diff); + bottom_diff += inner_dim_; + top_diff += inner_dim_; + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(ScalarLayer); +#endif + +INSTANTIATE_CLASS(ScalarLayer); +REGISTER_LAYER_CLASS(Scalar); + +} // namespace caffe diff --git a/src/caffe/layers/scalar_layer.cu b/src/caffe/layers/scalar_layer.cu new file mode 100644 index 00000000000..5cf0f9c7bec --- /dev/null +++ b/src/caffe/layers/scalar_layer.cu @@ -0,0 +1,84 @@ +#include +#include + +#include "caffe/layers/scalar_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +__global__ void ScalarForward(const int n, const Dtype* in, + const Dtype* scalar, const int scalar_dim, const int inner_dim, + Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + const int scalar_index = (index / inner_dim) % scalar_dim; + out[index] = in[index] * scalar[scalar_index]; + } +} + +template +void ScalarLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + const int count = top[0]->count(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* scalar_data = bottom[1]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, bottom_data, scalar_data, scalar_dim_, inner_dim_, top_data); +} + +template +void ScalarLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[1]) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + // Hack: store big eltwise product in bottom[0] diff, except in the special + // case where this layer itself does the eltwise product, in which case we + // can store it directly in the scalar diff, and we're done. + const bool is_eltwise = (bottom[0]->count() == bottom[1]->count()); + Dtype* product = (is_eltwise ? bottom[1] : bottom[0])->mutable_gpu_diff(); + caffe_gpu_mul(top[0]->count(), top_diff, bottom_data, product); + if (!is_eltwise) { + Dtype* sum_result = NULL; + if (inner_dim_ == 1) { + sum_result = product; + } else if (sum_result_.count() == 1) { + const Dtype* sum_mult = sum_multiplier_.gpu_data(); + Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); + caffe_gpu_dot(inner_dim_, product, sum_mult, scalar_diff); + } else { + const Dtype* sum_mult = sum_multiplier_.gpu_data(); + sum_result = (outer_dim_ == 1) ? + bottom[1]->mutable_gpu_diff() : sum_result_.mutable_gpu_data(); + caffe_gpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_, + Dtype(1), product, sum_mult, Dtype(0), sum_result); + } + if (outer_dim_ != 1) { + const Dtype* sum_mult = sum_multiplier_.gpu_data(); + if (scalar_dim_ == 1) { + Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); + caffe_gpu_dot(outer_dim_, sum_mult, sum_result, scalar_diff); + } else { + Dtype* scalar_diff = bottom[1]->mutable_gpu_diff(); + caffe_gpu_gemv(CblasTrans, outer_dim_, scalar_dim_, + Dtype(1), sum_result, sum_mult, Dtype(0), scalar_diff); + } + } + } + } + if (propagate_down[0]) { + const int count = top[0]->count(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* scalar_data = bottom[1]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, top_diff, scalar_data, scalar_dim_, inner_dim_, bottom_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ScalarLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 019aa614373..b62c3b9872c 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -306,7 +306,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 140 (last added: batch_norm_param) +// LayerParameter next available layer-specific ID: 141 (last added: scalar_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -383,6 +383,7 @@ message LayerParameter { optional ReductionParameter reduction_param = 136; optional ReLUParameter relu_param = 123; optional ReshapeParameter reshape_param = 133; + optional ScalarParameter scalar_param = 140; optional SigmoidParameter sigmoid_param = 124; optional SoftmaxParameter softmax_param = 125; optional SPPParameter spp_param = 132; @@ -951,6 +952,23 @@ message ReshapeParameter { optional int32 num_axes = 3 [default = -1]; } +message ScalarParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a literal scalar. + optional int32 axis = 1 [default = 0]; +} + message SigmoidParameter { enum Engine { DEFAULT = 0; diff --git a/src/caffe/test/test_scalar_layer.cpp b/src/caffe/test/test_scalar_layer.cpp new file mode 100644 index 00000000000..f7bf63f2119 --- /dev/null +++ b/src/caffe/test/test_scalar_layer.cpp @@ -0,0 +1,255 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/scalar_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class ScalarLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ScalarLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_bottom_eltwise_(new Blob(2, 3, 4, 5)), + blob_bottom_broadcast_0_(new Blob()), + blob_bottom_broadcast_1_(new Blob()), + blob_bottom_broadcast_2_(new Blob()), + blob_bottom_scalar_(new Blob(vector())), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + vector broadcast_shape(2); + broadcast_shape[0] = 2; broadcast_shape[1] = 3; + this->blob_bottom_broadcast_0_->Reshape(broadcast_shape); + broadcast_shape[0] = 3; broadcast_shape[1] = 4; + this->blob_bottom_broadcast_1_->Reshape(broadcast_shape); + broadcast_shape[0] = 4; broadcast_shape[1] = 5; + this->blob_bottom_broadcast_2_->Reshape(broadcast_shape); + FillerParameter filler_param; + filler_param.set_min(1); + filler_param.set_max(10); + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + filler.Fill(this->blob_bottom_eltwise_); + filler.Fill(this->blob_bottom_broadcast_0_); + filler.Fill(this->blob_bottom_broadcast_1_); + filler.Fill(this->blob_bottom_broadcast_2_); + filler.Fill(this->blob_bottom_scalar_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~ScalarLayerTest() { + delete blob_bottom_; + delete blob_bottom_eltwise_; + delete blob_bottom_broadcast_0_; + delete blob_bottom_broadcast_1_; + delete blob_bottom_broadcast_2_; + delete blob_bottom_scalar_; + delete blob_top_; + } + Blob* const blob_bottom_; + Blob* const blob_bottom_eltwise_; + Blob* const blob_bottom_broadcast_0_; + Blob* const blob_bottom_broadcast_1_; + Blob* const blob_bottom_broadcast_2_; + Blob* const blob_bottom_scalar_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ScalarLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ScalarLayerTest, TestForwardEltwise) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_->cpu_data(); + const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], 1e-5); + } +} + +TYPED_TEST(ScalarLayerTest, TestForwardBroadcastBegin) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_0_); + LayerParameter layer_param; + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) * + this->blob_bottom_broadcast_0_->data_at(n, c, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddle) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(1); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) * + this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(ScalarLayerTest, TestForwardBroadcastEnd) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(2); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) * + this->blob_bottom_broadcast_2_->data_at(h, w, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(ScalarLayerTest, TestForwardScalar) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_scalar_); + LayerParameter layer_param; + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data = this->blob_bottom_->cpu_data(); + const Dtype scalar = *this->blob_bottom_scalar_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data[i] * scalar, 1e-5); + } +} + +TYPED_TEST(ScalarLayerTest, TestForwardScalarAxis2) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_scalar_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(2); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data = this->blob_bottom_->cpu_data(); + const Dtype scalar = *this->blob_bottom_scalar_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data[i] * scalar, 1e-5); + } +} + +TYPED_TEST(ScalarLayerTest, TestGradientEltwise) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ScalarLayerTest, TestGradientBroadcastBegin) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_0_); + LayerParameter layer_param; + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ScalarLayerTest, TestGradientBroadcastMiddle) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(1); + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ScalarLayerTest, TestGradientBroadcastEnd) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(2); + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ScalarLayerTest, TestGradientScalar) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_scalar_); + LayerParameter layer_param; + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ScalarLayerTest, TestGradientScalarAxis2) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_scalar_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(2); + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe From d60389d840336395eb371394a9d3235a840b01c5 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 23 Dec 2015 12:06:13 -0800 Subject: [PATCH 2/7] ScalarLayer learns scalar as a parameter if only one bottom given --- include/caffe/layers/scalar_layer.hpp | 5 ++ src/caffe/layers/scalar_layer.cpp | 93 +++++++++++++++++++++------ src/caffe/layers/scalar_layer.cu | 41 ++++++++---- src/caffe/proto/caffe.proto | 15 +++++ src/caffe/test/test_scalar_layer.cpp | 71 ++++++++++++++++++++ 5 files changed, 193 insertions(+), 32 deletions(-) diff --git a/include/caffe/layers/scalar_layer.hpp b/include/caffe/layers/scalar_layer.hpp index b57677c640d..59882e4d5f6 100644 --- a/include/caffe/layers/scalar_layer.hpp +++ b/include/caffe/layers/scalar_layer.hpp @@ -14,12 +14,17 @@ namespace caffe { * latter Blob "broadcast" to match the shape of the former. * Equivalent to tiling the latter Blob, then computing the elementwise * product. + * + * The second input may be omitted, in which case it's learned as a parameter + * of the layer. */ template class ScalarLayer: public Layer { public: explicit ScalarLayer(const LayerParameter& param) : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); virtual void Reshape(const vector*>& bottom, const vector*>& top); diff --git a/src/caffe/layers/scalar_layer.cpp b/src/caffe/layers/scalar_layer.cpp index 578faf9dc25..ef52b986908 100644 --- a/src/caffe/layers/scalar_layer.cpp +++ b/src/caffe/layers/scalar_layer.cpp @@ -1,11 +1,48 @@ #include #include +#include "caffe/filler.hpp" #include "caffe/layers/scalar_layer.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { +template +void ScalarLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + if (bottom.size() == 1 && this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else if (bottom.size() == 1) { + // scalar is a learned parameter; initialize it + const ScalarParameter& param = this->layer_param_.scalar_param(); + axis_ = bottom[0]->CanonicalAxisIndex(param.axis()); + const int num_axes = param.num_axes(); + CHECK_GE(num_axes, -1) << "num_axes must be non-negative, " + << "or -1 to extend to the end of bottom[0]"; + if (num_axes >= 0) { + CHECK_GE(bottom[0]->num_axes(), axis_ + num_axes) + << "scalar blob's shape extends past bottom[0]'s shape when applied " + << "starting with bottom[0] axis = " << axis_; + } + this->blobs_.resize(1); + const vector::const_iterator& shape_start = + bottom[0]->shape().begin() + axis_; + const vector::const_iterator& shape_end = + (num_axes == -1) ? bottom[0]->shape().end() : (shape_start + num_axes); + vector scalar_shape(shape_start, shape_end); + this->blobs_[0].reset(new Blob(scalar_shape)); + FillerParameter filler_param(param.filler()); + if (!param.has_filler()) { + // Default to unit (1) filler for identity operation. + filler_param.set_type("constant"); + filler_param.set_value(1); + } + shared_ptr > filler(GetFiller(filler_param)); + filler->Fill(this->blobs_[0].get()); + } + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + template void ScalarLayer::Reshape(const vector*>& bottom, const vector*>& top) { @@ -15,26 +52,27 @@ void ScalarLayer::Reshape(const vector*>& bottom, // temporary storage of an intermediate result, overwriting top[0]'s diff // if using in-place computation. CHECK_NE(bottom[0], top[0]) << "ScalarLayer cannot be used in-place"; + const ScalarParameter& param = this->layer_param_.scalar_param(); + Blob* scalar = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get(); // Always set axis_ == 0 in special case where scalar is an actual scalar // (num_axes == 0). Mathematically equivalent for any choice of axis_, so the // actual setting can be safely ignored; and computation is most efficient // with axis_ == 0 and (therefore) outer_dim_ == 1. (Setting axis_ to // bottom[0]->num_axes() - 1, giving inner_dim_ == 1, would be equally // performant.) - const ScalarParameter& param = this->layer_param_.scalar_param(); - axis_ = (bottom[1]->num_axes() == 0) ? + axis_ = (scalar->num_axes() == 0) ? 0 : bottom[0]->CanonicalAxisIndex(param.axis()); - CHECK_GE(bottom[0]->num_axes(), axis_ + bottom[1]->num_axes()) - << "bottom[1]'s shape extends past bottom[0]'s shape when applied " + CHECK_GE(bottom[0]->num_axes(), axis_ + scalar->num_axes()) + << "scalar blob's shape extends past bottom[0]'s shape when applied " << "starting with bottom[0] axis = " << axis_; - for (int i = 0; i < bottom[1]->num_axes(); ++i) { - CHECK_EQ(bottom[0]->shape(axis_ + i), bottom[1]->shape(i)) + for (int i = 0; i < scalar->num_axes(); ++i) { + CHECK_EQ(bottom[0]->shape(axis_ + i), scalar->shape(i)) << "dimension mismatch between bottom[0]->shape(" << axis_ + i - << ") and bottom[1]->shape(" << i << ")"; + << ") and scalar->shape(" << i << ")"; } outer_dim_ = bottom[0]->count(0, axis_); - scalar_dim_ = bottom[1]->count(); - inner_dim_ = bottom[0]->count(axis_ + bottom[1]->num_axes()); + scalar_dim_ = scalar->count(); + inner_dim_ = bottom[0]->count(axis_ + scalar->num_axes()); top[0]->ReshapeLike(*bottom[0]); sum_result_.Reshape(vector(1, outer_dim_ * scalar_dim_)); const int sum_mult_size = std::max(outer_dim_, inner_dim_); @@ -48,7 +86,8 @@ template void ScalarLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); - const Dtype* scalar_data = bottom[1]->cpu_data(); + const Dtype* scalar_data = + ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); for (int n = 0; n < outer_dim_; ++n) { for (int d = 0; d < scalar_dim_; ++d) { @@ -63,14 +102,17 @@ void ScalarLayer::Forward_cpu( template void ScalarLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { - if (propagate_down[1]) { + const bool scalar_param = (bottom.size() == 1); + Blob* scalar = scalar_param ? this->blobs_[0].get() : bottom[1]; + if ((!scalar_param && propagate_down[1]) || + (scalar_param && this->param_propagate_down_[0])) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* bottom_data = bottom[0]->cpu_data(); // Hack: store big eltwise product in bottom[0] diff, except in the special // case where this layer itself does the eltwise product, in which case we // can store it directly in the scalar diff, and we're done. - const bool is_eltwise = (bottom[0]->count() == bottom[1]->count()); - Dtype* product = (is_eltwise ? bottom[1] : bottom[0])->mutable_cpu_diff(); + const bool is_eltwise = (bottom[0]->count() == scalar->count()); + Dtype* product = (is_eltwise ? scalar : bottom[0])->mutable_cpu_diff(); caffe_mul(top[0]->count(), top_diff, bottom_data, product); if (!is_eltwise) { Dtype* sum_result = NULL; @@ -78,30 +120,41 @@ void ScalarLayer::Backward_cpu(const vector*>& top, sum_result = product; } else if (sum_result_.count() == 1) { const Dtype* sum_mult = sum_multiplier_.cpu_data(); - Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); - *scalar_diff = caffe_cpu_dot(inner_dim_, product, sum_mult); + Dtype* scalar_diff = scalar->mutable_cpu_diff(); + if (scalar_param) { + Dtype result = caffe_cpu_dot(inner_dim_, product, sum_mult); + *scalar_diff += result; + } else { + *scalar_diff = caffe_cpu_dot(inner_dim_, product, sum_mult); + } } else { const Dtype* sum_mult = sum_multiplier_.cpu_data(); sum_result = (outer_dim_ == 1) ? - bottom[1]->mutable_cpu_diff() : sum_result_.mutable_cpu_data(); + scalar->mutable_cpu_diff() : sum_result_.mutable_cpu_data(); caffe_cpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_, Dtype(1), product, sum_mult, Dtype(0), sum_result); } if (outer_dim_ != 1) { const Dtype* sum_mult = sum_multiplier_.cpu_data(); - Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); + Dtype* scalar_diff = scalar->mutable_cpu_diff(); if (scalar_dim_ == 1) { - *scalar_diff = caffe_cpu_dot(outer_dim_, sum_mult, sum_result); + if (scalar_param) { + Dtype result = caffe_cpu_dot(outer_dim_, sum_mult, sum_result); + *scalar_diff += result; + } else { + *scalar_diff = caffe_cpu_dot(outer_dim_, sum_mult, sum_result); + } } else { caffe_cpu_gemv(CblasTrans, outer_dim_, scalar_dim_, - Dtype(1), sum_result, sum_mult, Dtype(0), scalar_diff); + Dtype(1), sum_result, sum_mult, Dtype(scalar_param), + scalar_diff); } } } } if (propagate_down[0]) { const Dtype* top_diff = top[0]->cpu_diff(); - const Dtype* scalar_data = bottom[1]->cpu_data(); + const Dtype* scalar_data = scalar->cpu_data(); Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); for (int n = 0; n < outer_dim_; ++n) { for (int d = 0; d < scalar_dim_; ++d) { diff --git a/src/caffe/layers/scalar_layer.cu b/src/caffe/layers/scalar_layer.cu index 5cf0f9c7bec..b1af488d769 100644 --- a/src/caffe/layers/scalar_layer.cu +++ b/src/caffe/layers/scalar_layer.cu @@ -21,7 +21,8 @@ void ScalarLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { const int count = top[0]->count(); const Dtype* bottom_data = bottom[0]->gpu_data(); - const Dtype* scalar_data = bottom[1]->gpu_data(); + const Dtype* scalar_data = + ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) <<>>( @@ -31,14 +32,17 @@ void ScalarLayer::Forward_gpu( template void ScalarLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { - if (propagate_down[1]) { + const bool scalar_param = (bottom.size() == 1); + Blob* scalar = scalar_param ? this->blobs_[0].get() : bottom[1]; + if ((!scalar_param && propagate_down[1]) || + (scalar_param && this->param_propagate_down_[0])) { const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = bottom[0]->gpu_data(); // Hack: store big eltwise product in bottom[0] diff, except in the special // case where this layer itself does the eltwise product, in which case we // can store it directly in the scalar diff, and we're done. - const bool is_eltwise = (bottom[0]->count() == bottom[1]->count()); - Dtype* product = (is_eltwise ? bottom[1] : bottom[0])->mutable_gpu_diff(); + const bool is_eltwise = (bottom[0]->count() == scalar->count()); + Dtype* product = (is_eltwise ? scalar : bottom[0])->mutable_gpu_diff(); caffe_gpu_mul(top[0]->count(), top_diff, bottom_data, product); if (!is_eltwise) { Dtype* sum_result = NULL; @@ -46,24 +50,37 @@ void ScalarLayer::Backward_gpu(const vector*>& top, sum_result = product; } else if (sum_result_.count() == 1) { const Dtype* sum_mult = sum_multiplier_.gpu_data(); - Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); - caffe_gpu_dot(inner_dim_, product, sum_mult, scalar_diff); + Dtype* scalar_diff = scalar->mutable_cpu_diff(); + if (scalar_param) { + Dtype result; + caffe_gpu_dot(inner_dim_, product, sum_mult, &result); + *scalar_diff += result; + } else { + caffe_gpu_dot(inner_dim_, product, sum_mult, scalar_diff); + } } else { const Dtype* sum_mult = sum_multiplier_.gpu_data(); sum_result = (outer_dim_ == 1) ? - bottom[1]->mutable_gpu_diff() : sum_result_.mutable_gpu_data(); + scalar->mutable_gpu_diff() : sum_result_.mutable_gpu_data(); caffe_gpu_gemv(CblasNoTrans, sum_result_.count(), inner_dim_, Dtype(1), product, sum_mult, Dtype(0), sum_result); } if (outer_dim_ != 1) { const Dtype* sum_mult = sum_multiplier_.gpu_data(); if (scalar_dim_ == 1) { - Dtype* scalar_diff = bottom[1]->mutable_cpu_diff(); - caffe_gpu_dot(outer_dim_, sum_mult, sum_result, scalar_diff); + Dtype* scalar_diff = scalar->mutable_cpu_diff(); + if (scalar_param) { + Dtype result; + caffe_gpu_dot(outer_dim_, sum_mult, sum_result, &result); + *scalar_diff += result; + } else { + caffe_gpu_dot(outer_dim_, sum_mult, sum_result, scalar_diff); + } } else { - Dtype* scalar_diff = bottom[1]->mutable_gpu_diff(); + Dtype* scalar_diff = scalar->mutable_gpu_diff(); caffe_gpu_gemv(CblasTrans, outer_dim_, scalar_dim_, - Dtype(1), sum_result, sum_mult, Dtype(0), scalar_diff); + Dtype(1), sum_result, sum_mult, Dtype(scalar_param), + scalar_diff); } } } @@ -71,7 +88,7 @@ void ScalarLayer::Backward_gpu(const vector*>& top, if (propagate_down[0]) { const int count = top[0]->count(); const Dtype* top_diff = top[0]->gpu_diff(); - const Dtype* scalar_data = bottom[1]->gpu_data(); + const Dtype* scalar_data = scalar->gpu_data(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) <<>>( diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index b62c3b9872c..6f8cf101bb0 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -967,6 +967,21 @@ message ScalarParameter { // Furthermore, bottom[1] may have the empty shape (regardless of the value of // "axis") -- a literal scalar. optional int32 axis = 1 [default = 0]; + + // (num_axes is ignored unless just one bottom is given and the scalar is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scalar + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // The default, 0, learns a zero-axis Blob: an actual scalar multiplier. + optional int32 num_axes = 2 [default = 0]; + + // (filler is ignored unless just one bottom is given and the scalar is + // a learned parameter of the layer.) + // The initialization for the learned scalar parameter. + // Default is the unit (1) initialization, resulting in the ScalarLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; } message SigmoidParameter { diff --git a/src/caffe/test/test_scalar_layer.cpp b/src/caffe/test/test_scalar_layer.cpp index f7bf63f2119..caba89a0d81 100644 --- a/src/caffe/test/test_scalar_layer.cpp +++ b/src/caffe/test/test_scalar_layer.cpp @@ -86,6 +86,26 @@ TYPED_TEST(ScalarLayerTest, TestForwardEltwise) { } } +TYPED_TEST(ScalarLayerTest, TestForwardEltwiseWithParam) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ScalarParameter* scalar_param = layer_param.mutable_scalar_param(); + scalar_param->set_axis(0); + scalar_param->set_num_axes(-1); + scalar_param->mutable_filler()->set_type("gaussian"); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_->cpu_data(); + const Dtype* in_data_b = layer->blobs()[0]->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], 1e-5); + } +} + TYPED_TEST(ScalarLayerTest, TestForwardBroadcastBegin) { typedef typename TypeParam::Dtype Dtype; this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_0_); @@ -131,6 +151,30 @@ TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddle) { } } +TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddleWithParam) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ScalarParameter* scalar_param = layer_param.mutable_scalar_param(); + scalar_param->set_axis(1); + scalar_param->set_num_axes(2); + scalar_param->mutable_filler()->set_type("gaussian"); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) * + layer->blobs()[0]->data_at(c, h, 0, 0), 1e-5); + } + } + } + } +} + TYPED_TEST(ScalarLayerTest, TestForwardBroadcastEnd) { typedef typename TypeParam::Dtype Dtype; this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); @@ -199,6 +243,19 @@ TYPED_TEST(ScalarLayerTest, TestGradientEltwise) { this->blob_top_vec_); } +TYPED_TEST(ScalarLayerTest, TestGradientEltwiseWithParam) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ScalarParameter* scalar_param = layer_param.mutable_scalar_param(); + scalar_param->set_axis(0); + scalar_param->set_num_axes(-1); + scalar_param->mutable_filler()->set_type("gaussian"); + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + TYPED_TEST(ScalarLayerTest, TestGradientBroadcastBegin) { typedef typename TypeParam::Dtype Dtype; this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_0_); @@ -220,6 +277,20 @@ TYPED_TEST(ScalarLayerTest, TestGradientBroadcastMiddle) { this->blob_top_vec_); } +TYPED_TEST(ScalarLayerTest, TestGradientBroadcastMiddleWithParam) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + ScalarParameter* scalar_param = layer_param.mutable_scalar_param(); + scalar_param->set_axis(1); + scalar_param->set_num_axes(2); + scalar_param->mutable_filler()->set_type("gaussian"); + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + TYPED_TEST(ScalarLayerTest, TestGradientBroadcastEnd) { typedef typename TypeParam::Dtype Dtype; this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); From f30ffad057883bb21b42e18fd1d0129e5afc426d Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 23 Dec 2015 23:15:51 -0800 Subject: [PATCH 3/7] ScalarLayer supports in-place computation --- include/caffe/layers/scalar_layer.hpp | 1 + src/caffe/layers/scalar_layer.cpp | 28 ++++-- src/caffe/layers/scalar_layer.cu | 16 ++- src/caffe/test/test_scalar_layer.cpp | 135 ++++++++++++++++++++++++++ 4 files changed, 169 insertions(+), 11 deletions(-) diff --git a/include/caffe/layers/scalar_layer.hpp b/include/caffe/layers/scalar_layer.hpp index 59882e4d5f6..f679622dde4 100644 --- a/include/caffe/layers/scalar_layer.hpp +++ b/include/caffe/layers/scalar_layer.hpp @@ -65,6 +65,7 @@ class ScalarLayer: public Layer { Blob sum_multiplier_; Blob sum_result_; + Blob temp_; int axis_; int outer_dim_, scalar_dim_, inner_dim_; }; diff --git a/src/caffe/layers/scalar_layer.cpp b/src/caffe/layers/scalar_layer.cpp index ef52b986908..58f5b9b00e4 100644 --- a/src/caffe/layers/scalar_layer.cpp +++ b/src/caffe/layers/scalar_layer.cpp @@ -46,12 +46,6 @@ void ScalarLayer::LayerSetUp(const vector*>& bottom, template void ScalarLayer::Reshape(const vector*>& bottom, const vector*>& top) { - // TODO: make ScalarLayer usable in-place. - // Currently, in-place computation is broken during Backward with - // propagate_down[0] && propagate_down[1], as bottom[0]'s diff is used for - // temporary storage of an intermediate result, overwriting top[0]'s diff - // if using in-place computation. - CHECK_NE(bottom[0], top[0]) << "ScalarLayer cannot be used in-place"; const ScalarParameter& param = this->layer_param_.scalar_param(); Blob* scalar = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get(); // Always set axis_ == 0 in special case where scalar is an actual scalar @@ -73,7 +67,11 @@ void ScalarLayer::Reshape(const vector*>& bottom, outer_dim_ = bottom[0]->count(0, axis_); scalar_dim_ = scalar->count(); inner_dim_ = bottom[0]->count(axis_ + scalar->num_axes()); - top[0]->ReshapeLike(*bottom[0]); + if (bottom[0] == top[0]) { // in-place computation + temp_.ReshapeLike(*bottom[0]); + } else { + top[0]->ReshapeLike(*bottom[0]); + } sum_result_.Reshape(vector(1, outer_dim_ * scalar_dim_)); const int sum_mult_size = std::max(outer_dim_, inner_dim_); sum_multiplier_.Reshape(vector(1, sum_mult_size)); @@ -86,6 +84,14 @@ template void ScalarLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); + if (bottom[0] == top[0]) { + // In-place computation; need to store bottom data before overwriting it. + // Note that this is only necessary for Backward; we could skip this if not + // doing Backward, but Caffe currently provides no way of knowing whether + // we'll need to do Backward at the time of the Forward call. + caffe_copy(bottom[0]->count(), bottom[0]->cpu_data(), + temp_.mutable_cpu_data()); + } const Dtype* scalar_data = ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); @@ -107,12 +113,16 @@ void ScalarLayer::Backward_cpu(const vector*>& top, if ((!scalar_param && propagate_down[1]) || (scalar_param && this->param_propagate_down_[0])) { const Dtype* top_diff = top[0]->cpu_diff(); - const Dtype* bottom_data = bottom[0]->cpu_data(); + const bool in_place = (bottom[0] == top[0]); + const Dtype* bottom_data = (in_place ? &temp_ : bottom[0])->cpu_data(); // Hack: store big eltwise product in bottom[0] diff, except in the special // case where this layer itself does the eltwise product, in which case we // can store it directly in the scalar diff, and we're done. + // If we're computing in-place (and not doing eltwise computation), this + // hack doesn't work and we store the product in temp_. const bool is_eltwise = (bottom[0]->count() == scalar->count()); - Dtype* product = (is_eltwise ? scalar : bottom[0])->mutable_cpu_diff(); + Dtype* product = (is_eltwise ? scalar->mutable_cpu_diff() : + (in_place ? temp_.mutable_cpu_data() : bottom[0]->mutable_cpu_diff())); caffe_mul(top[0]->count(), top_diff, bottom_data, product); if (!is_eltwise) { Dtype* sum_result = NULL; diff --git a/src/caffe/layers/scalar_layer.cu b/src/caffe/layers/scalar_layer.cu index b1af488d769..9c6932723af 100644 --- a/src/caffe/layers/scalar_layer.cu +++ b/src/caffe/layers/scalar_layer.cu @@ -21,6 +21,14 @@ void ScalarLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { const int count = top[0]->count(); const Dtype* bottom_data = bottom[0]->gpu_data(); + if (bottom[0] == top[0]) { + // in-place computation; need to store bottom data before overwriting it. + // Note that this is only necessary for Backward; we could skip this if not + // doing Backward, but Caffe currently provides no way of knowing whether + // we'll need to do Backward at the time of the Forward call. + caffe_copy(bottom[0]->count(), bottom[0]->gpu_data(), + temp_.mutable_gpu_data()); + } const Dtype* scalar_data = ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); @@ -37,12 +45,16 @@ void ScalarLayer::Backward_gpu(const vector*>& top, if ((!scalar_param && propagate_down[1]) || (scalar_param && this->param_propagate_down_[0])) { const Dtype* top_diff = top[0]->gpu_diff(); - const Dtype* bottom_data = bottom[0]->gpu_data(); + const bool in_place = (bottom[0] == top[0]); + const Dtype* bottom_data = (in_place ? &temp_ : bottom[0])->gpu_data(); // Hack: store big eltwise product in bottom[0] diff, except in the special // case where this layer itself does the eltwise product, in which case we // can store it directly in the scalar diff, and we're done. + // If we're computing in-place (and not doing eltwise computation), this + // hack doesn't work and we store the product in temp_. const bool is_eltwise = (bottom[0]->count() == scalar->count()); - Dtype* product = (is_eltwise ? scalar : bottom[0])->mutable_gpu_diff(); + Dtype* product = (is_eltwise ? scalar->mutable_gpu_diff() : + (in_place ? temp_.mutable_gpu_data() : bottom[0]->mutable_gpu_diff())); caffe_gpu_mul(top[0]->count(), top_diff, bottom_data, product); if (!is_eltwise) { Dtype* sum_result = NULL; diff --git a/src/caffe/test/test_scalar_layer.cpp b/src/caffe/test/test_scalar_layer.cpp index caba89a0d81..399d54a395e 100644 --- a/src/caffe/test/test_scalar_layer.cpp +++ b/src/caffe/test/test_scalar_layer.cpp @@ -86,6 +86,70 @@ TYPED_TEST(ScalarLayerTest, TestForwardEltwise) { } } +TYPED_TEST(ScalarLayerTest, TestForwardEltwiseInPlace) { + typedef typename TypeParam::Dtype Dtype; + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_bottom_->cpu_data(); + const int count = this->blob_bottom_->count(); + const Dtype* in_data_a = orig_bottom.cpu_data(); + const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] * in_data_b[i], 1e-5); + } +} + +TYPED_TEST(ScalarLayerTest, TestBackwardEltwiseInPlace) { + typedef typename TypeParam::Dtype Dtype; + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + shared_ptr > layer(new ScalarLayer(layer_param)); + Blob top_diff(this->blob_bottom_->shape()); + FillerParameter filler_param; + filler_param.set_type("gaussian"); + filler_param.set_std(1); + GaussianFiller filler(filler_param); + filler.Fill(&top_diff); + vector propagate_down(2, true); + // Run forward + backward without in-place computation; + // save resulting bottom diffs. + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const bool kReshape = true; + const bool kCopyDiff = true; + Blob orig_bottom_diff; + orig_bottom_diff.CopyFrom(*this->blob_bottom_, kCopyDiff, kReshape); + Blob orig_scalar_diff; + orig_scalar_diff.CopyFrom(*this->blob_bottom_eltwise_, + kCopyDiff, kReshape); + // Rerun forward + backward with in-place computation; + // check that resulting bottom diffs are the same. + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_bottom_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], + this->blob_bottom_->cpu_diff()[i], 1e-5); + } + for (int i = 0; i < this->blob_bottom_eltwise_->count(); ++i) { + EXPECT_NEAR(orig_scalar_diff.cpu_diff()[i], + this->blob_bottom_eltwise_->cpu_diff()[i], 1e-5); + } +} + TYPED_TEST(ScalarLayerTest, TestForwardEltwiseWithParam) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; @@ -151,6 +215,77 @@ TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddle) { } } +TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddleInPlace) { + typedef typename TypeParam::Dtype Dtype; + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(1); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_bottom_->data_at(n, c, h, w), + orig_bottom.data_at(n, c, h, w) * + this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(ScalarLayerTest, TestBackwardBroadcastMiddleInPlace) { + typedef typename TypeParam::Dtype Dtype; + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_scalar_param()->set_axis(1); + shared_ptr > layer(new ScalarLayer(layer_param)); + Blob top_diff(this->blob_bottom_->shape()); + FillerParameter filler_param; + filler_param.set_type("gaussian"); + filler_param.set_std(1); + GaussianFiller filler(filler_param); + filler.Fill(&top_diff); + vector propagate_down(2, true); + // Run forward + backward without in-place computation; + // save resulting bottom diffs. + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const bool kReshape = true; + const bool kCopyDiff = true; + Blob orig_bottom_diff; + orig_bottom_diff.CopyFrom(*this->blob_bottom_, kCopyDiff, kReshape); + Blob orig_scalar_diff; + orig_scalar_diff.CopyFrom(*this->blob_bottom_broadcast_1_, + kCopyDiff, kReshape); + // Rerun forward + backward with in-place computation; + // check that resulting bottom diffs are the same. + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_bottom_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], + this->blob_bottom_->cpu_diff()[i], 1e-5); + } + for (int i = 0; i < this->blob_bottom_broadcast_1_->count(); ++i) { + EXPECT_NEAR(orig_scalar_diff.cpu_diff()[i], + this->blob_bottom_broadcast_1_->cpu_diff()[i], 1e-5); + } +} + TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddleWithParam) { typedef typename TypeParam::Dtype Dtype; LayerParameter layer_param; From 8626dde739eb88bc2bd0b31c583c016541b2c60f Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 13 Jan 2016 12:35:51 -0800 Subject: [PATCH 4/7] Add BiasLayer to add two blobs with broadcasting --- include/caffe/layers/bias_layer.hpp | 54 ++++ src/caffe/layers/bias_layer.cpp | 121 ++++++++ src/caffe/layers/bias_layer.cu | 54 ++++ src/caffe/proto/caffe.proto | 35 ++- src/caffe/test/test_bias_layer.cpp | 461 ++++++++++++++++++++++++++++ 5 files changed, 724 insertions(+), 1 deletion(-) create mode 100644 include/caffe/layers/bias_layer.hpp create mode 100644 src/caffe/layers/bias_layer.cpp create mode 100644 src/caffe/layers/bias_layer.cu create mode 100644 src/caffe/test/test_bias_layer.cpp diff --git a/include/caffe/layers/bias_layer.hpp b/include/caffe/layers/bias_layer.hpp new file mode 100644 index 00000000000..4f396737963 --- /dev/null +++ b/include/caffe/layers/bias_layer.hpp @@ -0,0 +1,54 @@ +#ifndef CAFFE_INNER_PRODUCT_LAYER_HPP_ +#define CAFFE_INNER_PRODUCT_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Computes a sum of two input Blobs, with the shape of the + * latter Blob "broadcast" to match the shape of the former. + * Equivalent to tiling the latter Blob, then computing the elementwise + * sum. + * + * The second input may be omitted, in which case it's learned as a parameter + * of the layer. + */ +template +class BiasLayer : public Layer { + public: + explicit BiasLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "Bias"; } + virtual inline int MinBottomBlobs() const { return 1; } + virtual inline int MaxBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + private: + Blob bias_multiplier_; + int outer_dim_, bias_dim_, inner_dim_, dim_; +}; + + + +} // namespace caffe + +#endif // CAFFE_INNER_PRODUCT_LAYER_HPP_ diff --git a/src/caffe/layers/bias_layer.cpp b/src/caffe/layers/bias_layer.cpp new file mode 100644 index 00000000000..0a786b5db98 --- /dev/null +++ b/src/caffe/layers/bias_layer.cpp @@ -0,0 +1,121 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layers/bias_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void BiasLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + if (bottom.size() == 1 && this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else if (bottom.size() == 1) { + // bias is a learned parameter; initialize it + const BiasParameter& param = this->layer_param_.bias_param(); + const int axis = bottom[0]->CanonicalAxisIndex(param.axis()); + const int num_axes = param.num_axes(); + CHECK_GE(num_axes, -1) << "num_axes must be non-negative, " + << "or -1 to extend to the end of bottom[0]"; + if (num_axes >= 0) { + CHECK_GE(bottom[0]->num_axes(), axis + num_axes) + << "bias blob's shape extends past bottom[0]'s shape when applied " + << "starting with bottom[0] axis = " << axis; + } + this->blobs_.resize(1); + const vector::const_iterator& shape_start = + bottom[0]->shape().begin() + axis; + const vector::const_iterator& shape_end = + (num_axes == -1) ? bottom[0]->shape().end() : (shape_start + num_axes); + vector bias_shape(shape_start, shape_end); + this->blobs_[0].reset(new Blob(bias_shape)); + shared_ptr > filler(GetFiller(param.filler())); + filler->Fill(this->blobs_[0].get()); + } + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void BiasLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + const BiasParameter& param = this->layer_param_.bias_param(); + Blob* bias = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get(); + // Always set axis == 0 in special case where bias is a scalar + // (num_axes == 0). Mathematically equivalent for any choice of axis, so the + // actual setting can be safely ignored; and computation is most efficient + // with axis == 0 and (therefore) outer_dim_ == 1. + const int axis = (bias->num_axes() == 0) ? + 0 : bottom[0]->CanonicalAxisIndex(param.axis()); + CHECK_GE(bottom[0]->num_axes(), axis + bias->num_axes()) + << "bias blob's shape extends past bottom[0]'s shape when applied " + << "starting with bottom[0] axis = " << axis; + for (int i = 0; i < bias->num_axes(); ++i) { + CHECK_EQ(bottom[0]->shape(axis + i), bias->shape(i)) + << "dimension mismatch between bottom[0]->shape(" << axis + i + << ") and bias->shape(" << i << ")"; + } + outer_dim_ = bottom[0]->count(0, axis); + bias_dim_ = bias->count(); + inner_dim_ = bottom[0]->count(axis + bias->num_axes()); + dim_ = bias_dim_ * inner_dim_; + if (bottom[0] != top[0]) { + top[0]->ReshapeLike(*bottom[0]); + } + bias_multiplier_.Reshape(vector(1, inner_dim_)); + if (bias_multiplier_.cpu_data()[inner_dim_ - 1] != Dtype(1)) { + caffe_set(inner_dim_, Dtype(1), bias_multiplier_.mutable_cpu_data()); + } +} + +template +void BiasLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bias_data = + ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + if (bottom[0] != top[0]) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + caffe_copy(bottom[0]->count(), bottom_data, top_data); + } + for (int n = 0; n < outer_dim_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, bias_dim_, + inner_dim_, Dtype(1), Dtype(1), bias_data, + bias_multiplier_.cpu_data(), Dtype(1), top_data); + top_data += dim_; + } +} + +template +void BiasLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[0] && bottom[0] != top[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_copy(bottom[0]->count(), top_diff, bottom_diff); + } + // in-place, we don't need to do anything with the data diff + const bool bias_param = (bottom.size() == 1); + if ((!bias_param && propagate_down[1]) || + (bias_param && this->param_propagate_down_[0])) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bias_diff = (bias_param ? this->blobs_[0].get() : bottom[1]) + ->mutable_cpu_diff(); + bool accum = bias_param; + for (int n = 0; n < outer_dim_; ++n) { + caffe_cpu_gemv(CblasNoTrans, bias_dim_, inner_dim_, Dtype(1), + top_diff, bias_multiplier_.cpu_data(), Dtype(accum), bias_diff); + top_diff += dim_; + accum = true; + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(BiasLayer); +#endif + +INSTANTIATE_CLASS(BiasLayer); +REGISTER_LAYER_CLASS(Bias); + +} // namespace caffe diff --git a/src/caffe/layers/bias_layer.cu b/src/caffe/layers/bias_layer.cu new file mode 100644 index 00000000000..7711e89e6e5 --- /dev/null +++ b/src/caffe/layers/bias_layer.cu @@ -0,0 +1,54 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layers/bias_layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void BiasLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bias_data = + ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + if (bottom[0] != top[0]) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + caffe_copy(bottom[0]->count(), bottom_data, top_data); + } + for (int n = 0; n < outer_dim_; ++n) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, bias_dim_, + inner_dim_, Dtype(1), Dtype(1), bias_data, + bias_multiplier_.gpu_data(), Dtype(1), top_data); + top_data += dim_; + } +} + +template +void BiasLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[0] && bottom[0] != top[0]) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + caffe_copy(bottom[0]->count(), top_diff, bottom_diff); + } + // in-place, we don't need to do anything with the data diff + const bool bias_param = (bottom.size() == 1); + if ((!bias_param && propagate_down[1]) || + (bias_param && this->param_propagate_down_[0])) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bias_diff = (bias_param ? this->blobs_[0].get() : bottom[1]) + ->mutable_gpu_diff(); + bool accum = bias_param; + for (int n = 0; n < outer_dim_; ++n) { + caffe_gpu_gemv(CblasNoTrans, bias_dim_, inner_dim_, Dtype(1), + top_diff, bias_multiplier_.gpu_data(), Dtype(accum), bias_diff); + top_diff += dim_; + accum = true; + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(BiasLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 6f8cf101bb0..e0029a15dd9 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -306,7 +306,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 141 (last added: scalar_param) +// LayerParameter next available layer-specific ID: 142 (last added: bias_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -356,6 +356,7 @@ message LayerParameter { optional AccuracyParameter accuracy_param = 102; optional ArgMaxParameter argmax_param = 103; optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; optional ConcatParameter concat_param = 104; optional ContrastiveLossParameter contrastive_loss_param = 105; optional ConvolutionParameter convolution_param = 106; @@ -498,6 +499,38 @@ message BatchNormParameter { optional float eps = 3 [default = 1e-5]; } +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a literal bias. + optional int32 axis = 1 [default = 0]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // The default, 0, learns a zero-axis Blob: an actual bias multiplier. + optional int32 num_axes = 2 [default = 0]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; +} + message ContrastiveLossParameter { // margin for dissimilar pair optional float margin = 1 [default = 1.0]; diff --git a/src/caffe/test/test_bias_layer.cpp b/src/caffe/test/test_bias_layer.cpp new file mode 100644 index 00000000000..0d23d3f453c --- /dev/null +++ b/src/caffe/test/test_bias_layer.cpp @@ -0,0 +1,461 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layers/bias_layer.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class BiasLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + BiasLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_bottom_eltwise_(new Blob(2, 3, 4, 5)), + blob_bottom_broadcast_0_(new Blob()), + blob_bottom_broadcast_1_(new Blob()), + blob_bottom_broadcast_2_(new Blob()), + blob_bottom_bias_(new Blob(vector())), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + vector broadcast_shape(2); + broadcast_shape[0] = 2; broadcast_shape[1] = 3; + this->blob_bottom_broadcast_0_->Reshape(broadcast_shape); + broadcast_shape[0] = 3; broadcast_shape[1] = 4; + this->blob_bottom_broadcast_1_->Reshape(broadcast_shape); + broadcast_shape[0] = 4; broadcast_shape[1] = 5; + this->blob_bottom_broadcast_2_->Reshape(broadcast_shape); + FillerParameter filler_param; + filler_param.set_min(1); + filler_param.set_max(10); + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + filler.Fill(this->blob_bottom_eltwise_); + filler.Fill(this->blob_bottom_broadcast_0_); + filler.Fill(this->blob_bottom_broadcast_1_); + filler.Fill(this->blob_bottom_broadcast_2_); + filler.Fill(this->blob_bottom_bias_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~BiasLayerTest() { + delete blob_bottom_; + delete blob_bottom_eltwise_; + delete blob_bottom_broadcast_0_; + delete blob_bottom_broadcast_1_; + delete blob_bottom_broadcast_2_; + delete blob_bottom_bias_; + delete blob_top_; + } + Blob* const blob_bottom_; + Blob* const blob_bottom_eltwise_; + Blob* const blob_bottom_broadcast_0_; + Blob* const blob_bottom_broadcast_1_; + Blob* const blob_bottom_broadcast_2_; + Blob* const blob_bottom_bias_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(BiasLayerTest, TestDtypesAndDevices); + +TYPED_TEST(BiasLayerTest, TestForwardEltwise) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_->cpu_data(); + const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestForwardEltwiseInPlace) { + typedef typename TypeParam::Dtype Dtype; + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_bottom_->cpu_data(); + const int count = this->blob_bottom_->count(); + const Dtype* in_data_a = orig_bottom.cpu_data(); + const Dtype* in_data_b = this->blob_bottom_eltwise_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestBackwardEltwiseInPlace) { + typedef typename TypeParam::Dtype Dtype; + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + shared_ptr > layer(new BiasLayer(layer_param)); + Blob top_diff(this->blob_bottom_->shape()); + FillerParameter filler_param; + filler_param.set_type("gaussian"); + filler_param.set_std(1); + GaussianFiller filler(filler_param); + filler.Fill(&top_diff); + vector propagate_down(2, true); + // Run forward + backward without in-place computation; + // save resulting bottom diffs. + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const bool kReshape = true; + const bool kCopyDiff = true; + Blob orig_bottom_diff; + orig_bottom_diff.CopyFrom(*this->blob_bottom_, kCopyDiff, kReshape); + Blob orig_bias_diff; + orig_bias_diff.CopyFrom(*this->blob_bottom_eltwise_, + kCopyDiff, kReshape); + // Rerun forward + backward with in-place computation; + // check that resulting bottom diffs are the same. + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_bottom_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], + this->blob_bottom_->cpu_diff()[i], 1e-5); + } + for (int i = 0; i < this->blob_bottom_eltwise_->count(); ++i) { + EXPECT_NEAR(orig_bias_diff.cpu_diff()[i], + this->blob_bottom_eltwise_->cpu_diff()[i], 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestForwardEltwiseWithParam) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + BiasParameter* bias_param = layer_param.mutable_bias_param(); + bias_param->set_axis(0); + bias_param->set_num_axes(-1); + bias_param->mutable_filler()->set_type("gaussian"); + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_->cpu_data(); + const Dtype* in_data_b = layer->blobs()[0]->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] + in_data_b[i], 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBroadcastBegin) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_0_); + LayerParameter layer_param; + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) + + this->blob_bottom_broadcast_0_->data_at(n, c, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddle) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(1); + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) + + this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleInPlace) { + typedef typename TypeParam::Dtype Dtype; + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(1); + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_bottom_->data_at(n, c, h, w), + orig_bottom.data_at(n, c, h, w) + + this->blob_bottom_broadcast_1_->data_at(c, h, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(BiasLayerTest, TestBackwardBroadcastMiddleInPlace) { + typedef typename TypeParam::Dtype Dtype; + Blob orig_bottom(this->blob_bottom_->shape()); + orig_bottom.CopyFrom(*this->blob_bottom_); + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(1); + shared_ptr > layer(new BiasLayer(layer_param)); + Blob top_diff(this->blob_bottom_->shape()); + FillerParameter filler_param; + filler_param.set_type("gaussian"); + filler_param.set_std(1); + GaussianFiller filler(filler_param); + filler.Fill(&top_diff); + vector propagate_down(2, true); + // Run forward + backward without in-place computation; + // save resulting bottom diffs. + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_top_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const bool kReshape = true; + const bool kCopyDiff = true; + Blob orig_bottom_diff; + orig_bottom_diff.CopyFrom(*this->blob_bottom_, kCopyDiff, kReshape); + Blob orig_bias_diff; + orig_bias_diff.CopyFrom(*this->blob_bottom_broadcast_1_, + kCopyDiff, kReshape); + // Rerun forward + backward with in-place computation; + // check that resulting bottom diffs are the same. + this->blob_top_vec_[0] = this->blob_bottom_; // in-place computation + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + caffe_copy(top_diff.count(), top_diff.cpu_data(), + this->blob_bottom_->mutable_cpu_diff()); + layer->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(orig_bottom_diff.cpu_diff()[i], + this->blob_bottom_->cpu_diff()[i], 1e-5); + } + for (int i = 0; i < this->blob_bottom_broadcast_1_->count(); ++i) { + EXPECT_NEAR(orig_bias_diff.cpu_diff()[i], + this->blob_bottom_broadcast_1_->cpu_diff()[i], 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBroadcastMiddleWithParam) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + BiasParameter* bias_param = layer_param.mutable_bias_param(); + bias_param->set_axis(1); + bias_param->set_num_axes(2); + bias_param->mutable_filler()->set_type("gaussian"); + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) + + layer->blobs()[0]->data_at(c, h, 0, 0), 1e-5); + } + } + } + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBroadcastEnd) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(2); + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) + + this->blob_bottom_broadcast_2_->data_at(h, w, 0, 0), + 1e-5); + } + } + } + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBias) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_bias_); + LayerParameter layer_param; + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data = this->blob_bottom_->cpu_data(); + const Dtype bias = *this->blob_bottom_bias_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data[i] + bias, 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestForwardBiasAxis2) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_bias_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(2); + shared_ptr > layer(new BiasLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data = this->blob_bottom_->cpu_data(); + const Dtype bias = *this->blob_bottom_bias_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data[i] + bias, 1e-5); + } +} + +TYPED_TEST(BiasLayerTest, TestGradientEltwise) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_eltwise_); + LayerParameter layer_param; + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientEltwiseWithParam) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + BiasParameter* bias_param = layer_param.mutable_bias_param(); + bias_param->set_axis(0); + bias_param->set_num_axes(-1); + bias_param->mutable_filler()->set_type("gaussian"); + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientBroadcastBegin) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_0_); + LayerParameter layer_param; + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientBroadcastMiddle) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(1); + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientBroadcastMiddleWithParam) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_1_); + LayerParameter layer_param; + BiasParameter* bias_param = layer_param.mutable_bias_param(); + bias_param->set_axis(1); + bias_param->set_num_axes(2); + bias_param->mutable_filler()->set_type("gaussian"); + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientBroadcastEnd) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(2); + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientBias) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_bias_); + LayerParameter layer_param; + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(BiasLayerTest, TestGradientBiasAxis2) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_bias_); + LayerParameter layer_param; + layer_param.mutable_bias_param()->set_axis(2); + BiasLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe From 0a4156f8bdaf3681767e260b3c1409ab459c2218 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sun, 27 Dec 2015 10:18:05 -0800 Subject: [PATCH 5/7] BiasLayer Forward GPU kernel --- src/caffe/layers/bias_layer.cu | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/caffe/layers/bias_layer.cu b/src/caffe/layers/bias_layer.cu index 7711e89e6e5..8ac913a5d7b 100644 --- a/src/caffe/layers/bias_layer.cu +++ b/src/caffe/layers/bias_layer.cu @@ -6,22 +6,27 @@ namespace caffe { +template +__global__ void BiasForward(const int n, const Dtype* in, + const Dtype* bias, const int bias_dim, const int inner_dim, + Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + const int bias_index = (index / inner_dim) % bias_dim; + out[index] = in[index] + bias[bias_index]; + } +} + template void BiasLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { + const int count = top[0]->count(); + const Dtype* bottom_data = bottom[0]->gpu_data(); const Dtype* bias_data = ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); - if (bottom[0] != top[0]) { - const Dtype* bottom_data = bottom[0]->gpu_data(); - caffe_copy(bottom[0]->count(), bottom_data, top_data); - } - for (int n = 0; n < outer_dim_; ++n) { - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, bias_dim_, - inner_dim_, Dtype(1), Dtype(1), bias_data, - bias_multiplier_.gpu_data(), Dtype(1), top_data); - top_data += dim_; - } + BiasForward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, bottom_data, bias_data, bias_dim_, inner_dim_, top_data); } template From debf245f935f61389a45ce782232eeac5bc99d2a Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 30 Dec 2015 00:04:15 -0800 Subject: [PATCH 6/7] ScalarLayer bias_term option --- include/caffe/layers/scalar_layer.hpp | 7 +++++ src/caffe/layers/scalar_layer.cpp | 34 ++++++++++++++++++++++- src/caffe/layers/scalar_layer.cu | 7 +++++ src/caffe/proto/caffe.proto | 5 ++++ src/caffe/test/test_scalar_layer.cpp | 40 +++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 1 deletion(-) diff --git a/include/caffe/layers/scalar_layer.hpp b/include/caffe/layers/scalar_layer.hpp index f679622dde4..6c1853a5d92 100644 --- a/include/caffe/layers/scalar_layer.hpp +++ b/include/caffe/layers/scalar_layer.hpp @@ -7,6 +7,8 @@ #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/layers/bias_layer.hpp" + namespace caffe { /** @@ -63,6 +65,11 @@ class ScalarLayer: public Layer { virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); + shared_ptr > bias_layer_; + vector*> bias_bottom_vec_; + vector bias_propagate_down_; + int bias_param_id_; + Blob sum_multiplier_; Blob sum_result_; Blob temp_; diff --git a/src/caffe/layers/scalar_layer.cpp b/src/caffe/layers/scalar_layer.cpp index 58f5b9b00e4..d46b6cd64cd 100644 --- a/src/caffe/layers/scalar_layer.cpp +++ b/src/caffe/layers/scalar_layer.cpp @@ -2,6 +2,7 @@ #include #include "caffe/filler.hpp" +#include "caffe/layer_factory.hpp" #include "caffe/layers/scalar_layer.hpp" #include "caffe/util/math_functions.hpp" @@ -10,11 +11,11 @@ namespace caffe { template void ScalarLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { + const ScalarParameter& param = this->layer_param_.scalar_param(); if (bottom.size() == 1 && this->blobs_.size() > 0) { LOG(INFO) << "Skipping parameter initialization"; } else if (bottom.size() == 1) { // scalar is a learned parameter; initialize it - const ScalarParameter& param = this->layer_param_.scalar_param(); axis_ = bottom[0]->CanonicalAxisIndex(param.axis()); const int num_axes = param.num_axes(); CHECK_GE(num_axes, -1) << "num_axes must be non-negative, " @@ -40,6 +41,26 @@ void ScalarLayer::LayerSetUp(const vector*>& bottom, shared_ptr > filler(GetFiller(filler_param)); filler->Fill(this->blobs_[0].get()); } + if (param.bias_term()) { + LayerParameter layer_param(this->layer_param_); + layer_param.set_type("Bias"); + BiasParameter* bias_param = layer_param.mutable_bias_param(); + bias_param->set_axis(param.axis()); + if (bottom.size() > 1) { + bias_param->set_num_axes(bottom[1]->num_axes()); + } else { + bias_param->set_num_axes(param.num_axes()); + } + bias_param->mutable_filler()->CopyFrom(param.bias_filler()); + bias_layer_ = LayerRegistry::CreateLayer(layer_param); + bias_bottom_vec_.resize(1); + bias_bottom_vec_[0] = bottom[0]; + bias_layer_->SetUp(bias_bottom_vec_, top); + bias_param_id_ = this->blobs_.size(); + this->blobs_.resize(bias_param_id_ + 1); + this->blobs_[bias_param_id_] = bias_layer_->blobs()[0]; + bias_propagate_down_.resize(1, false); + } this->param_propagate_down_.resize(this->blobs_.size(), true); } @@ -78,6 +99,10 @@ void ScalarLayer::Reshape(const vector*>& bottom, if (sum_multiplier_.cpu_data()[sum_mult_size - 1] != Dtype(1)) { caffe_set(sum_mult_size, Dtype(1), sum_multiplier_.mutable_cpu_data()); } + if (bias_layer_) { + bias_bottom_vec_[0] = top[0]; + bias_layer_->Reshape(bias_bottom_vec_, top); + } } template @@ -103,11 +128,18 @@ void ScalarLayer::Forward_cpu( top_data += inner_dim_; } } + if (bias_layer_) { + bias_layer_->Forward(bias_bottom_vec_, top); + } } template void ScalarLayer::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { + if (bias_layer_ && + this->param_propagate_down_[this->param_propagate_down_.size() - 1]) { + bias_layer_->Backward(top, bias_propagate_down_, bias_bottom_vec_); + } const bool scalar_param = (bottom.size() == 1); Blob* scalar = scalar_param ? this->blobs_[0].get() : bottom[1]; if ((!scalar_param && propagate_down[1]) || diff --git a/src/caffe/layers/scalar_layer.cu b/src/caffe/layers/scalar_layer.cu index 9c6932723af..4f4bd16f4bd 100644 --- a/src/caffe/layers/scalar_layer.cu +++ b/src/caffe/layers/scalar_layer.cu @@ -35,11 +35,18 @@ void ScalarLayer::Forward_gpu( ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) <<>>( count, bottom_data, scalar_data, scalar_dim_, inner_dim_, top_data); + if (bias_layer_) { + bias_layer_->Forward(bias_bottom_vec_, top); + } } template void ScalarLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { + if (bias_layer_ && + this->param_propagate_down_[this->param_propagate_down_.size() - 1]) { + bias_layer_->Backward(top, bias_propagate_down_, bias_bottom_vec_); + } const bool scalar_param = (bottom.size() == 1); Blob* scalar = scalar_param ? this->blobs_[0].get() : bottom[1]; if ((!scalar_param && propagate_down[1]) || diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index e0029a15dd9..4a28034ed6c 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -1015,6 +1015,11 @@ message ScalarParameter { // Default is the unit (1) initialization, resulting in the ScalarLayer // initially performing the identity operation. optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScalarLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; } message SigmoidParameter { diff --git a/src/caffe/test/test_scalar_layer.cpp b/src/caffe/test/test_scalar_layer.cpp index 399d54a395e..8eb1554348c 100644 --- a/src/caffe/test/test_scalar_layer.cpp +++ b/src/caffe/test/test_scalar_layer.cpp @@ -310,6 +310,33 @@ TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddleWithParam) { } } +TYPED_TEST(ScalarLayerTest, TestForwardBroadcastMiddleWithParamAndBias) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ScalarParameter* scalar_param = layer_param.mutable_scalar_param(); + scalar_param->set_axis(1); + scalar_param->set_num_axes(2); + scalar_param->mutable_filler()->set_type("gaussian"); + scalar_param->set_bias_term(true); + scalar_param->mutable_bias_filler()->set_type("gaussian"); + shared_ptr > layer(new ScalarLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_bottom_->shape(), this->blob_top_->shape()); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_NEAR(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_->data_at(n, c, h, w) * + layer->blobs()[0]->data_at(c, h, 0, 0) + + layer->blobs()[1]->data_at(c, h, 0, 0), 1e-5); + } + } + } + } +} + TYPED_TEST(ScalarLayerTest, TestForwardBroadcastEnd) { typedef typename TypeParam::Dtype Dtype; this->blob_bottom_vec_.push_back(this->blob_bottom_broadcast_2_); @@ -447,6 +474,19 @@ TYPED_TEST(ScalarLayerTest, TestGradientScalar) { this->blob_top_vec_); } +TYPED_TEST(ScalarLayerTest, TestGradientScalarAndBias) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_scalar_); + LayerParameter layer_param; + ScalarParameter* scalar_param = layer_param.mutable_scalar_param(); + scalar_param->set_bias_term(true); + scalar_param->mutable_bias_filler()->set_type("gaussian"); + ScalarLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + TYPED_TEST(ScalarLayerTest, TestGradientScalarAxis2) { typedef typename TypeParam::Dtype Dtype; this->blob_bottom_vec_.push_back(this->blob_bottom_scalar_); From fd9f9ba14bcf9c3af3e69d68652b38c9258a1b2b Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 30 Dec 2015 00:12:08 -0800 Subject: [PATCH 7/7] ScalarLayer with bias single GPU kernel in Forward --- src/caffe/layers/scalar_layer.cu | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/caffe/layers/scalar_layer.cu b/src/caffe/layers/scalar_layer.cu index 4f4bd16f4bd..105634e2564 100644 --- a/src/caffe/layers/scalar_layer.cu +++ b/src/caffe/layers/scalar_layer.cu @@ -16,6 +16,16 @@ __global__ void ScalarForward(const int n, const Dtype* in, } } +template +__global__ void ScalarBiasForward(const int n, const Dtype* in, + const Dtype* scalar, const Dtype* bias, + const int scalar_dim, const int inner_dim, Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + const int scalar_index = (index / inner_dim) % scalar_dim; + out[index] = in[index] * scalar[scalar_index] + bias[scalar_index]; + } +} + template void ScalarLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { @@ -32,11 +42,16 @@ void ScalarLayer::Forward_gpu( const Dtype* scalar_data = ((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data(); Dtype* top_data = top[0]->mutable_gpu_data(); - ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) - <<>>( - count, bottom_data, scalar_data, scalar_dim_, inner_dim_, top_data); if (bias_layer_) { - bias_layer_->Forward(bias_bottom_vec_, top); + const Dtype* bias_data = this->blobs_[bias_param_id_]->gpu_data(); + ScalarBiasForward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, bottom_data, scalar_data, bias_data, scalar_dim_, inner_dim_, + top_data); + } else { + ScalarForward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, bottom_data, scalar_data, scalar_dim_, inner_dim_, top_data); } }