forked from chenzhi1992/TensorRT-SSD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pluginImplement.h
276 lines (234 loc) · 10.6 KB
/
pluginImplement.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#ifndef __PLUGIN_LAYER_H__
#define __PLUGIN_LAYER_H__
#include <cassert>
#include <iostream>
#include <cudnn.h>
#include <cstring>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "NvCaffeParser.h"
#include "NvInferPlugin.h"
#define CHECK(status) \
{ \
if (status != 0) \
{ \
std::cout << "Cuda failure: " << cudaGetErrorString(status) \
<< " at line " << __LINE__ \
<< std::endl; \
abort(); \
} \
}
using namespace nvinfer1;
using namespace nvcaffeparser1;
using namespace plugin;
enum FunctionType
{
SELECT=0,
SUMMARY
};
class bboxProfile {
public:
bboxProfile(float4& p, int idx): pos(p), bboxNum(idx) {}
float4 pos;
int bboxNum = -1;
int labelID = -1;
};
class tagProfile {
public:
tagProfile(int b, int l): bboxID(b), label(l) {}
int bboxID;
int label;
};
//SSD Reshape layer : shape{0,-1,21}
template<int OutC>
class Reshape : public IPlugin
{
public:
Reshape() {}
Reshape(const void* buffer, size_t size)
{
assert(size == sizeof(mCopySize));
mCopySize = *reinterpret_cast<const size_t*>(buffer);
}
int getNbOutputs() const override
{
return 1;
}
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
{
assert(nbInputDims == 1);
assert(index == 0);
assert(inputs[index].nbDims == 3);
assert((inputs[0].d[0])*(inputs[0].d[1]) % OutC == 0);
return DimsCHW(OutC, inputs[0].d[0] * inputs[0].d[1] / OutC, inputs[0].d[2]);
}
int initialize() override
{
return 0;
}
void terminate() override
{
}
size_t getWorkspaceSize(int) const override
{
return 0;
}
// currently it is not possible for a plugin to execute "in place". Therefore we memcpy the data from the input to the output buffer
int enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) override
{
CHECK(cudaMemcpyAsync(outputs[0], inputs[0], mCopySize * batchSize, cudaMemcpyDeviceToDevice, stream));
return 0;
}
size_t getSerializationSize() override
{
return sizeof(mCopySize);
}
void serialize(void* buffer) override
{
*reinterpret_cast<size_t*>(buffer) = mCopySize;
}
void configure(const Dims*inputs, int nbInputs, const Dims* outputs, int nbOutputs, int) override
{
mCopySize = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2] * sizeof(float);
}
protected:
size_t mCopySize;
};
//SSD Flatten layer
class FlattenLayer : public IPlugin
{
public:
FlattenLayer(){}
FlattenLayer(const void* buffer,size_t size)
{
assert(size == 3 * sizeof(int));
const int* d = reinterpret_cast<const int*>(buffer);
_size = d[0] * d[1] * d[2];
dimBottom = DimsCHW{d[0], d[1], d[2]};
}
inline int getNbOutputs() const override { return 1; };
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
{
assert(1 == nbInputDims);
assert(0 == index);
assert(3 == inputs[index].nbDims);
_size = inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2];
return DimsCHW(_size, 1, 1);
}
int initialize() override
{
return 0;
}
inline void terminate() override
{
}
inline size_t getWorkspaceSize(int) const override { return 0; }
int enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) override
{
std::cout<<"flatten enqueue:"<<batchSize<<";"<<_size<<std::endl;
CHECK(cudaMemcpyAsync(outputs[0],inputs[0],batchSize*_size*sizeof(float),cudaMemcpyDeviceToDevice,stream));
return 0;
}
size_t getSerializationSize() override
{
return 3 * sizeof(int);
}
void serialize(void* buffer) override
{
int* d = reinterpret_cast<int*>(buffer);
d[0] = dimBottom.c(); d[1] = dimBottom.h(); d[2] = dimBottom.w();
}
void configure(const Dims*inputs, int nbInputs, const Dims* outputs, int nbOutputs, int) override
{
dimBottom = DimsCHW(inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]);
}
protected:
DimsCHW dimBottom;
int _size;
};
//Concat layer . TensorRT Concat only support cross channel
class ConcatPlugin : public IPlugin
{
public:
ConcatPlugin(int axis){ _axis = axis; };
ConcatPlugin(int axis, const void* buffer, size_t size);
inline int getNbOutputs() const override {return 1;};
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override ;
int initialize() override;
inline void terminate() override;
inline size_t getWorkspaceSize(int) const override { return 0; };
int enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) override;
size_t getSerializationSize() override;
void serialize(void* buffer) override;
void configure(const Dims*inputs, int nbInputs, const Dims* outputs, int nbOutputs, int) override;
protected:
DimsCHW dimsConv4_3, dimsFc7, dimsConv6, dimsConv7, dimsConv8, dimsConv9;
int inputs_size;
int top_concat_axis;//top 层 concat后的维度
int* bottom_concat_axis = new int[9];//记录每个bottom层concat维度的shape
int* concat_input_size_ = new int[9];
int* num_concats_ = new int[9];
int _axis;
};
//Softmax layer . TensorRT softmax only support cross channel
class SoftmaxPlugin : public IPlugin
{
//You need to implement it when softmax parameter axis is 2.
//
}
class PluginFactory : public nvinfer1::IPluginFactory, public nvcaffeparser1::IPluginFactory
{
public:
virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const nvinfer1::Weights* weights, int nbWeights) override;
IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override;
void(*nvPluginDeleter)(INvPlugin*) { [](INvPlugin* ptr) {ptr->destroy(); } };
bool isPlugin(const char* name) override;
void destroyPlugin();
//normalize layer
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mNormalizeLayer{ nullptr, nvPluginDeleter };
//permute layers
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv4_3_norm_mbox_conf_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv4_3_norm_mbox_loc_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mFc7_mbox_conf_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mFc7_mbox_loc_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv6_2_mbox_conf_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv6_2_mbox_loc_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv7_2_mbox_conf_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv7_2_mbox_loc_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv8_2_mbox_conf_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv8_2_mbox_loc_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv9_2_mbox_conf_perm_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv9_2_mbox_loc_perm_layer{ nullptr, nvPluginDeleter };
//priorbox layers
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv4_3_norm_mbox_priorbox_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mFc7_mbox_priorbox_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv6_2_mbox_priorbox_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv7_2_mbox_priorbox_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv8_2_mbox_priorbox_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mConv9_2_mbox_priorbox_layer{ nullptr, nvPluginDeleter };
//detection output layer
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mDetection_out{ nullptr, nvPluginDeleter };
//concat layers
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mBox_loc_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mBox_conf_layer{ nullptr, nvPluginDeleter };
std::unique_ptr<INvPlugin, decltype(nvPluginDeleter)> mBox_priorbox_layer{ nullptr, nvPluginDeleter };
//reshape layer
std::unique_ptr<Reshape<21>> mMbox_conf_reshape{ nullptr };
//flatten layers
std::unique_ptr<FlattenLayer> mConv4_3_norm_mbox_conf_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv4_3_norm_mbox_loc_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mFc7_mbox_conf_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mFc7_mbox_loc_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv6_2_mbox_conf_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv6_2_mbox_loc_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv7_2_mbox_conf_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv7_2_mbox_loc_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv8_2_mbox_conf_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv8_2_mbox_loc_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv9_2_mbox_conf_flat_layer{ nullptr };
std::unique_ptr<FlattenLayer> mConv9_2_mbox_loc_flat_layer{ nullptr };
//softmax layer
std::unique_ptr<SoftmaxPlugin> mPluginSoftmax{ nullptr };
std::unique_ptr<FlattenLayer> mMbox_conf_flat_layer{ nullptr };
};
#endif