Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance yolobox trt plugin #34128

Merged
merged 4 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,20 @@ class YoloBoxOpConverter : public OpConverter {
float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh"));
bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox"));
float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y"));
bool iou_aware = op_desc.HasAttr("iou_aware")
? BOOST_GET_CONST(bool, op_desc.GetAttr("iou_aware"))
: false;
float iou_aware_factor =
op_desc.HasAttr("iou_aware_factor")
? BOOST_GET_CONST(float, op_desc.GetAttr("iou_aware_factor"))
: 0.5;

int type_id = static_cast<int>(engine_->WithFp16());
auto input_dim = X_tensor->getDimensions();
auto* yolo_box_plugin = new plugin::YoloBoxPlugin(
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y,
input_dim.d[1], input_dim.d[2]);
iou_aware, iou_aware_factor, input_dim.d[1], input_dim.d[2]);

std::vector<nvinfer1::ITensor*> yolo_box_inputs;
yolo_box_inputs.push_back(X_tensor);
Expand Down
65 changes: 49 additions & 16 deletions paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cassert>

Expand All @@ -29,14 +27,17 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
const std::vector<int>& anchors,
const int class_num, const float conf_thresh,
const int downsample_ratio, const bool clip_bbox,
const float scale_x_y, const int input_h,
const float scale_x_y, const bool iou_aware,
const float iou_aware_factor, const int input_h,
const int input_w)
: data_type_(data_type),
class_num_(class_num),
conf_thresh_(conf_thresh),
downsample_ratio_(downsample_ratio),
clip_bbox_(clip_bbox),
scale_x_y_(scale_x_y),
iou_aware_(iou_aware),
iou_aware_factor_(iou_aware_factor),
input_h_(input_h),
input_w_(input_w) {
anchors_.insert(anchors_.end(), anchors.cbegin(), anchors.cend());
Expand All @@ -45,6 +46,7 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
assert(class_num_ > 0);
assert(input_h_ > 0);
assert(input_w_ > 0);
assert((iou_aware_factor_ > 0 && iou_aware_factor_ < 1));

cudaMalloc(&anchors_device_, anchors.size() * sizeof(int));
cudaMemcpy(anchors_device_, anchors.data(), anchors.size() * sizeof(int),
Expand All @@ -59,6 +61,8 @@ YoloBoxPlugin::YoloBoxPlugin(const void* data, size_t length) {
DeserializeValue(&data, &length, &downsample_ratio_);
DeserializeValue(&data, &length, &clip_bbox_);
DeserializeValue(&data, &length, &scale_x_y_);
DeserializeValue(&data, &length, &iou_aware_);
DeserializeValue(&data, &length, &iou_aware_factor_);
DeserializeValue(&data, &length, &input_h_);
DeserializeValue(&data, &length, &input_w_);
}
Expand Down Expand Up @@ -133,8 +137,19 @@ __device__ inline void GetYoloBox(float* box, const T* x, const int* anchors,

__device__ inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
int entry, bool iou_aware) {
if (iou_aware) {
return (batch * an_num + an_idx) * an_stride +
(batch * an_num + an_num + entry) * stride + hw_idx;
} else {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
}

__device__ inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride) {
return batch * an_num * an_stride + (batch * an_num + an_idx) * stride +
hw_idx;
}

template <typename T>
Expand Down Expand Up @@ -178,7 +193,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
const int w, const int an_num, const int class_num,
const int box_num, int input_size_h,
int input_size_w, bool clip_bbox, const float scale,
const float bias) {
const float bias, bool iou_aware,
const float iou_aware_factor) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
float box[4];
Expand All @@ -193,11 +209,16 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1];

int obj_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4);
int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4,
iou_aware);
float conf = sigmoid(static_cast<float>(input[obj_idx]));
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0);
if (iou_aware) {
int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num);
float iou = sigmoid<float>(input[iou_idx]);
conf = powf(conf, 1. - iou_aware_factor) * powf(iou, iou_aware_factor);
}
int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0,
iou_aware);

if (conf < conf_thresh) {
for (int i = 0; i < 4; ++i) {
Expand All @@ -212,8 +233,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);

int label_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5);
int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num,
5, iou_aware);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num);
Expand All @@ -240,7 +261,8 @@ int YoloBoxPlugin::enqueue_impl(int batch_size, const void* const* inputs,
reinterpret_cast<const int* const>(inputs[1]),
reinterpret_cast<T*>(outputs[0]), reinterpret_cast<T*>(outputs[1]),
conf_thresh_, anchors_device_, n, h, w, an_num, class_num_, box_num,
input_size_h, input_size_w, clip_bbox_, scale_x_y_, bias);
input_size_h, input_size_w, clip_bbox_, scale_x_y_, bias, iou_aware_,
iou_aware_factor_);
return cudaGetLastError() != cudaSuccess;
}

