Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Channel softmax #940

Merged
merged 4 commits into from
Aug 19, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ void caffe_abs(const int n, const Dtype* a, Dtype* y);
template <typename Dtype>
Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);

template <typename Dtype>
Dtype caffe_cpu_strided_dot(const int n, const Dtype* x, const int incx,
const Dtype* y, const int incy);

template <typename Dtype>
int caffe_cpu_hamming_distance(const int n, const Dtype* x, const Dtype* y);

Expand Down
63 changes: 37 additions & 26 deletions src/caffe/layers/softmax_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ void SoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
sum_multiplier_.Reshape(1, bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1);
Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
for (int i = 0; i < sum_multiplier_.count(); ++i) {
multiplier_data[i] = 1.;
}
scale_.Reshape(bottom[0]->num(), 1, 1, 1);
scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
}

template <typename Dtype>
Expand All @@ -29,27 +28,34 @@ void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
Dtype* top_data = (*top)[0]->mutable_cpu_data();
Dtype* scale_data = scale_.mutable_cpu_data();
int num = bottom[0]->num();
int channels = bottom[0]->channels();
int dim = bottom[0]->count() / bottom[0]->num();
int spatial_dim = bottom[0]->height() * bottom[0]->width();
caffe_copy(bottom[0]->count(), bottom_data, top_data);
// we need to subtract the max to avoid numerical issues, compute the exp,
// We need to subtract the max to avoid numerical issues, compute the exp,
// and then normalize.
for (int i = 0; i < num; ++i) {
scale_data[i] = bottom_data[i*dim];
for (int j = 0; j < dim; ++j) {
scale_data[i] = std::max(scale_data[i], bottom_data[i * dim + j]);
// initialize scale_data to the first plane
caffe_copy(spatial_dim, bottom_data + i * dim, scale_data);
for (int j = 0; j < channels; j++) {
for (int k = 0; k < spatial_dim; k++) {
scale_data[k] = std::max(scale_data[k],
bottom_data[i * dim + j * spatial_dim + k]);
}
}
// subtraction
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim,
1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim);
// exponentiation
caffe_exp<Dtype>(dim, top_data + i * dim, top_data + i * dim);
// sum after exp
caffe_cpu_gemv<Dtype>(CblasTrans, channels, spatial_dim, 1.,
top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data);
// division
for (int j = 0; j < channels; j++) {
caffe_div(spatial_dim, top_data + (*top)[0]->offset(i, j), scale_data,
top_data + (*top)[0]->offset(i, j));
}
}
// subtraction
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
scale_data, sum_multiplier_.cpu_data(), 1., top_data);
// Perform exponentiation
caffe_exp<Dtype>(num * dim, top_data, top_data);
// sum after exp
caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
sum_multiplier_.cpu_data(), 0., scale_data);
// Do division
for (int i = 0; i < num; ++i) {
caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
}
}

Expand All @@ -62,18 +68,23 @@ void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
Dtype* scale_data = scale_.mutable_cpu_data();
int num = top[0]->num();
int channels = top[0]->channels();
int dim = top[0]->count() / top[0]->num();
int spatial_dim = top[0]->height() * top[0]->width();
caffe_copy(top[0]->count(), top_diff, bottom_diff);
// Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
for (int i = 0; i < num; ++i) {
scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
top_data + i * dim);
// compute dot(top_diff, top_data) and subtract them from the bottom diff
for (int k = 0; k < spatial_dim; ++k) {
scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
bottom_diff + i * dim + k, spatial_dim,
top_data + i * dim + k, spatial_dim);
}
// subtraction
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1,
-1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
}
// subtraction
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
// elementwise multiplication
caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);
}


