Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BiasLayer to add two Blobs with broadcasting #3550

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions include/caffe/layers/bias_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef CAFFE_INNER_PRODUCT_LAYER_HPP_
#define CAFFE_INNER_PRODUCT_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief Computes a sum of two input Blobs, with the shape of the
* latter Blob "broadcast" to match the shape of the former.
* Equivalent to tiling the latter Blob, then computing the elementwise
* sum.
*
* The second input may be omitted, in which case it's learned as a parameter
* of the layer.
*/
template <typename Dtype>
class BiasLayer : public Layer<Dtype> {
public:
explicit BiasLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Bias"; }
virtual inline int MinBottomBlobs() const { return 1; }
virtual inline int MaxBottomBlobs() const { return 2; }
virtual inline int ExactNumTopBlobs() const { return 1; }

virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

private:
Blob<Dtype> bias_multiplier_;
int outer_dim_, bias_dim_, inner_dim_, dim_;
};



} // namespace caffe

#endif // CAFFE_INNER_PRODUCT_LAYER_HPP_
83 changes: 83 additions & 0 deletions include/caffe/layers/scalar_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#ifndef CAFFE_INNER_PRODUCT_LAYER_HPP_
#define CAFFE_INNER_PRODUCT_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/bias_layer.hpp"

