forked from BVLC/caffe
-
Notifications
You must be signed in to change notification settings - Fork 263
/
crop_layer.cu
118 lines (107 loc) · 4.42 KB
/
crop_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <vector>
#include "caffe/layers/crop_layer.hpp"
namespace caffe {
// Copy (one line per thread) from one array to another, with arbitrary
// strides in the last two dimensions.
template <typename Dtype>
__global__ void copy_kernel(const int n, const int height, const int width,
const int src_outer_stride, const int src_inner_stride,
const int dest_outer_stride, const int dest_inner_stride,
const Dtype* src, Dtype* dest) {
CUDA_KERNEL_LOOP(index, n) {
int src_start = index / height * src_outer_stride
+ index % height * src_inner_stride;
int dest_start = index / height * dest_outer_stride
+ index % height * dest_inner_stride;
for (int i = 0; i < width; ++i) {
dest[dest_start + i] = src[src_start + i];
}
}
}
template <typename Ftype, typename Btype>
template <typename Dtype>
void CropLayer<Ftype, Btype>::crop_copy_gpu(const vector<Blob*>& bottom,
const vector<Blob*>& top,
const vector<int>& offsets,
vector<int> indices,
int cur_dim,
const Dtype* src_data,
Dtype* dest_data,
bool is_forward) {
if (cur_dim + 2 < top[0]->num_axes()) {
// We are not yet at the final dimension, call copy recursivley
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
indices[cur_dim] = i;
crop_copy_gpu(bottom, top, offsets, indices, cur_dim+1,
src_data, dest_data, is_forward);
}
} else {
// We are at the last two dimensions, which are stored continously in memory
// With (N,C,H,W)
// (0,1,2,3) cur_dim -> H
// cur_dim+1 -> W
const int lines = top[0]->shape(cur_dim);
const int height = top[0]->shape(cur_dim);
const int width = top[0]->shape(cur_dim+1);
std::vector<int> ind_off(cur_dim+2, 0);
for (int j = 0; j < cur_dim; ++j) {
ind_off[j] = indices[j] + offsets[j];
}
ind_off[cur_dim] = offsets[cur_dim];
ind_off[cur_dim+1] = offsets[cur_dim+1];
// Compute copy strides
const int src_outer_stride =
bottom[0]->shape(cur_dim)*bottom[0]->shape(cur_dim+1);
const int src_inner_stride = bottom[0]->shape(cur_dim+1);
const int dest_outer_stride =
top[0]->shape(cur_dim)*top[0]->shape(cur_dim+1);
const int dest_inner_stride = top[0]->shape(cur_dim+1);
cudaStream_t stream = Caffe::thread_stream();
if (is_forward) {
const Dtype* bottom_data = bottom[0]->gpu_data<Dtype>() +
bottom[0]->offset(ind_off);
Dtype* top_data = top[0]->mutable_gpu_data<Dtype>() +
top[0]->offset(indices);
// NOLINT_NEXT_LINE(whitespace/operators)
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS, 0, stream>>>(
lines, height, width,
src_outer_stride, src_inner_stride,
dest_outer_stride, dest_inner_stride,
bottom_data, top_data);
} else {
const Dtype* top_diff = top[0]->gpu_diff<Dtype>() +
top[0]->offset(indices);
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff<Dtype>() +
bottom[0]->offset(ind_off);
// NOLINT_NEXT_LINE(whitespace/operators)
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS, 0, stream>>>(
lines, height, width,
dest_outer_stride, dest_inner_stride,
src_outer_stride, src_inner_stride,
top_diff, bottom_diff);
}
CUDA_CHECK(cudaStreamSynchronize(stream));
}
}
template <typename Ftype, typename Btype>
void CropLayer<Ftype, Btype>::Forward_gpu(const vector<Blob*>& bottom,
const vector<Blob*>& top) {
std::vector<int> indices(top[0]->num_axes(), 0);
const Ftype* bottom_data = bottom[0]->gpu_data<Ftype>();
Ftype* top_data = top[0]->mutable_gpu_data<Ftype>();
crop_copy_gpu(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
}
template <typename Ftype, typename Btype>
void CropLayer<Ftype, Btype>::Backward_gpu(const vector<Blob*>& top,
const vector<bool>& propagate_down, const vector<Blob*>& bottom) {
const Btype* top_diff = top[0]->gpu_diff<Btype>();
Btype* bottom_diff = bottom[0]->mutable_gpu_diff<Btype>();
if (propagate_down[0]) {
caffe_gpu_set(bottom[0]->count(), static_cast<Btype>(0), bottom_diff);
std::vector<int> indices(top[0]->num_axes(), 0);
crop_copy_gpu(bottom, top, offsets, indices, 0, top_diff, bottom_diff,
false);
}
}
INSTANTIATE_LAYER_GPU_FUNCS_FB(CropLayer);
} // namespace caffe