Expand Down
141 changes: 95 additions & 46 deletions src/caffe/layers/softmax_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,90 +11,139 @@
namespace caffe {

template <typename Dtype>
__global__ void kernel_get_max(const int num, const int dim,
const Dtype* data, Dtype* out) {
CUDA_KERNEL_LOOP(index, num) {
__global__ void kernel_channel_max(const int num, const int channels,
const int spatial_dim, const Dtype* data, Dtype* out) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
Dtype maxval = -FLT_MAX;
for (int i = 0; i < dim; ++i) {
maxval = max(data[index * dim + i], maxval);
for (int c = 0; c < channels; ++c) {
maxval = max(data[(n * channels + c) * spatial_dim + s], maxval);
}
out[index] = maxval;
}
}

template <typename Dtype>
__global__ void kernel_softmax_div(const int num, const int dim,
const Dtype* scale, Dtype* data) {
CUDA_KERNEL_LOOP(index, num * dim) {
int n = index / dim;
data[index] /= scale[n];
__global__ void kernel_channel_subtract(const int num, const int channels,
const int spatial_dim, Dtype* data, const Dtype* channel_max) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
for (int c = 0; c < channels; ++c) {
data[(n * channels + c) * spatial_dim + s] -= channel_max[index];
}
}
}

template <typename Dtype>
__global__ void kernel_exp(const int num, const Dtype* data, Dtype* out) {
CUDA_KERNEL_LOOP(index, num) {
__global__ void kernel_exp(const int count, const Dtype* data, Dtype* out) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kernel is the same as caffe_gpu_exp isn't it? Let's remove it and replace with caffe_gpu_exp, unless I'm misunderstanding somehow. (I know it wasn't added this PR, but I just noticed it from seeing the diff.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find caffe_gpu_exp. I only found caffe_exp, which calls vsExp in MKL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, my bad, I think I was thinking of caffe_gpu_powx. caffe_gpu_exp should probably exist but device abstraction (#610) will probably take care of this so never mind, sorry!

CUDA_KERNEL_LOOP(index, count) {
out[index] = exp(data[index]);
}
}

template <typename Dtype>
__global__ void kernel_channel_sum(const int num, const int channels,
const int spatial_dim, const Dtype* data, Dtype* channel_sum) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
Dtype sum = 0;
for (int c = 0; c < channels; ++c) {
sum += data[(n * channels + c) * spatial_dim + s];
}
channel_sum[index] = sum;
}
}

template <typename Dtype>
__global__ void kernel_channel_div(const int num, const int channels,
const int spatial_dim, Dtype* data, const Dtype* channel_sum) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
for (int c = 0; c < channels; ++c) {
data[(n * channels + c) * spatial_dim + s] /= channel_sum[index];
}
}
}

template <typename Dtype>
__global__ void kernel_channel_dot(const int num, const int channels,
const int spatial_dim, const Dtype* data_1, const Dtype* data_2,
Dtype* channel_dot) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
Dtype dot = 0;
for (int c = 0; c < channels; ++c) {
dot += (data_1[(n * channels + c) * spatial_dim + s]
* data_2[(n * channels + c) * spatial_dim + s]);
}
channel_dot[index] = dot;
}
}

template <typename Dtype>
void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
Dtype* scale_data = scale_.mutable_gpu_data();
int num = bottom[0]->num();
int dim = bottom[0]->count() / bottom[0]->num();
int channels = bottom[0]->channels();
int spatial_dim = bottom[0]->height() * bottom[0]->width();
caffe_copy(bottom[0]->count(), bottom_data, top_data);
// we need to subtract the max to avoid numerical issues, compute the exp,
// We need to subtract the max to avoid numerical issues, compute the exp,
// and then normalize.
// Compute max
// compute max
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_max<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
scale_data);
// subtract
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_get_max<Dtype><<<CAFFE_GET_BLOCKS(num), CAFFE_CUDA_NUM_THREADS>>>(
num, dim, bottom_data, scale_data);
// subtraction
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
scale_data, sum_multiplier_.gpu_data(), 1., top_data);
// Perform exponentiation
kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
scale_data);
// exponentiate
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * dim), CAFFE_CUDA_NUM_THREADS>>>(
num * dim, top_data, top_data);
kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * channels * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num * channels * spatial_dim, top_data,
top_data);
// sum after exp
caffe_gpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
sum_multiplier_.gpu_data(), 0., scale_data);
// Do division
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_softmax_div<Dtype><<<CAFFE_GET_BLOCKS(num * dim),
CAFFE_CUDA_NUM_THREADS>>>(
num, dim, scale_data, top_data);
kernel_channel_sum<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
scale_data);
// divide
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_div<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
scale_data);
}

