diff --git a/examples/cifar10/cifar10_full_sigmoid_solver.prototxt b/examples/cifar10/cifar10_full_sigmoid_solver.prototxt new file mode 100644 index 00000000000..7dd3ecb9d8e --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_solver.prototxt @@ -0,0 +1,28 @@ +# reduce learning rate after 120 epochs (60000 iters) by factor 0f 10 +# then another factor of 10 after 10 more epochs (5000 iters) + +# The train/test net protocol buffer definition +net: "examples/cifar10/cifar10_full_sigmoid_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of CIFAR10, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 10 +# Carry out testing every 1000 training iterations. +test_interval: 1000 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.001 +momentum: 0.9 +#weight_decay: 0.004 +# The learning rate policy +lr_policy: "step" +gamma: 1 +stepsize: 5000 +# Display every 200 iterations +display: 100 +# The maximum number of iterations +max_iter: 60000 +# snapshot intermediate results +snapshot: 10000 +snapshot_prefix: "examples/cifar10_full_sigmoid" +# solver mode: CPU or GPU +solver_mode: GPU diff --git a/examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt b/examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt new file mode 100644 index 00000000000..a57b280fd1e --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt @@ -0,0 +1,28 @@ +# reduce learning rate after 120 epochs (60000 iters) by factor 0f 10 +# then another factor of 10 after 10 more epochs (5000 iters) + +# The train/test net protocol buffer definition +net: "examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of CIFAR10, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 10 +# Carry out testing every 1000 training iterations. +test_interval: 1000 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.001 +momentum: 0.9 +#weight_decay: 0.004 +# The learning rate policy +lr_policy: "step" +gamma: 1 +stepsize: 5000 +# Display every 200 iterations +display: 100 +# The maximum number of iterations +max_iter: 60000 +# snapshot intermediate results +snapshot: 10000 +snapshot_prefix: "examples/cifar10_full_sigmoid_bn" +# solver mode: CPU or GPU +solver_mode: GPU diff --git a/examples/cifar10/cifar10_full_sigmoid_train_test.prototxt b/examples/cifar10/cifar10_full_sigmoid_train_test.prototxt new file mode 100644 index 00000000000..fba69b814ad --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_train_test.prototxt @@ -0,0 +1,212 @@ +name: "CIFAR10_full" +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TRAIN + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_train_lmdb" + batch_size: 111 + backend: LMDB + } +} +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TEST + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_test_lmdb" + batch_size: 1000 + backend: LMDB + } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + } + param { + lr_mult: 2 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.0001 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + + + +layer { + name: "Sigmoid1" + type: "Sigmoid" + bottom: "pool1" + top: "Sigmoid1" +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "Sigmoid1" + top: "conv2" + param { + lr_mult: 1 + } + param { + lr_mult: 2 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} + + +layer { + name: "Sigmoid2" + type: "Sigmoid" + bottom: "conv2" + top: "Sigmoid2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "Sigmoid2" + top: "pool2" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } + param { + lr_mult: 1 + } + param { + lr_mult: 1 + } + +} + +layer { + name: "Sigmoid3" + type: "Sigmoid" + bottom: "conv3" + top: "Sigmoid3" +} + +layer { + name: "pool3" + type: "Pooling" + bottom: "Sigmoid3" + top: "pool3" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "ip1" + type: "InnerProduct" + bottom: "pool3" + top: "ip1" + param { + lr_mult: 1 + decay_mult: 0 + } + param { + lr_mult: 2 + decay_mult: 0 + } + inner_product_param { + num_output: 10 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "accuracy" + type: "Accuracy" + bottom: "ip1" + bottom: "label" + top: "accuracy" + include { + phase: TEST + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "ip1" + bottom: "label" + top: "loss" +} diff --git a/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt b/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt new file mode 100644 index 00000000000..1a810751177 --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt @@ -0,0 +1,240 @@ +name: "CIFAR10_full" +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TRAIN + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_train_lmdb" + batch_size: 100 + backend: LMDB + } +} +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TEST + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_test_lmdb" + batch_size: 1000 + backend: LMDB + } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + bias_term: false + weight_filler { + type: "gaussian" + std: 0.0001 + } + } +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "bn1" + type: "BatchNorm" + bottom: "pool1" + top: "bn1" + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } +} + +layer { + name: "Sigmoid1" + type: "Sigmoid" + bottom: "bn1" + top: "Sigmoid1" +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "Sigmoid1" + top: "conv2" + param { + lr_mult: 1 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + bias_term: false + weight_filler { + type: "gaussian" + std: 0.01 + } + } +} + +layer { + name: "bn2" + type: "BatchNorm" + bottom: "conv2" + top: "bn2" + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } +} + +layer { + name: "Sigmoid2" + type: "Sigmoid" + bottom: "bn2" + top: "Sigmoid2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "Sigmoid2" + top: "pool2" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + param { + lr_mult: 1 + } + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + stride: 1 + bias_term: false + weight_filler { + type: "gaussian" + std: 0.01 + } + } +} + +layer { + name: "bn3" + type: "BatchNorm" + bottom: "conv3" + top: "bn3" + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } + param { + lr_mult: 0 + } +} + +layer { + name: "Sigmoid3" + type: "Sigmoid" + bottom: "bn3" + top: "Sigmoid3" +} +layer { + name: "pool3" + type: "Pooling" + bottom: "Sigmoid3" + top: "pool3" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "ip1" + type: "InnerProduct" + bottom: "pool3" + top: "ip1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 1 + decay_mult: 0 + } + inner_product_param { + num_output: 10 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "accuracy" + type: "Accuracy" + bottom: "ip1" + bottom: "label" + top: "accuracy" + include { + phase: TEST + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "ip1" + bottom: "label" + top: "loss" +} diff --git a/examples/cifar10/train_full_sigmoid.sh b/examples/cifar10/train_full_sigmoid.sh new file mode 100755 index 00000000000..9cff06d3e34 --- /dev/null +++ b/examples/cifar10/train_full_sigmoid.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +TOOLS=./build/tools + +$TOOLS/caffe train \ + --solver=examples/cifar10/cifar10_full_sigmoid_solver.prototxt + diff --git a/examples/cifar10/train_full_sigmoid_bn.sh b/examples/cifar10/train_full_sigmoid_bn.sh new file mode 100755 index 00000000000..011387c996e --- /dev/null +++ b/examples/cifar10/train_full_sigmoid_bn.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +TOOLS=./build/tools + +$TOOLS/caffe train \ + --solver=examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt + diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 21a27d759a8..da38f1227ba 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -78,6 +78,73 @@ class ArgMaxLayer : public Layer { int axis_; }; +/** + * @brief Normalizes the input to have 0-mean and/or unit (1) variance across + * the batch. + * + * This layer computes Batch Normalization described in [1]. For + * each channel in the data (i.e. axis 1), it subtracts the mean and divides + * by the variance, where both statistics are computed across both spatial + * dimensions and across the different examples in the batch. + * + * By default, during training time, the network is computing global mean/ + * variance statistics via a running average, which is then used at test + * time to allow deterministic outputs for each input. You can manually + * toggle whether the network is accumulating or using the statistics via the + * use_global_stats option. IMPORTANT: for this feature to work, you MUST + * set the learning rate to zero for all three parameter blobs, i.e., + * param {lr_mult: 0} three times in the layer definition. + * + * Note that the original paper also included a per-channel learned bias and + * scaling factor. It is possible (though a bit cumbersome) to implement + * this in caffe using a single-channel DummyDataLayer filled with zeros, + * followed by a Convolution layer with output the same size as the current. + * This produces a channel-specific value that can be added or multiplied by + * the BatchNorm layer's output. + * + * [1] S. Ioffe and C. Szegedy, "Batch Normalization: Accelerating Deep Network + * Training by Reducing Internal Covariate Shift." arXiv preprint + * arXiv:1502.03167 (2015). + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class BatchNormLayer : public Layer { + public: + explicit BatchNormLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "BatchNorm"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob mean_, variance_, temp_, x_norm_; + bool use_global_stats_; + Dtype moving_average_fraction_; + int channels_; + Dtype eps_; + + // extra temporarary variables is used to carry out sums/broadcasting + // using BLAS + Blob batch_sum_multiplier_; + Blob num_by_chans_; + Blob spatial_sum_multiplier_; +}; + /** * @brief Index into the input blob along its first axis. * @@ -146,7 +213,6 @@ class BatchReindexLayer : public Layer { const Dtype* ridx_data); }; - /** * @brief Takes at least two Blob%s and concatenates them along either the num * or channel dimension, outputting the result. diff --git a/src/caffe/layers/batch_norm_layer.cpp b/src/caffe/layers/batch_norm_layer.cpp new file mode 100644 index 00000000000..94c2b96b9cd --- /dev/null +++ b/src/caffe/layers/batch_norm_layer.cpp @@ -0,0 +1,236 @@ +#include +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void BatchNormLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + BatchNormParameter param = this->layer_param_.batch_norm_param(); + moving_average_fraction_ = param.moving_average_fraction(); + use_global_stats_ = this->phase_ == TEST; + if (param.has_use_global_stats()) + use_global_stats_ = param.use_global_stats(); + if (bottom[0]->num_axes() == 1) + channels_ = 1; + else + channels_ = bottom[0]->shape(1); + eps_ = param.eps(); + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + this->blobs_.resize(3); + vector sz; + sz.push_back(channels_); + this->blobs_[0].reset(new Blob(sz)); + this->blobs_[1].reset(new Blob(sz)); + sz[0]=1; + this->blobs_[2].reset(new Blob(sz)); + for (int i = 0; i < 3; ++i) { + caffe_set(this->blobs_[i]->count(), Dtype(0), + this->blobs_[i]->mutable_cpu_data()); + } + } +} + +template +void BatchNormLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + if (bottom[0]->num_axes() >= 1) + CHECK_EQ(bottom[0]->shape(1), channels_); + top[0]->ReshapeLike(*bottom[0]); + + vector sz; + sz.push_back(channels_); + mean_.Reshape(sz); + variance_.Reshape(sz); + temp_.ReshapeLike(*bottom[0]); + x_norm_.ReshapeLike(*bottom[0]); + sz[0]=bottom[0]->shape(0); + batch_sum_multiplier_.Reshape(sz); + + int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0)); + if (spatial_sum_multiplier_.num_axes() == 0 || + spatial_sum_multiplier_.shape(0) != spatial_dim) { + sz[0] = spatial_dim; + spatial_sum_multiplier_.Reshape(sz); + Dtype* multiplier_data = spatial_sum_multiplier_.mutable_cpu_data(); + caffe_set(spatial_sum_multiplier_.count(), Dtype(1), multiplier_data); + } + + int numbychans = channels_*bottom[0]->shape(0); + if (num_by_chans_.num_axes() == 0 || + num_by_chans_.shape(0) != numbychans) { + sz[0] = numbychans; + num_by_chans_.Reshape(sz); + caffe_set(batch_sum_multiplier_.count(), Dtype(1), + batch_sum_multiplier_.mutable_cpu_data()); + } +} + +template +void BatchNormLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + int num = bottom[0]->shape(0); + int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_); + + // elementwise square + caffe_powx(bottom[0]->count(), bottom_data, Dtype(2), + temp_.mutable_cpu_data()); + + if (use_global_stats_) { + // use the stored mean/variance estimates. TODO(cdoersch): allow an option + // to use an unbiased variance estimate, like the paper does. + const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0]; + caffe_cpu_scale(variance_.count(), scale_factor, + this->blobs_[0]->cpu_data(), mean_.mutable_cpu_data()); + caffe_cpu_scale(variance_.count(), scale_factor, + this->blobs_[1]->cpu_data(), variance_.mutable_cpu_data()); + } else { + // computes variance using var(X) = E(X^2) - (EX)^2 + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), bottom_data, + spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + mean_.mutable_cpu_data()); + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), temp_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + variance_.mutable_cpu_data()); + this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_; + this->blobs_[2]->mutable_cpu_data()[0] += 1; + caffe_cpu_axpby(mean_.count(), Dtype(1), mean_.cpu_data(), + moving_average_fraction_, this->blobs_[0]->mutable_cpu_data()); + Dtype m = Dtype(bottom[0]->count()/channels_); + caffe_cpu_axpby(variance_.count(), m/(m-1), variance_.cpu_data(), + moving_average_fraction_, this->blobs_[1]->mutable_cpu_data()); + } + // elementwise square of mean + caffe_powx(mean_.count(), mean_.cpu_data(), Dtype(2), + temp_.mutable_cpu_data()); + + caffe_sub(mean_.count(), variance_.cpu_data(), temp_.cpu_data(), + variance_.mutable_cpu_data()); // variance + + // normalize variance + caffe_add_scalar(variance_.count(), eps_, variance_.mutable_cpu_data()); + caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5), + variance_.mutable_cpu_data()); + + // do mean and variance normalization + if (bottom[0] != top[0]) { + caffe_copy(bottom[0]->count(), bottom_data, top_data); + } + // subtract mean + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, -1, num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 1., top_data); + // replicate variance to input size + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), variance_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data()); + caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data); + // TODO(cdoersch): The caching is only needed because later in-place layers + // might clobber the data. Can we skip this if they won't? + caffe_copy(x_norm_.count(), top_data, + x_norm_.mutable_cpu_data()); +} + +template +void BatchNormLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK(!use_global_stats_); + const Dtype* top_diff; + if (bottom[0] != top[0]) { + top_diff = top[0]->cpu_diff(); + } else { + caffe_copy(x_norm_.count(), top[0]->cpu_diff(), x_norm_.mutable_cpu_diff()); + top_diff = x_norm_.cpu_diff(); + } + const Dtype* top_data = x_norm_.cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + int num = bottom[0]->shape()[0]; + int spatial_dim = bottom[0]->count()/(bottom[0]->shape(0)*channels_); + // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then + // + // dE(Y)/dX = + // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y) + // ./ sqrt(var(X) + eps) + // + // where \cdot and ./ are hadamard product and elementwise division, + // respectively, dE/dY is the top diff, and mean/var/sum are all computed + // along all dimensions except the channels dimension. In the above + // equation, the operations allow for expansion (i.e. broadcast) along all + // dimensions except the channels dimension where required. + + // sum(dE/dY \cdot Y) + caffe_mul(temp_.count(), top_data, top_diff, bottom_diff); + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + bottom_diff, spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + mean_.mutable_cpu_data()); + + // reshape (broadcast) the above + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 0., bottom_diff); + + // sum(dE/dY \cdot Y) \cdot Y + caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff); + + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_cpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + top_diff, spatial_sum_multiplier_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.cpu_data(), batch_sum_multiplier_.cpu_data(), 0., + mean_.mutable_cpu_data()); + // reshape (broadcast) the above to make + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.cpu_data(), mean_.cpu_data(), 0., + num_by_chans_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num * channels_, + spatial_dim, 1, 1., num_by_chans_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), 1., bottom_diff); + + // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y + caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff, + Dtype(-1. / (num * spatial_dim)), bottom_diff); + + // note: temp_ still contains sqrt(var(X)+eps), computed during the forward + // pass. + caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff); +} + + +#ifdef CPU_ONLY +STUB_GPU(BatchNormLayer); +#endif + +INSTANTIATE_CLASS(BatchNormLayer); +REGISTER_LAYER_CLASS(BatchNorm); +} // namespace caffe diff --git a/src/caffe/layers/batch_norm_layer.cu b/src/caffe/layers/batch_norm_layer.cu new file mode 100644 index 00000000000..cd8924a451d --- /dev/null +++ b/src/caffe/layers/batch_norm_layer.cu @@ -0,0 +1,167 @@ +#include +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void BatchNormLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + int num = bottom[0]->shape(0); + int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0)); + + // elementwise square + caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2), + temp_.mutable_gpu_data()); + + if (use_global_stats_) { + // use the stored mean/variance estimates. TODO(cdoersch): allow an option + // to use an unbiased variance estimate, like the paper does. + const Dtype scale_factor = 1 / this->blobs_[2]->cpu_data()[0]; + caffe_gpu_scale(variance_.count(), scale_factor, + this->blobs_[0]->gpu_data(), mean_.mutable_gpu_data()); + caffe_gpu_scale(variance_.count(), scale_factor, + this->blobs_[1]->gpu_data(), variance_.mutable_gpu_data()); + } else { + // computes variance using var(X) = E(X^2) - (EX)^2 + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), bottom_data, + spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + mean_.mutable_gpu_data()); + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, + 1. / (num * spatial_dim), temp_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + variance_.mutable_gpu_data()); + this->blobs_[2]->mutable_cpu_data()[0] *= moving_average_fraction_; + this->blobs_[2]->mutable_cpu_data()[0] += 1; + caffe_gpu_axpby(mean_.count(), Dtype(1), mean_.gpu_data(), + moving_average_fraction_, this->blobs_[0]->mutable_gpu_data()); + Dtype m = Dtype(bottom[0]->count()/channels_); + caffe_gpu_axpby(variance_.count(), m/(m-1), variance_.gpu_data(), + moving_average_fraction_, this->blobs_[1]->mutable_gpu_data()); + } + // elementwise square of mean + caffe_gpu_powx(mean_.count(), mean_.gpu_data(), Dtype(2), + temp_.mutable_gpu_data()); + + caffe_gpu_sub(mean_.count(), variance_.gpu_data(), temp_.gpu_data(), + variance_.mutable_gpu_data()); // variance + + // normalize variance + caffe_gpu_add_scalar(variance_.count(), eps_, variance_.mutable_gpu_data()); + caffe_gpu_powx(variance_.count(), variance_.gpu_data(), Dtype(0.5), + variance_.mutable_gpu_data()); + + // do mean and variance normalization + if (bottom[0] != top[0]) { + caffe_copy(bottom[0]->count(), bottom_data, top_data); + } + // subtract mean + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, -1, num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 1., top_data); + // replicate variance to input size + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), variance_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 0., temp_.mutable_gpu_data()); + caffe_gpu_div(temp_.count(), top_data, temp_.gpu_data(), top_data); + // TODO(cdoersch): The caching is only needed because later in-place layers + // might clobber the data. Can we skip this if they won't? + caffe_copy(x_norm_.count(), top_data, + x_norm_.mutable_gpu_data()); +} + +template +void BatchNormLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + CHECK(!use_global_stats_); + const Dtype* top_diff; + if (bottom[0] != top[0]) { + top_diff = top[0]->gpu_diff(); + } else { + caffe_copy(x_norm_.count(), top[0]->gpu_diff(), x_norm_.mutable_gpu_diff()); + top_diff = x_norm_.gpu_diff(); + } + const Dtype* top_data = x_norm_.gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + int num = bottom[0]->shape()[0]; + int spatial_dim = bottom[0]->count()/(channels_*bottom[0]->shape(0)); + // if Y = (X-mean(X))/(sqrt(var(X)+eps)), then + // + // dE(Y)/dX = + // (dE/dY - mean(dE/dY) - mean(dE/dY \cdot Y) \cdot Y) + // ./ sqrt(var(X) + eps) + // + // where \cdot and ./ are hadamard product and elementwise division, + // respectively, dE/dY is the top diff, and mean/var/sum are all computed + // along all dimensions except the channels dimension. In the above + // equation, the operations allow for expansion (i.e. broadcast) along all + // dimensions except the channels dimension where required. + + // sum(dE/dY \cdot Y) + caffe_gpu_mul(temp_.count(), top_data, top_diff, bottom_diff); + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + bottom_diff, spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + mean_.mutable_gpu_data()); + + // reshape (broadcast) the above + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, channels_ * num, + spatial_dim, 1, 1., num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 0., bottom_diff); + + // sum(dE/dY \cdot Y) \cdot Y + caffe_gpu_mul(temp_.count(), top_data, bottom_diff, bottom_diff); + + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_gpu_gemv(CblasNoTrans, channels_ * num, spatial_dim, 1., + top_diff, spatial_sum_multiplier_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemv(CblasTrans, num, channels_, 1., + num_by_chans_.gpu_data(), batch_sum_multiplier_.gpu_data(), 0., + mean_.mutable_gpu_data()); + // reshape (broadcast) the above to make + // sum(dE/dY)-sum(dE/dY \cdot Y) \cdot Y + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, channels_, 1, 1, + batch_sum_multiplier_.gpu_data(), mean_.gpu_data(), 0., + num_by_chans_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num * channels_, + spatial_dim, 1, 1., num_by_chans_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), 1., bottom_diff); + + // dE/dY - mean(dE/dY)-mean(dE/dY \cdot Y) \cdot Y + caffe_gpu_axpby(temp_.count(), Dtype(1), top_diff, + Dtype(-1. / (num * spatial_dim)), bottom_diff); + + // note: temp_ still contains sqrt(var(X)+eps), computed during the forward + // pass. + caffe_gpu_div(temp_.count(), bottom_diff, temp_.gpu_data(), bottom_diff); +} + +INSTANTIATE_LAYER_GPU_FUNCS(BatchNormLayer); + + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index a8747c12b37..99dd3c90eef 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -301,7 +301,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 139 (last added: tile_param) +// LayerParameter next available layer-specific ID: 140 (last added: batch_norm_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -350,6 +350,7 @@ message LayerParameter { // The default for the engine is set by the ENGINE switch at compile-time. optional AccuracyParameter accuracy_param = 102; optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; optional ConcatParameter concat_param = 104; optional ContrastiveLossParameter contrastive_loss_param = 105; optional ConvolutionParameter convolution_param = 106; @@ -461,6 +462,18 @@ message ConcatParameter { optional uint32 concat_dim = 1 [default = 1]; } +message BatchNormParameter { + // If false, accumulate global mean/variance values via a moving average. If + // true, use those accumulated values instead of computing mean/variance + // across the batch. + optional bool use_global_stats = 1; + // How much does the moving average decay each iteration? + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + message ContrastiveLossParameter { // margin for dissimilar pair optional float margin = 1 [default = 1.0]; diff --git a/src/caffe/test/test_batch_norm_layer.cpp b/src/caffe/test/test_batch_norm_layer.cpp new file mode 100644 index 00000000000..22b9667f31b --- /dev/null +++ b/src/caffe/test/test_batch_norm_layer.cpp @@ -0,0 +1,133 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +#define BATCH_SIZE 2 +#define INPUT_DATA_SIZE 3 + +namespace caffe { + + template + class BatchNormLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + BatchNormLayerTest() + : blob_bottom_(new Blob(5, 2, 3, 4)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~BatchNormLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + }; + + TYPED_TEST_CASE(BatchNormLayerTest, TestDtypesAndDevices); + + TYPED_TEST(BatchNormLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + + BatchNormLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Test mean + int num = this->blob_bottom_->num(); + int channels = this->blob_bottom_->channels(); + int height = this->blob_bottom_->height(); + int width = this->blob_bottom_->width(); + + for (int j = 0; j < channels; ++j) { + Dtype sum = 0, var = 0; + for (int i = 0; i < num; ++i) { + for ( int k = 0; k < height; ++k ) { + for ( int l = 0; l < width; ++l ) { + Dtype data = this->blob_top_->data_at(i, j, k, l); + sum += data; + var += data * data; + } + } + } + sum /= height * width * num; + var /= height * width * num; + + const Dtype kErrorBound = 0.001; + // expect zero mean + EXPECT_NEAR(0, sum, kErrorBound); + // expect unit variance + EXPECT_NEAR(1, var, kErrorBound); + } + } + + TYPED_TEST(BatchNormLayerTest, TestForwardInplace) { + typedef typename TypeParam::Dtype Dtype; + Blob blob_inplace(5, 2, 3, 4); + vector*> blob_bottom_vec; + vector*> blob_top_vec; + LayerParameter layer_param; + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(&blob_inplace); + blob_bottom_vec.push_back(&blob_inplace); + blob_top_vec.push_back(&blob_inplace); + + BatchNormLayer layer(layer_param); + layer.SetUp(blob_bottom_vec, blob_top_vec); + layer.Forward(blob_bottom_vec, blob_top_vec); + + // Test mean + int num = blob_inplace.num(); + int channels = blob_inplace.channels(); + int height = blob_inplace.height(); + int width = blob_inplace.width(); + + for (int j = 0; j < channels; ++j) { + Dtype sum = 0, var = 0; + for (int i = 0; i < num; ++i) { + for ( int k = 0; k < height; ++k ) { + for ( int l = 0; l < width; ++l ) { + Dtype data = blob_inplace.data_at(i, j, k, l); + sum += data; + var += data * data; + } + } + } + sum /= height * width * num; + var /= height * width * num; + + const Dtype kErrorBound = 0.001; + // expect zero mean + EXPECT_NEAR(0, sum, kErrorBound); + // expect unit variance + EXPECT_NEAR(1, var, kErrorBound); + } + } + + TYPED_TEST(BatchNormLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + + BatchNormLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-4); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + +} // namespace caffe