Expand Down Expand Up @@ -274,6 +296,8 @@ size_t YoloBoxPlugin::getSerializationSize() const TRT_NOEXCEPT {
serialize_size += SerializedSize(scale_x_y_);
serialize_size += SerializedSize(input_h_);
serialize_size += SerializedSize(input_w_);
serialize_size += SerializedSize(iou_aware_);
serialize_size += SerializedSize(iou_aware_factor_);
return serialize_size;
}

Expand All @@ -285,6 +309,8 @@ void YoloBoxPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, downsample_ratio_);
SerializeValue(&buffer, clip_bbox_);
SerializeValue(&buffer, scale_x_y_);
SerializeValue(&buffer, iou_aware_);
SerializeValue(&buffer, iou_aware_factor_);
SerializeValue(&buffer, input_h_);
SerializeValue(&buffer, input_w_);
}
Expand Down Expand Up @@ -326,8 +352,8 @@ void YoloBoxPlugin::configurePlugin(

nvinfer1::IPluginV2Ext* YoloBoxPlugin::clone() const TRT_NOEXCEPT {
return new YoloBoxPlugin(data_type_, anchors_, class_num_, conf_thresh_,
downsample_ratio_, clip_bbox_, scale_x_y_, input_h_,
input_w_);
downsample_ratio_, clip_bbox_, scale_x_y_,
iou_aware_, iou_aware_factor_, input_h_, input_w_);
}

YoloBoxPluginCreator::YoloBoxPluginCreator() {}
Expand Down Expand Up @@ -367,6 +393,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
float scale_x_y = 1.;
int h = -1;
int w = -1;
bool iou_aware = false;
float iou_aware_factor = 0.5;

for (int i = 0; i < fc->nbFields; ++i) {
const std::string field_name(fc->fields[i].name);
Expand All @@ -386,6 +414,10 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
clip_bbox = *static_cast<const bool*>(fc->fields[i].data);
} else if (field_name.compare("scale_x_y")) {
scale_x_y = *static_cast<const float*>(fc->fields[i].data);
} else if (field_name.compare("iou_aware")) {
iou_aware = *static_cast<const bool*>(fc->fields[i].data);
} else if (field_name.compare("iou_aware_factor")) {
iou_aware_factor = *static_cast<const float*>(fc->fields[i].data);
} else if (field_name.compare("h")) {
h = *static_cast<const int*>(fc->fields[i].data);
} else if (field_name.compare("w")) {
Expand All @@ -397,7 +429,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(

return new YoloBoxPlugin(
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, anchors,
class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, h, w);
class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, iou_aware,
iou_aware_factor, h, w);
}

nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::deserializePlugin(
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
const std::vector<int>& anchors, const int class_num,
const float conf_thresh, const int downsample_ratio,
const bool clip_bbox, const float scale_x_y,
const bool iou_aware, const float iou_aware_factor,
const int input_h, const int input_w);
YoloBoxPlugin(const void* data, size_t length);
~YoloBoxPlugin() override;
Expand Down Expand Up @@ -89,6 +90,8 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
float scale_x_y_;
int input_h_;
int input_w_;
bool iou_aware_;
float iou_aware_factor_;
std::string namespace_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,56 @@ def test_check_output(self):
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))


class TRTYoloBoxIoUAwareTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
image_shape = [self.bs, self.channel, self.height, self.width]
image = fluid.data(name='image', shape=image_shape, dtype='float32')
image_size = fluid.data(
name='image_size', shape=[self.bs, 2], dtype='int32')
boxes, scores = self.append_yolobox(image, image_size)

self.feeds = {
'image': np.random.random(image_shape).astype('float32'),
'image_size': np.random.randint(
32, 64, size=(self.bs, 2)).astype('int32'),
}
self.enable_trt = True
self.trt_parameters = TRTYoloBoxTest.TensorRTParam(
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [scores, boxes]

def set_params(self):
self.bs = 4
self.channel = 258
self.height = 64
self.width = 64
self.class_num = 80
self.anchors = [10, 13, 16, 30, 33, 23]
self.conf_thresh = .1
self.downsample_ratio = 32
self.iou_aware = True
self.iou_aware_factor = 0.5

def append_yolobox(self, image, image_size):
return fluid.layers.yolo_box(
x=image,
img_size=image_size,
class_num=self.class_num,
anchors=self.anchors,
conf_thresh=self.conf_thresh,
downsample_ratio=self.downsample_ratio,
iou_aware=self.iou_aware,
iou_aware_factor=self.iou_aware_factor)

def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))


if __name__ == "__main__":
unittest.main()