// TODO(Yangqing): implement the GPU version of softmax.
template <typename Dtype>
void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* top_data = top[0]->gpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
Dtype* scale_data = scale_.mutable_gpu_data();
int num = top[0]->num();
int dim = top[0]->count() / top[0]->num();
int channels = top[0]->channels();
int spatial_dim = top[0]->height() * top[0]->width();
caffe_copy(top[0]->count(), top_diff, bottom_diff);
// Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
// cuda dot returns the result to cpu, so we temporarily change the pointer
// mode
CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
CUBLAS_POINTER_MODE_DEVICE));
Dtype* scale_data = scale_.mutable_gpu_data();
for (int i = 0; i < num; ++i) {
caffe_gpu_dot<Dtype>(dim, top_diff + i * dim,
top_data + i * dim, scale_data + i);
}
CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
CUBLAS_POINTER_MODE_HOST));
// subtraction
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
scale_.gpu_data(), sum_multiplier_.gpu_data(), 1., bottom_diff);
// Compute inner1d(top_diff, top_data) and subtract them from the bottom diff.
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_dot<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_diff, top_data,
scale_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, bottom_diff,
scale_data);
// elementwise multiplication
caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
}
Expand Down
19 changes: 13 additions & 6 deletions src/caffe/layers/softmax_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
// The forward pass computes the softmax prob values.
softmax_bottom_vec_[0] = bottom[0];
softmax_layer_->Forward(softmax_bottom_vec_, &softmax_top_vec_);
const Dtype* prob_data = prob_.cpu_data();
const Dtype* label = bottom[1]->cpu_data();
int num = prob_.num();
int dim = prob_.count() / num;
int spatial_dim = prob_.height() * prob_.width();
Dtype loss = 0;
for (int i = 0; i < num; ++i) {
loss += -log(std::max(prob_data[i * dim + static_cast<int>(label[i])],
Dtype(FLT_MIN)));
for (int j = 0; j < spatial_dim; j++) {
loss -= log(std::max(prob_data[i * dim +
static_cast<int>(label[i * spatial_dim + j]) * spatial_dim + j],
Dtype(FLT_MIN)));
}
}
(*top)[0]->mutable_cpu_data()[0] = loss / num;
(*top)[0]->mutable_cpu_data()[0] = loss / num / spatial_dim;
if (top->size() == 2) {
(*top)[1]->ShareData(prob_);
}
Expand All @@ -59,12 +62,16 @@ void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* label = (*bottom)[1]->cpu_data();
int num = prob_.num();
int dim = prob_.count() / num;
int spatial_dim = prob_.height() * prob_.width();
for (int i = 0; i < num; ++i) {
bottom_diff[i * dim + static_cast<int>(label[i])] -= 1;
for (int j = 0; j < spatial_dim; ++j) {
bottom_diff[i * dim + static_cast<int>(label[i * spatial_dim + j])
* spatial_dim + j] -= 1;
}
}
// Scale gradient
const Dtype loss_weight = top[0]->cpu_diff()[0];
caffe_scal(prob_.count(), loss_weight / num, bottom_diff);
caffe_scal(prob_.count(), loss_weight / num / spatial_dim, bottom_diff);
}
}

Expand Down
44 changes: 23 additions & 21 deletions src/caffe/test/test_softmax_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SoftmaxLayerTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
protected:
SoftmaxLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
: blob_bottom_(new Blob<Dtype>(2, 10, 2, 3)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
Expand All @@ -45,26 +45,28 @@ TYPED_TEST(SoftmaxLayerTest, TestForward) {
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test sum
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
Dtype sum = 0;
for (int j = 0; j < this->blob_top_->channels(); ++j) {
sum += this->blob_top_->data_at(i, j, 0, 0);
}
EXPECT_GE(sum, 0.999);
EXPECT_LE(sum, 1.001);
}
// Test exact values
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
Dtype scale = 0;
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
scale += exp(this->blob_bottom_->data_at(i, j, 0, 0));
}
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
EXPECT_GE(this->blob_top_->data_at(i, j, 0, 0) + 1e-4,
exp(this->blob_bottom_->data_at(i, j, 0, 0)) / scale)
<< "debug: " << i << " " << j;
EXPECT_LE(this->blob_top_->data_at(i, j, 0, 0) - 1e-4,
exp(this->blob_bottom_->data_at(i, j, 0, 0)) / scale)
<< "debug: " << i << " " << j;
for (int k = 0; k < this->blob_bottom_->height(); ++k) {
for (int l = 0; l < this->blob_bottom_->width(); ++l) {
Dtype sum = 0;
for (int j = 0; j < this->blob_top_->channels(); ++j) {
sum += this->blob_top_->data_at(i, j, k, l);
}
EXPECT_GE(sum, 0.999);
EXPECT_LE(sum, 1.001);
// Test exact values
Dtype scale = 0;
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
scale += exp(this->blob_bottom_->data_at(i, j, k, l));
}
for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
<< "debug: " << i << " " << j;
EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
<< "debug: " << i << " " << j;
}
}
}
}
}
Expand Down
Loading