namespace caffe {

/**
* @brief Computes a product of two input Blobs, with the shape of the
* latter Blob "broadcast" to match the shape of the former.
* Equivalent to tiling the latter Blob, then computing the elementwise
* product.
*
* The second input may be omitted, in which case it's learned as a parameter
* of the layer.
*/
template <typename Dtype>
class ScalarLayer: public Layer<Dtype> {
public:
explicit ScalarLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "Scalar"; }
// Scalar
virtual inline int MinBottomBlobs() const { return 1; }
virtual inline int MaxBottomBlobs() const { return 2; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
/**
* In the below shape specifications, @f$ i @f$ denotes the value of the
* `axis` field given by `this->layer_param_.scalar_param().axis()`, after
* canonicalization (i.e., conversion from negative to positive index,
* if applicable).
*
* @param bottom input Blob vector (length 2)
* -# @f$ (d_0 \times ... \times
* d_i \times ... \times d_j \times ... \times d_n) @f$
* the first factor @f$ x @f$
* -# @f$ (d_i \times ... \times d_j) @f$
* the second factor @f$ y @f$
* @param top output Blob vector (length 1)
* -# @f$ (d_0 \times ... \times
* d_i \times ... \times d_j \times ... \times d_n) @f$
* the product @f$ z = x y @f$ computed after "broadcasting" y.
* Equivalent to tiling @f$ y @f$ to have the same shape as @f$ x @f$,
* then computing the elementwise product.
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

shared_ptr<Layer<Dtype> > bias_layer_;
vector<Blob<Dtype>*> bias_bottom_vec_;
vector<bool> bias_propagate_down_;
int bias_param_id_;

Blob<Dtype> sum_multiplier_;
Blob<Dtype> sum_result_;
Blob<Dtype> temp_;
int axis_;
int outer_dim_, scalar_dim_, inner_dim_;
};


} // namespace caffe

#endif // CAFFE_INNER_PRODUCT_LAYER_HPP_
121 changes: 121 additions & 0 deletions src/caffe/layers/bias_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include <vector>

#include "caffe/filler.hpp"
#include "caffe/layers/bias_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void BiasLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
if (bottom.size() == 1 && this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else if (bottom.size() == 1) {
// bias is a learned parameter; initialize it
const BiasParameter& param = this->layer_param_.bias_param();
const int axis = bottom[0]->CanonicalAxisIndex(param.axis());
const int num_axes = param.num_axes();
CHECK_GE(num_axes, -1) << "num_axes must be non-negative, "
<< "or -1 to extend to the end of bottom[0]";
if (num_axes >= 0) {
CHECK_GE(bottom[0]->num_axes(), axis + num_axes)
<< "bias blob's shape extends past bottom[0]'s shape when applied "
<< "starting with bottom[0] axis = " << axis;
}
this->blobs_.resize(1);
const vector<int>::const_iterator& shape_start =
bottom[0]->shape().begin() + axis;
const vector<int>::const_iterator& shape_end =
(num_axes == -1) ? bottom[0]->shape().end() : (shape_start + num_axes);
vector<int> bias_shape(shape_start, shape_end);
this->blobs_[0].reset(new Blob<Dtype>(bias_shape));
shared_ptr<Filler<Dtype> > filler(GetFiller<Dtype>(param.filler()));
filler->Fill(this->blobs_[0].get());
}
this->param_propagate_down_.resize(this->blobs_.size(), true);
}

template <typename Dtype>
void BiasLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const BiasParameter& param = this->layer_param_.bias_param();
Blob<Dtype>* bias = (bottom.size() > 1) ? bottom[1] : this->blobs_[0].get();
// Always set axis == 0 in special case where bias is a scalar
// (num_axes == 0). Mathematically equivalent for any choice of axis, so the
// actual setting can be safely ignored; and computation is most efficient
// with axis == 0 and (therefore) outer_dim_ == 1.
const int axis = (bias->num_axes() == 0) ?
0 : bottom[0]->CanonicalAxisIndex(param.axis());
CHECK_GE(bottom[0]->num_axes(), axis + bias->num_axes())
<< "bias blob's shape extends past bottom[0]'s shape when applied "
<< "starting with bottom[0] axis = " << axis;
for (int i = 0; i < bias->num_axes(); ++i) {
CHECK_EQ(bottom[0]->shape(axis + i), bias->shape(i))
<< "dimension mismatch between bottom[0]->shape(" << axis + i
<< ") and bias->shape(" << i << ")";
}
outer_dim_ = bottom[0]->count(0, axis);
bias_dim_ = bias->count();
inner_dim_ = bottom[0]->count(axis + bias->num_axes());
dim_ = bias_dim_ * inner_dim_;
if (bottom[0] != top[0]) {
top[0]->ReshapeLike(*bottom[0]);
}
bias_multiplier_.Reshape(vector<int>(1, inner_dim_));
if (bias_multiplier_.cpu_data()[inner_dim_ - 1] != Dtype(1)) {
caffe_set(inner_dim_, Dtype(1), bias_multiplier_.mutable_cpu_data());
}
}

template <typename Dtype>
void BiasLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bias_data =
((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
if (bottom[0] != top[0]) {
const Dtype* bottom_data = bottom[0]->cpu_data();
caffe_copy(bottom[0]->count(), bottom_data, top_data);
}
for (int n = 0; n < outer_dim_; ++n) {
caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, bias_dim_,
inner_dim_, Dtype(1), Dtype(1), bias_data,
bias_multiplier_.cpu_data(), Dtype(1), top_data);
top_data += dim_;
}
}

template <typename Dtype>
void BiasLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0] && bottom[0] != top[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
caffe_copy(bottom[0]->count(), top_diff, bottom_diff);
}
// in-place, we don't need to do anything with the data diff
const bool bias_param = (bottom.size() == 1);
if ((!bias_param && propagate_down[1]) ||
(bias_param && this->param_propagate_down_[0])) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bias_diff = (bias_param ? this->blobs_[0].get() : bottom[1])
->mutable_cpu_diff();
bool accum = bias_param;
for (int n = 0; n < outer_dim_; ++n) {
caffe_cpu_gemv(CblasNoTrans, bias_dim_, inner_dim_, Dtype(1),
top_diff, bias_multiplier_.cpu_data(), Dtype(accum), bias_diff);
top_diff += dim_;
accum = true;
}
}
}

#ifdef CPU_ONLY
STUB_GPU(BiasLayer);
#endif

INSTANTIATE_CLASS(BiasLayer);
REGISTER_LAYER_CLASS(Bias);

} // namespace caffe
59 changes: 59 additions & 0 deletions src/caffe/layers/bias_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include <vector>

#include "caffe/filler.hpp"
#include "caffe/layers/bias_layer.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
__global__ void BiasForward(const int n, const Dtype* in,
const Dtype* bias, const int bias_dim, const int inner_dim,
Dtype* out) {
CUDA_KERNEL_LOOP(index, n) {
const int bias_index = (index / inner_dim) % bias_dim;
out[index] = in[index] + bias[bias_index];
}
}

template <typename Dtype>
void BiasLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int count = top[0]->count();
const Dtype* bottom_data = bottom[0]->gpu_data();
const Dtype* bias_data =
((bottom.size() > 1) ? bottom[1] : this->blobs_[0].get())->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
BiasForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, bias_data, bias_dim_, inner_dim_, top_data);
}

template <typename Dtype>
void BiasLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[0] && bottom[0] != top[0]) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
caffe_copy(bottom[0]->count(), top_diff, bottom_diff);
}
// in-place, we don't need to do anything with the data diff
const bool bias_param = (bottom.size() == 1);
if ((!bias_param && propagate_down[1]) ||
(bias_param && this->param_propagate_down_[0])) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bias_diff = (bias_param ? this->blobs_[0].get() : bottom[1])
->mutable_gpu_diff();
bool accum = bias_param;
for (int n = 0; n < outer_dim_; ++n) {
caffe_gpu_gemv(CblasNoTrans, bias_dim_, inner_dim_, Dtype(1),
top_diff, bias_multiplier_.gpu_data(), Dtype(accum), bias_diff);
top_diff += dim_;
accum = true;
}
}
}

INSTANTIATE_LAYER_GPU_FUNCS(BiasLayer);

} // namespace caffe
Loading