Skip to content

Commit

Permalink
Extending crop to work for ND Blobs.
Browse files Browse the repository at this point in the history
  • Loading branch information
BlGene committed Jan 20, 2016
1 parent 0af80f0 commit a141f45
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 36 deletions.
16 changes: 14 additions & 2 deletions include/caffe/layers/crop_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,21 @@ class CropLayer : public Layer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

int crop_h_, crop_w_;
vector<int> offsets;

private:
void crop_copy(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const vector<int>& offsets,
vector<int> indices,
int cur_dim);

void crop_copy_diff(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const vector<int>& offsets,
vector<int> indices,
int cur_dim);
};

} // namespace caffe

#endif // CAFFE_CROP_LAYER_HPP_
126 changes: 95 additions & 31 deletions src/caffe/layers/crop_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <vector>


#include "caffe/layer.hpp"
#include "caffe/layers/crop_layer.hpp"
#include "caffe/net.hpp"
Expand All @@ -15,56 +17,118 @@ void CropLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const CropParameter& param = this->layer_param_.crop_param();
CHECK_EQ(bottom.size(), 2) << "Wrong number of bottom blobs.";
CHECK_EQ(bottom[0]->num_axes(), 4) << "Only works with 4D blobs.";
CHECK_EQ(bottom[1]->num_axes(), 4) << "Only works with 4D blobs.";
crop_h_ = param.offset_height();
crop_w_ = param.offset_width();
// parameter setup moved to Reshape because it depends on size.
CHECK_EQ(param.crop_axis_size(), param.crop_offset_size());
}

template <typename Dtype>
void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Check that the image we are cropping minus the margin is bigger than the
// destination image.
CHECK_GT(bottom[0]->height()-crop_h_, bottom[1]->height())
<< "invalid offset";
CHECK_GT(bottom[0]->width()-crop_w_, bottom[1]->width()) << "invalid offset";
top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[1]->height(),
bottom[1]->width());
const CropParameter& param = this->layer_param_.crop_param();
int input_dim = bottom[0]->num_axes();
// initialize all offsets to 0
offsets = vector<int>(input_dim, 0);
// initialize new shape to bottom[0]
vector<int> new_shape(bottom[0]->shape());
// apply crops
for (int i = 0; i < param.crop_axis_size(); ++i) {
int crop_axis = param.crop_axis(i);
int crop_offset = param.crop_offset(i);
CHECK_LT(crop_axis, input_dim) << "crop axis bigger than input dim";
// Check that the image we are cropping minus the margin is bigger
// than the destination image.
CHECK_GE(bottom[0]->shape(crop_axis) - crop_offset,
bottom[1]->shape(crop_axis))
<< "invalid offset in dimension: " << crop_axis;
// Now set new size and offsets
new_shape[crop_axis] = bottom[1]->shape(crop_axis);
offsets[crop_axis] = crop_offset;
}
top[0]->Reshape(new_shape);
}

// recursive copy function
template <typename Dtype>
void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
for (int n = 0; n < top[0]->num(); ++n) {
for (int c = 0; c < top[0]->channels(); ++c) {
for (int h = 0; h < top[0]->height(); ++h) {
caffe_copy(top[0]->width(),
bottom_data + bottom[0]->offset(n, c, crop_h_ + h, crop_w_),
top_data + top[0]->offset(n, c, h));
void CropLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const vector<int>& offsets,
vector<int> indices,
int cur_dim) {
if (cur_dim + 1 < top[0]->num_axes()) {
// We are not yet at the final dimension, call copy recursivley
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
indices[cur_dim] = i;
crop_copy(bottom, top, offsets, indices, cur_dim+1);
}
} else {
// We are at the last dimensions, which is stored continously in memory
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
// prepare index vector reduced(red) and with offsets(off)
std::vector<int> ind_red(cur_dim, 0);
std::vector<int> ind_off(cur_dim+1, 0);
for (int j = 0; j < cur_dim; ++j) {
ind_red[j] = indices[j];
ind_off[j] = indices[j] + offsets[j];
}
ind_off[cur_dim] = offsets[cur_dim];
// do the copy
caffe_copy(top[0]->shape(cur_dim),
bottom_data + bottom[0]->offset(ind_off),
top_data + top[0]->offset(ind_red));
}
}
}

