diff --git a/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt b/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt index 85c2dffe3f6..9ccd7ad862d 100644 --- a/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt +++ b/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt @@ -12,7 +12,7 @@ layer { } data_param { source: "examples/cifar10/cifar10_train_lmdb" - batch_size: 111 + batch_size: 100 backend: LMDB } } @@ -75,26 +75,43 @@ layer { type: "BatchNorm" bottom: "pool1" top: "bn1" - bn_param { - scale_filler { - type: "constant" - value: 1 - } - shift_filler { - type: "constant" - value: 0.001 - } + batch_norm_param { + use_global_stats: false } param { - lr_mult: 1.00001 - decay_mult: 0 + lr_mult: 0 } param { - lr_mult: 1.00001 - decay_mult: 0 + lr_mult: 0 + } + param { + lr_mult: 0 + } + include { + phase: TRAIN + } +} +layer { + name: "bn1" + type: "BatchNorm" + bottom: "pool1" + top: "bn1" + batch_norm_param { + use_global_stats: true + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + include { + phase: TEST } } - layer { name: "Sigmoid1" type: "Sigmoid" @@ -129,29 +146,47 @@ layer { } +layer { + name: "bn2" + type: "BatchNorm" + bottom: "conv2" + top: "bn2" + batch_norm_param { + use_global_stats: false + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + include { + phase: TRAIN + } +} layer { name: "bn2" type: "BatchNorm" bottom: "conv2" top: "bn2" - bn_param { - scale_filler { - type: "constant" - value: 1 - } - shift_filler { - type: "constant" - value: 0.001 - } + batch_norm_param { + use_global_stats: true } param { - lr_mult: 1.00001 - decay_mult: 0 + lr_mult: 0 } param { - lr_mult: 1.00001 - decay_mult: 0 + lr_mult: 0 + } + param { + lr_mult: 0 + } + include { + phase: TEST } } layer { @@ -204,23 +239,41 @@ layer { type: "BatchNorm" bottom: "conv3" top: "bn3" - bn_param { - scale_filler { - type: "constant" - value: 1 - } - shift_filler { - type: "constant" - value: 0.001 - } + batch_norm_param { + use_global_stats: false } param { - lr_mult: 1.00001 - decay_mult: 0 + lr_mult: 0 } param { - lr_mult: 1.00001 - decay_mult: 0 + lr_mult: 0 + } + param { + lr_mult: 0 + } + include { + phase: TRAIN + } +} +layer { + name: "bn3" + type: "BatchNorm" + bottom: "conv3" + top: "bn3" + batch_norm_param { + use_global_stats: true + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + include { + phase: TEST } } layer { diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 09605db9a53..da38f1227ba 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -79,9 +79,35 @@ class ArgMaxLayer : public Layer { }; /** -* @brief Batch Normalization per-channel with scale & shift linear transform. -* -*/ + * @brief Normalizes the input to have 0-mean and/or unit (1) variance across + * the batch. + * + * This layer computes Batch Normalization described in [1]. For + * each channel in the data (i.e. axis 1), it subtracts the mean and divides + * by the variance, where both statistics are computed across both spatial + * dimensions and across the different examples in the batch. + * + * By default, during training time, the network is computing global mean/ + * variance statistics via a running average, which is then used at test + * time to allow deterministic outputs for each input. You can manually + * toggle whether the network is accumulating or using the statistics via the + * use_global_stats option. IMPORTANT: for this feature to work, you MUST + * set the learning rate to zero for all three parameter blobs, i.e., + * param {lr_mult: 0} three times in the layer definition. + * + * Note that the original paper also included a per-channel learned bias and + * scaling factor. It is possible (though a bit cumbersome) to implement + * this in caffe using a single-channel DummyDataLayer filled with zeros, + * followed by a Convolution layer with output the same size as the current. + * This produces a channel-specific value that can be added or multiplied by + * the BatchNorm layer's output. + * + * [1] S. Ioffe and C. Szegedy, "Batch Normalization: Accelerating Deep Network + * Training by Reducing Internal Covariate Shift." arXiv preprint + * arXiv:1502.03167 (2015). + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ template class BatchNormLayer : public Layer { public: @@ -89,11 +115,10 @@ class BatchNormLayer : public Layer { : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); - virtual void Reshape(const vector*>& bottom, const vector*>& top); - virtual inline const char* type() const { return "BN"; } + virtual inline const char* type() const { return "BatchNorm"; } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } @@ -105,26 +130,19 @@ class BatchNormLayer : public Layer { virtual void Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); virtual void Backward_gpu(const vector*>& top, - const vector& propagate_down, const vector*>& bottom); - - // spatial mean & variance - Blob spatial_mean_, spatial_variance_; - // batch mean & variance - Blob batch_mean_, batch_variance_; - // buffer blob - Blob buffer_blob_; + const vector& propagate_down, const vector*>& bottom); - Blob x_norm_; - // x_sum_multiplier is used to carry out sum using BLAS - Blob spatial_sum_multiplier_, batch_sum_multiplier_; + Blob mean_, variance_, temp_, x_norm_; + bool use_global_stats_; + Dtype moving_average_fraction_; + int channels_; + Dtype eps_; - // dimension - int N_; - int C_; - int H_; - int W_; - // eps - Dtype var_eps_; + // extra temporarary variables is used to carry out sums/broadcasting + // using BLAS + Blob batch_sum_multiplier_; + Blob num_by_chans_; + Blob spatial_sum_multiplier_; }; /** diff --git a/src/caffe/layers/batch_norm_layer.cpp b/src/caffe/layers/batch_norm_layer.cpp index 8dea34932f3..b0484ff1fcd 100644 --- a/src/caffe/layers/batch_norm_layer.cpp +++ b/src/caffe/layers/batch_norm_layer.cpp @@ -2,350 +2,229 @@ #include #include "caffe/common_layers.hpp" -#include "caffe/filler.hpp" #include "caffe/layer.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { - template - void BatchNormLayer::Reshape(const vector*>& bottom, - const vector*>& top) { - top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - - x_norm_.Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - - // Figure out the dimensions - N_ = bottom[0]->num(); - C_ = bottom[0]->channels(); - H_ = bottom[0]->height(); - W_ = bottom[0]->width(); - // mean - spatial_mean_.Reshape(N_, C_, 1, 1); - batch_mean_.Reshape(1, C_, 1, 1); - // variance - spatial_variance_.Reshape(N_, C_, 1, 1); - batch_variance_.Reshape(1, C_, 1, 1); - // buffer blod - buffer_blob_.Reshape(N_, C_, H_, W_); - - // fill spatial multiplier - spatial_sum_multiplier_.Reshape(1, 1, H_, W_); - Dtype* spatial_multipl_data = spatial_sum_multiplier_.mutable_cpu_data(); - caffe_set(spatial_sum_multiplier_.count(), Dtype(1), - spatial_multipl_data); - caffe_set(spatial_sum_multiplier_.count(), Dtype(0), - spatial_sum_multiplier_.mutable_cpu_diff()); - // fill batch multiplier - batch_sum_multiplier_.Reshape(N_, 1, 1, 1); - Dtype* batch_multiplier_data = batch_sum_multiplier_.mutable_cpu_data(); - caffe_set(batch_sum_multiplier_.count(), Dtype(1), - batch_multiplier_data); - caffe_set(batch_sum_multiplier_.count(), Dtype(0), - batch_sum_multiplier_.mutable_cpu_diff()); - this->param_propagate_down_.resize(this->blobs_.size(), true); - } - template - void BatchNormLayer::LayerSetUp(const vector*>& bottom, +template +void BatchNormLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { - CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not " - "allow in-place computation."; - - top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - - x_norm_.Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); - // Figure out the dimensions - N_ = bottom[0]->num(); - C_ = bottom[0]->channels(); - H_ = bottom[0]->height(); - W_ = bottom[0]->width(); - var_eps_ = 1e-9; - - // mean - spatial_mean_.Reshape(N_, C_, 1, 1); - batch_mean_.Reshape(1, C_, 1, 1); - // variance - spatial_variance_.Reshape(N_, C_, 1, 1); - batch_variance_.Reshape(1, C_, 1, 1); - // buffer blod - buffer_blob_.Reshape(N_, C_, H_, W_); - - // fill spatial multiplier - spatial_sum_multiplier_.Reshape(1, 1, H_, W_); - Dtype* spatial_multipl_data = spatial_sum_multiplier_.mutable_cpu_data(); - caffe_set(spatial_sum_multiplier_.count(), Dtype(1), - spatial_multipl_data); - caffe_set(spatial_sum_multiplier_.count(), Dtype(0), - spatial_sum_multiplier_.mutable_cpu_diff()); - - // fill batch multiplier - batch_sum_multiplier_.Reshape(N_, 1, 1, 1); - Dtype* batch_multiplier_data = batch_sum_multiplier_.mutable_cpu_data(); - caffe_set(batch_sum_multiplier_.count(), Dtype(1), - batch_multiplier_data); - caffe_set(batch_sum_multiplier_.count(), Dtype(0), - batch_sum_multiplier_.mutable_cpu_diff()); - - // Check if we need to set up the weights - if (this->blobs_.size() > 0) { - LOG(INFO) << "Skipping parameter initialization"; - } else { - this->blobs_.resize(2); - - // fill scale with scale_filler - this->blobs_[0].reset(new Blob(1, C_, 1, 1)); - caffe_set(this->blobs_[0]->count(), Dtype(1), - this->blobs_[0]->mutable_cpu_data()); - - // fill shift with shift_filler - this->blobs_[1].reset(new Blob(1, C_, 1, 1)); - caffe_set(this->blobs_[1]->count(), Dtype(0), - this->blobs_[1]->mutable_cpu_data()); - } // parameter initialization - this->param_propagate_down_.resize(this->blobs_.size(), true); + BatchNormParameter param = this->layer_param_.batch_norm_param(); + moving_average_fraction_ = param.moving_average_fraction(); + use_global_stats_ = this->phase_ == TEST; + if (param.has_use_global_stats()) + use_global_stats_ = param.use_global_stats(); + if (bottom[0]->num_axes() == 1) + channels_ = 1; + else + channels_ = bottom[0]->channels(); + eps_ = param.eps(); + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + this->blobs_.resize(3); + vector sz; + sz.push_back(channels_); + this->blobs_[0].reset(new Blob(sz)); + this->blobs_[1].reset(new Blob(sz)); + sz[0]=1; + this->blobs_[2].reset(new Blob(sz)); + for (int i = 0; i < 3; ++i) { + caffe_set(this->blobs_[i]->count(), Dtype(0), + this->blobs_[i]->mutable_cpu_data()); + } } +} - template - void BatchNormLayer::Forward_cpu(const vector*>& bottom, +template +void BatchNormLayer::Reshape(const vector*>& bottom, const vector*>& top) { - const Dtype* bottom_data = bottom[0]->cpu_data(); - Dtype* top_data = top[0]->mutable_cpu_data(); - const Dtype* const_top_data = top[0]->cpu_data(); - - const Dtype* scale_data = this->blobs_[0]->cpu_data(); - const Dtype* shift_data = this->blobs_[1]->cpu_data(); - - // put the squares of bottom into buffer_blob_ - caffe_powx(bottom[0]->count(), bottom_data, Dtype(2), - buffer_blob_.mutable_cpu_data()); + if (bottom[0]->num_axes() >= 1) + CHECK_EQ(bottom[0]->shape(1), channels_); + top[0]->ReshapeLike(*bottom[0]); + + vector sz; + sz.push_back(channels_); + mean_.Reshape(sz); + variance_.Reshape(sz); + temp_.ReshapeLike(*bottom[0]); + x_norm_.ReshapeLike(*bottom[0]); + sz[0]=bottom[0]->shape(0); + batch_sum_multiplier_.Reshape(sz); + + int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0)); + if (spatial_sum_multiplier_.num_axes() == 0 || + spatial_sum_multiplier_.shape(0) != spatial_dim) { + sz[0] = spatial_dim; + spatial_sum_multiplier_.Reshape(sz); + Dtype* multiplier_data = spatial_sum_multiplier_.mutable_cpu_data(); + caffe_set(spatial_sum_multiplier_.count(), Dtype(1), multiplier_data); + } + int numbychans = channels_*bottom[0]->shape(0); + if (num_by_chans_.num_axes() == 0 || + num_by_chans_.shape(0) != numbychans) { + sz[0] = numbychans; + num_by_chans_.Reshape(sz); + caffe_set(batch_sum_multiplier_.count(), Dtype(1), + batch_sum_multiplier_.mutable_cpu_data()); + } +} + +template +void BatchNormLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + int num = bottom[0]->shape(0); + int spatial_dim = bottom[0]->height() * bottom[0]->width(); + + // elementwise square + caffe_powx(bottom[0]->count(), bottom_data, Dtype(2), + temp_.mutable_cpu_data()); + + if (use_global_stats_) { + // use the stored mean/variance estimates. TODO(cdoersch): allow an option + // to use an unbiased variance estimate, like the paper does. + const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0]; + caffe_cpu_scale(variance_.count(), scale_factor, + this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data()); + caffe_cpu_scale(variance_.count(), scale_factor, + this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data()); + } else { // computes variance using var(X) = E(X^2) - (EX)^2 - // EX across spatial - caffe_cpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, - Dtype(1. / (H_ * W_)), bottom_data, - spatial_sum_multiplier_.cpu_data(), Dtype(0), - spatial_mean_.mutable_cpu_data()); - // EX across batch - caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), - spatial_mean_.cpu_data(), - batch_sum_multiplier_.cpu_data(), Dtype(0), - batch_mean_.mutable_cpu_data()); - - // E(X^2) across spatial - caffe_cpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, - Dtype(1. / (H_ * W_)), buffer_blob_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - spatial_variance_.mutable_cpu_data()); - // E(X^2) across batch - caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), - spatial_variance_.cpu_data(), - batch_sum_multiplier_.cpu_data(), Dtype(0), - batch_variance_.mutable_cpu_data()); - - caffe_powx(batch_mean_.count(), batch_mean_.cpu_data(), Dtype(2), - buffer_blob_.mutable_cpu_data()); // (EX)^2 - caffe_sub(batch_mean_.count(), batch_variance_.cpu_data(), - buffer_blob_.cpu_data(), - batch_variance_.mutable_cpu_data()); // variance - - // do mean and variance normalization - // subtract mean - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, - C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), - batch_mean_.cpu_data(), Dtype(0), - spatial_mean_.mutable_cpu_data()); - - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(-1), - spatial_mean_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - buffer_blob_.mutable_cpu_data()); - - caffe_add(buffer_blob_.count(), bottom_data, - buffer_blob_.cpu_data(), top_data); - - // normalize variance - caffe_add_scalar(batch_variance_.count(), var_eps_, - batch_variance_.mutable_cpu_data()); - caffe_powx(batch_variance_.count(), - batch_variance_.cpu_data(), Dtype(0.5), - batch_variance_.mutable_cpu_data()); - - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, - C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), - batch_variance_.cpu_data(), Dtype(0), - spatial_variance_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_ * C_, H_ * W_, 1, Dtype(1), - spatial_variance_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - buffer_blob_.mutable_cpu_data()); - - caffe_div(buffer_blob_.count(), const_top_data, - buffer_blob_.cpu_data(), top_data); - - // Saving x_norm - caffe_copy(buffer_blob_.count(), const_top_data, - x_norm_.mutable_cpu_data()); - // scale - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), scale_data, Dtype(0), - spatial_variance_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_variance_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - buffer_blob_.mutable_cpu_data()); - caffe_mul(buffer_blob_.count(), top_data, - buffer_blob_.cpu_data(), top_data); - - // shift - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), shift_data, Dtype(0), - spatial_mean_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_ * C_, H_ * W_, 1, Dtype(1), - spatial_mean_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - buffer_blob_.mutable_cpu_data()); - caffe_add(buffer_blob_.count(), const_top_data, - buffer_blob_.cpu_data(), top_data); + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), bottom_data, + spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + mean_.mutable_cpu_data()); + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), temp_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + variance_.mutable_cpu_data()); + this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_; + this->blobs_[2]->mutable_cpu_data()[0] += 1; + caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(), + moving_average_fraction_, this->blobs_[0]->mutable_cpu_data()); + Dtype m = Dtype(bottom[0]->count()/channels_); + caffe_cpu_axpby(variance_.count(), m/(m-1), variance_.cpu_data(), + moving_average_fraction_, this->blobs_[1]->mutable_cpu_data()); } + // elementwise square of mean + caffe_powx(mean_.count(), mean_.cpu_data(), Dtype(2), + temp_.mutable_cpu_data()); - template - void BatchNormLayer::Backward_cpu(const vector*>& top, - const vector& propagate_down, - const vector*>& bottom) { - const Dtype* top_diff = top[0]->cpu_diff(); - const Dtype* bottom_data = bottom[0]->cpu_data(); - Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); - - Dtype* scale_diff = this->blobs_[0]->mutable_cpu_diff(); - Dtype* shift_diff = this->blobs_[1]->mutable_cpu_diff(); - const Dtype* scale_data = this->blobs_[0]->cpu_data(); - -// Propagate layer to parameters - // gradient w.r.t. scale - caffe_mul(buffer_blob_.count(), x_norm_.cpu_data(), - top_diff, buffer_blob_.mutable_cpu_data()); - // EX across spatial - caffe_cpu_gemv(CblasNoTrans, N_ * C_, - H_ * W_, Dtype(1), buffer_blob_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - spatial_variance_.mutable_cpu_diff()); - // EX across batch - caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_variance_.cpu_diff(), - batch_sum_multiplier_.cpu_data(), Dtype(0), scale_diff); - - // gradient w.r.t. shift - // EX across spatial - caffe_cpu_gemv(CblasNoTrans, N_ * C_, - H_ * W_, Dtype(1), top_diff, - spatial_sum_multiplier_.cpu_data(), - Dtype(0), spatial_mean_.mutable_cpu_diff()); - // EX across batch - caffe_cpu_gemv(CblasTrans, N_, C_, - Dtype(1), spatial_mean_.cpu_diff(), - batch_sum_multiplier_.cpu_data(), - Dtype(0), shift_diff); + caffe_sub(mean_.count(), variance_.cpu_data(), temp_.cpu_data(), + variance_.mutable_cpu_data()); // variance -// Propagate down + // normalize variance + caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data()); + caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5), + variance_.mutable_cpu_data()); - // put scale * top_diff to buffer_blob_ - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), scale_data, Dtype(0), - spatial_variance_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_variance_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - buffer_blob_.mutable_cpu_data()); - caffe_mul(buffer_blob_.count(), top_diff, buffer_blob_.cpu_data(), - buffer_blob_.mutable_cpu_data()); - - // use new top diff for computation - caffe_mul(buffer_blob_.count(), x_norm_.cpu_data(), - buffer_blob_.cpu_data(), bottom_diff); - // EX across spatial - caffe_cpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, - Dtype(1), bottom_diff, - spatial_sum_multiplier_.cpu_data(), Dtype(0), - spatial_mean_.mutable_cpu_data()); - // EX across batch - caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_mean_.cpu_data(), - batch_sum_multiplier_.cpu_data(), Dtype(0), - batch_mean_.mutable_cpu_data()); - - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_, C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), - batch_mean_.cpu_data(), Dtype(0), - spatial_mean_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_mean_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - bottom_diff); - - caffe_mul(buffer_blob_.count(), - x_norm_.cpu_data(), bottom_diff, bottom_diff); - - // EX across spatial - caffe_cpu_gemv(CblasNoTrans, N_ * C_, - H_ * W_, Dtype(1), buffer_blob_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - spatial_mean_.mutable_cpu_data()); - // EX across batch - caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_mean_.cpu_data(), - batch_sum_multiplier_.cpu_data(), Dtype(0), - batch_mean_.mutable_cpu_data()); - - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_, C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), - batch_mean_.cpu_data(), Dtype(0), - spatial_mean_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_ * C_, H_ * W_, 1, Dtype(1), - spatial_mean_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(1), bottom_diff); - - caffe_cpu_axpby(buffer_blob_.count(), Dtype(1), - buffer_blob_.cpu_data(), Dtype(-1. / (N_ * H_ * W_)), - bottom_diff); - - // put the squares of bottom into buffer_blob_ -// caffe_powx(buffer_blob_.count(), bottom_data, Dtype(2), -// buffer_blob_.mutable_cpu_data()); + // do mean and variance normalization + if (bottom[0] != top[0]) { + caffe_copy(bottom[0]->count(), bottom_data, top_data); + } + // subtract mean + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, -1, num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 1., top_data); + // replicate variance to input size + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data()); + caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data); + // TODO(cdoersch): The caching is only needed because later in-place layers + // might clobber the data. Can we skip this if they won't? + caffe_copy(x_norm_.count(), top_data, + x_norm_.mutable_cpu_data()); +} + +template +void BatchNormLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK(!use_global_stats_); + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* top_data = x_norm_.cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + int num = bottom[0]->shape()[0]; + int spatial_dim = bottom[0]->height() * bottom[0]->width(); + // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then + // + // dE(Y)/dX = + // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y) + // ./ sqrt(var(X) + eps) + // + // where \cdot and ./ are hadamard product and elementwise division, + // respectively, dE/dY is the top diff, and mean/var/sum are all computed + // along all dimensions except the channels dimension. In the above + // equation, the operations allow for expansion (i.e. broadcast) along all + // dimensions except the channels dimension where required. + + // sum(dE/dY \cdot Y) + caffe_mul(temp_.count(), top_data, top_diff, bottom_diff); + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + bottom_diff, spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + mean_.mutable_cpu_data()); + + // reshape (broadcast) the above + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 0., bottom_diff); + + // sum(dE/dY \cdot Y) \cdot Y + caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff); + + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + top_diff, spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + mean_.mutable_cpu_data()); + // reshape (broadcast) the above to make + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num * channels_, + spatial_dim, 1, 1., num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 1., bottom_diff); + + // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y + caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff, + Dtype(-1. / (num * spatial_dim)), bottom_diff); + + // note: temp_ still contains sqrt(var(X)+eps), computed during the forward + // pass. + caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff); +} - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_, C_, 1, Dtype(1), - batch_sum_multiplier_.cpu_data(), - batch_variance_.cpu_data(), Dtype(0), - spatial_variance_.mutable_cpu_data()); - caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, - N_ * C_, H_ * W_, 1, Dtype(1), - spatial_variance_.cpu_data(), - spatial_sum_multiplier_.cpu_data(), Dtype(0), - buffer_blob_.mutable_cpu_data()); - caffe_div(buffer_blob_.count(), bottom_diff, - buffer_blob_.cpu_data(), bottom_diff); - } #ifdef CPU_ONLY STUB_GPU(BatchNormLayer); #endif - INSTANTIATE_CLASS(BatchNormLayer); - REGISTER_LAYER_CLASS(BatchNorm); +INSTANTIATE_CLASS(BatchNormLayer); +REGISTER_LAYER_CLASS(BatchNorm); } // namespace caffe - diff --git a/src/caffe/layers/batch_norm_layer.cu b/src/caffe/layers/batch_norm_layer.cu index e87f8c62f43..bdc79b8abf0 100644 --- a/src/caffe/layers/batch_norm_layer.cu +++ b/src/caffe/layers/batch_norm_layer.cu @@ -2,227 +2,160 @@ #include #include "caffe/common_layers.hpp" -#include "caffe/filler.hpp" #include "caffe/layer.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { - template - void BatchNormLayer::Forward_gpu(const vector*>& bottom, - const vector*>& top) { - const Dtype* bottom_data = bottom[0]->gpu_data(); - const Dtype* const_top_data = top[0]->gpu_data(); - Dtype* top_data = top[0]->mutable_gpu_data(); - Dtype* spatial_mean_data = spatial_mean_.mutable_gpu_data(); - Dtype* buffer_data = buffer_blob_.mutable_gpu_data(); - const Dtype* const_buffer_data = buffer_blob_.gpu_data(); - - - // put the squares of bottom into buffer_blob_ - caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2), - buffer_blob_.mutable_gpu_data()); +template +void BatchNormLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + int num = bottom[0]->shape(0); + int spatial_dim = bottom[0]->height() * bottom[0]->width(); + + // elementwise square + caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2), + temp_.mutable_gpu_data()); + + if (use_global_stats_) { + // use the stored mean/variance estimates. TODO(cdoersch): allow an option + // to use an unbiased variance estimate, like the paper does. + const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0]; + caffe_gpu_scale(variance_.count(), scale_factor, + this->blobs_[0]->gpu_data(), mean_.mutable_gpu_data()); + caffe_gpu_scale(variance_.count(), scale_factor, + this->blobs_[1]->gpu_data(), variance_.mutable_gpu_data()); + } else { // computes variance using var(X) = E(X^2) - (EX)^2 - // EX across spatial - caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, - Dtype(1. / (H_ * W_)), - bottom_data, spatial_sum_multiplier_.gpu_data(), - Dtype(0), spatial_mean_data); - // EX across batch - caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), - spatial_mean_.gpu_data(), - batch_sum_multiplier_.gpu_data(), Dtype(0), - batch_mean_.mutable_gpu_data()); - - // E(X^2) across spatial - caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, - Dtype(1. / (H_ * W_)), buffer_data, - spatial_sum_multiplier_.gpu_data(), Dtype(0), - spatial_variance_.mutable_gpu_data()); - // E(X^2) across batch - caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), - spatial_variance_.gpu_data(), - batch_sum_multiplier_.gpu_data(), Dtype(0), - batch_variance_.mutable_gpu_data()); - - caffe_gpu_powx(batch_mean_.count(), batch_mean_.gpu_data(), - Dtype(2), buffer_blob_.mutable_gpu_data()); // (EX)^2 - caffe_gpu_sub(batch_mean_.count(), batch_variance_.gpu_data(), - buffer_data, batch_variance_.mutable_gpu_data()); // variance - - // do mean and variance normalization - // subtract mean - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), batch_mean_.gpu_data(), Dtype(0), - spatial_mean_data); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, H_ * W_, - 1, -Dtype(1), - spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), Dtype(0), - buffer_blob_.mutable_gpu_data()); - - caffe_gpu_add(buffer_blob_.count(), bottom_data, buffer_data, top_data); - - // normalize variance - caffe_gpu_add_scalar(batch_variance_.count(), var_eps_, - batch_variance_.mutable_gpu_data()); - caffe_gpu_powx(batch_variance_.count(), batch_variance_.gpu_data(), - Dtype(0.5), batch_variance_.mutable_gpu_data()); - - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), batch_variance_.gpu_data(), Dtype(0), - spatial_variance_.mutable_gpu_data()); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), - Dtype(0), buffer_blob_.mutable_gpu_data()); - - caffe_gpu_div(buffer_blob_.count(), top_data, buffer_data, top_data); - - // Saving x_norm - caffe_copy(top[0]->count(), const_top_data, x_norm_.mutable_gpu_data()); - - // scale - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), this->blobs_[0]->gpu_data(), - Dtype(0), spatial_variance_.mutable_gpu_data()); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), - Dtype(0), buffer_blob_.mutable_gpu_data()); - - caffe_gpu_mul(buffer_blob_.count(), top_data, buffer_data, top_data); - - // shift - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), - this->blobs_[1]->gpu_data(), Dtype(0), - spatial_mean_data); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, H_ * W_, 1, - Dtype(1), - spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), Dtype(0), - buffer_blob_.mutable_gpu_data()); - caffe_gpu_add(buffer_blob_.count(), top_data, buffer_data, top_data); + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), bottom_data, + spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + mean_.mutable_gpu_data()); + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), temp_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + variance_.mutable_gpu_data()); + this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_; + this->blobs_[2]->mutable_cpu_data()[0] += 1; + caffe_gpu_axpby(mean_.count(), Dtype(1), mean_.gpu_data(), + moving_average_fraction_, this->blobs_[0]->mutable_gpu_data()); + Dtype m = Dtype(bottom[0]->count()/channels_); + caffe_gpu_axpby(variance_.count(), m/(m-1), variance_.gpu_data(), + moving_average_fraction_, this->blobs_[1]->mutable_gpu_data()); } + // elementwise square of mean + caffe_gpu_powx(mean_.count(), mean_.gpu_data(), Dtype(2), + temp_.mutable_gpu_data()); + + caffe_gpu_sub(mean_.count(), variance_.gpu_data(), temp_.gpu_data(), + variance_.mutable_gpu_data()); // variance - template - void BatchNormLayer::Backward_gpu(const vector*>& top, - const vector& propagate_down, - const vector*>& bottom) { - const Dtype* top_diff = top[0]->gpu_diff(); - const Dtype* top_data = top[0]->gpu_data(); - const Dtype* bottom_data = bottom[0]->gpu_data(); - Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); - const Dtype* const_bottom_diff = bottom[0]->gpu_diff(); - Dtype* spatial_mean_data = spatial_mean_.mutable_gpu_data(); - Dtype* buffer_data = buffer_blob_.mutable_gpu_data(); - const Dtype* const_buffer_data = buffer_blob_.gpu_data(); - - // Propage to layer params - // gradient w.r.t. scale - caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(), - top_diff, buffer_blob_.mutable_gpu_data()); - // EX across spatial - caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1), - buffer_data, spatial_sum_multiplier_.gpu_data(), Dtype(0), - spatial_variance_.mutable_gpu_data()); - // EX across batch - caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_variance_.gpu_data(), - batch_sum_multiplier_.gpu_data(), Dtype(0), - this->blobs_[0]->mutable_gpu_diff()); - - // gradient w.r.t. shift - // EX across spatial - caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1), - top_diff, spatial_sum_multiplier_.gpu_data(), - Dtype(0), spatial_mean_data); - // EX across batch - caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_mean_.gpu_data(), - batch_sum_multiplier_.gpu_data(), Dtype(0), - this->blobs_[1]->mutable_gpu_diff()); - - // Propagate down - // scale top diff - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), this->blobs_[0]->gpu_data(), - Dtype(0), spatial_variance_.mutable_gpu_data()); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), - Dtype(0), - buffer_blob_.mutable_gpu_data()); - caffe_gpu_mul(buffer_blob_.count(), top_diff, buffer_data, - buffer_blob_.mutable_gpu_data()); - - // use new top diff for computation - caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(), - buffer_data, bottom_diff); - // EX across spatial - caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, - Dtype(1), bottom_diff, - spatial_sum_multiplier_.gpu_data(), Dtype(0), spatial_mean_data); - // EX across batch - caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_mean_.gpu_data(), - batch_sum_multiplier_.gpu_data(), Dtype(0), - batch_mean_.mutable_gpu_data()); - - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), - batch_mean_.gpu_data(), Dtype(0), - spatial_mean_data); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), spatial_mean_.gpu_data(), - spatial_sum_multiplier_.gpu_data(), Dtype(0), - bottom_diff); - - caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(), - bottom_diff, bottom_diff); - - // EX across spatial - caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1), - buffer_data, spatial_sum_multiplier_.gpu_data(), - Dtype(0), spatial_mean_data); - - // EX across batch - caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), - spatial_mean_.gpu_data(), - batch_sum_multiplier_.gpu_data(), Dtype(0), - batch_mean_.mutable_gpu_data()); - - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, - C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), - batch_mean_.gpu_data(), Dtype(0), - spatial_mean_data); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), - Dtype(1), - bottom_diff); - - caffe_gpu_axpby(buffer_blob_.count(), Dtype(1), buffer_data, - Dtype(-1. / (N_ * H_ * W_)), - bottom_diff); - - // put the squares of bottom into buffer_blob_ -// caffe_gpu_powx(buffer_blob_.count(), bottom_data, Dtype(2), -// buffer_blob_.mutable_gpu_data()); - - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), - batch_sum_multiplier_.gpu_data(), batch_variance_.gpu_data(), Dtype(0), - spatial_variance_.mutable_gpu_data()); - caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, - H_ * W_, 1, Dtype(1), - spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), - Dtype(0), - buffer_blob_.mutable_gpu_data()); - - caffe_gpu_div(buffer_blob_.count(), const_bottom_diff, - const_buffer_data, bottom_diff); + // normalize variance + caffe_gpu_add_scalar(variance_.count(), eps_, variance_.mutable_gpu_data()); + caffe_gpu_powx(variance_.count(), variance_.gpu_data(), Dtype(0.5), + variance_.mutable_gpu_data()); + + // do mean and variance normalization + if (bottom[0] != top[0]) { + caffe_copy(bottom[0]->count(), bottom_data, top_data); } + // subtract mean + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, -1, num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 1., top_data); + // replicate variance to input size + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), variance_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 0., temp_.mutable_gpu_data()); + caffe_gpu_div(temp_.count(), top_data, temp_.gpu_data(), top_data); + // TODO(cdoersch): The caching is only needed because later in-place layers + // might clobber the data. Can we skip this if they won't? + caffe_copy(x_norm_.count(), top_data, + x_norm_.mutable_gpu_data()); +} + +template +void BatchNormLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK(!use_global_stats_); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = x_norm_.gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + int num = bottom[0]->shape()[0]; + int spatial_dim = bottom[0]->height() * bottom[0]->width(); + // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then + // + // dE(Y)/dX = + // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y) + // ./ sqrt(var(X) + eps) + // + // where \cdot and ./ are hadamard product and elementwise division, + // respectively, dE/dY is the top diff, and mean/var/sum are all computed + // along all dimensions except the channels dimension. In the above + // equation, the operations allow for expansion (i.e. broadcast) along all + // dimensions except the channels dimension where required. + + // sum(dE/dY \cdot Y) + caffe_gpu_mul(temp_.count(), top_data, top_diff, bottom_diff); + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + bottom_diff, spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + mean_.mutable_gpu_data()); + + // reshape (broadcast) the above + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 0., bottom_diff); + + // sum(dE/dY \cdot Y) \cdot Y + caffe_gpu_mul(temp_.count(), top_data, bottom_diff, bottom_diff); + + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + top_diff, spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + mean_.mutable_gpu_data()); + // reshape (broadcast) the above to make + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num * channels_, + spatial_dim, 1, 1., num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 1., bottom_diff); + + // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y + caffe_gpu_axpby(temp_.count(), Dtype(1), top_diff, + Dtype(-1. / (num * spatial_dim)), bottom_diff); + + // note: temp_ still contains sqrt(var(X)+eps), computed during the forward + // pass. + caffe_gpu_div(temp_.count(), bottom_diff, temp_.gpu_data(), bottom_diff); +} + +INSTANTIATE_LAYER_GPU_FUNCS(BatchNormLayer); - INSTANTIATE_LAYER_GPU_FUNCS(BatchNormLayer); -} // namespace caffe +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index a8747c12b37..99dd3c90eef 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -301,7 +301,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 139 (last added: tile_param) +// LayerParameter next available layer-specific ID: 140 (last added: batch_norm_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -350,6 +350,7 @@ message LayerParameter { // The default for the engine is set by the ENGINE switch at compile-time. optional AccuracyParameter accuracy_param = 102; optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; optional ConcatParameter concat_param = 104; optional ContrastiveLossParameter contrastive_loss_param = 105; optional ConvolutionParameter convolution_param = 106; @@ -461,6 +462,18 @@ message ConcatParameter { optional uint32 concat_dim = 1 [default = 1]; } +message BatchNormParameter { + // If false, accumulate global mean/variance values via a moving average. If + // true, use those accumulated values instead of computing mean/variance + // across the batch. + optional bool use_global_stats = 1; + // How much does the moving average decay each iteration? + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + message ContrastiveLossParameter { // margin for dissimilar pair optional float margin = 1 [default = 1.0]; diff --git a/src/caffe/test/test_batch_norm_layer.cpp b/src/caffe/test/test_batch_norm_layer.cpp index 704efd5df3d..97fbb26dc9b 100644 --- a/src/caffe/test/test_batch_norm_layer.cpp +++ b/src/caffe/test/test_batch_norm_layer.cpp @@ -60,7 +60,6 @@ namespace caffe { for ( int k = 0; k < height; ++k ) { for ( int l = 0; l < width; ++l ) { Dtype data = this->blob_top_->data_at(i, j, k, l); - Dtype bottom_data = this->blob_bottom_->data_at(i, j, k, l); sum += data; var += data * data; }