template <typename Dtype>
void CropLayer<Dtype>::crop_copy_diff(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top,
const vector<int>& offsets,
vector<int> indices,
int cur_dim) {
if (cur_dim + 1 < top[0]->num_axes()) {
// We are not yet at the final dimension, call copy recursivley
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
indices[cur_dim] = i;
crop_copy(bottom, top, offsets, indices, cur_dim+1);
}
} else {
// We are at the last dimensions, which is stored continously in memory
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
// prepare index vector reduced(red) and with offsets(off)
std::vector<int> ind_red(cur_dim, 0);
std::vector<int> ind_off(cur_dim+1, 0);
for (int j = 0; j < cur_dim; ++j) {
ind_red[j] = indices[j];
ind_off[j] = indices[j] + offsets[j];
}
ind_off[cur_dim] = offsets[cur_dim];
// do the copy
caffe_copy(top[0]->shape(cur_dim),
top_diff + top[0]->offset(ind_red),
bottom_diff + bottom[0]->offset(ind_off));
}
}
}

template <typename Dtype>
void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
std::vector<int> indices(top[0]->num_axes(), 0);
crop_copy(bottom, top, offsets, indices, 0);
}

template <typename Dtype>
void CropLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
if (propagate_down[0]) {
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
for (int n = 0; n < top[0]->num(); ++n) {
for (int c = 0; c < top[0]->channels(); ++c) {
for (int h = 0; h < top[0]->height(); ++h) {
caffe_copy(top[0]->width(),
top_diff + top[0]->offset(n, c, h),
bottom_diff + bottom[0]->offset(n, c, crop_h_ + h, crop_w_));
}
}
}
std::vector<int> indices(top[0]->num_axes(), 0);
crop_copy(bottom, top, offsets, indices, 0);
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/caffe/layers/crop_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ __global__ void copy_kernel(const int n, const int height, const int width,
template <typename Dtype>
void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
CHECK_EQ(top[0]->num_axes(), 4) << "only 4D crop implemented for GPU";
CHECK_EQ(offsets[0], 0) << "only H,W cropping implemented for GPU";
CHECK_EQ(offsets[1], 0) << "only H,W cropping implemented for GPU";

const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const int lines = top[0]->count() / top[0]->width();

int crop_h_ = offsets[2];
int crop_w_ = offsets[3];

// NOLINT_NEXT_LINE(whitespace/operators)
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
lines, top[0]->height(), top[0]->width(),
Expand All @@ -40,9 +47,14 @@ void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void CropLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK_EQ(top[0]->num_axes(), 4) << "only 4D crop implemented for GPU";
CHECK_EQ(offsets[0], 0) << "only H,W cropping implemented for GPU";
CHECK_EQ(offsets[1], 0) << "only H,W cropping implemented for GPU";
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
const int lines = top[0]->count() / top[0]->width();
int crop_h_ = offsets[2];
int crop_w_ = offsets[3];

if (propagate_down[0]) {
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
Expand Down
13 changes: 10 additions & 3 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,16 @@ message ConvolutionParameter {

message CropParameter {
// Assumes standard dimensions: ( N,C,H,W )
// This could possibly be extended to use "optional BlobShape offsets"
optional uint32 offset_height = 1[default = 0];
optional uint32 offset_width = 2[default = 0];
//optional BlobShape offsets = 1;
//repeated bool skip_axis = 2;

// For values where crop_axis is defined:
// 1. use size from second bottom
// 2. consider crop_offset

repeated uint32 crop_axis = 1;
repeated uint32 crop_offset = 2;

}

message DataParameter {
Expand Down

0 comments on commit a141f45

Please sign in to comment.