Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeCamp2023-671 #2422

Merged
merged 31 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
973bd58
add nms ops
yinfan98 Sep 11, 2023
21e376b
add some file
yinfan98 Sep 11, 2023
9ac0a3d
new file
yinfan98 Sep 18, 2023
ffb3e51
some change
yinfan98 Sep 18, 2023
dffbd10
Update nms_match.cpp
yinfan98 Sep 18, 2023
a3131bf
Update nms_match.cpp
yinfan98 Sep 18, 2023
50e7aa0
Update __init__.py
yinfan98 Sep 18, 2023
6bc40bd
Delete test_onnx_match.onnx
yinfan98 Sep 18, 2023
882808f
Delete tests/test_ops/test_onnx_match.onnx
yinfan98 Sep 18, 2023
367f49b
Update test_nms_match_small.py
yinfan98 Sep 18, 2023
f2bd9df
Update test_nms_match_small.py
yinfan98 Sep 18, 2023
2da88b9
Update nms_match.cpp
yinfan98 Sep 27, 2023
8a3f26a
Update nms_match.py
yinfan98 Sep 27, 2023
1c80d0a
Update test_nms_match_small.py
yinfan98 Sep 27, 2023
fdc1eaa
Update nms_match.cpp
yinfan98 Sep 29, 2023
b961f15
Update nms_match.py
yinfan98 Sep 29, 2023
dbbb867
Update test_nms_match_small.py
yinfan98 Sep 29, 2023
aef7b99
fix the lint
yinfan98 Oct 7, 2023
7040f1e
Update test_nms_match_small.py
yinfan98 Oct 7, 2023
caf5256
Update test_nms_match_small.py
yinfan98 Oct 7, 2023
f4691be
Update nms_match.cpp
yinfan98 Oct 7, 2023
7e50e33
Update test_nms_match_small.py
yinfan98 Oct 7, 2023
722ebdc
Update test_nms_match_small.py
yinfan98 Oct 7, 2023
38c7d88
Update onnxruntime.md
yinfan98 Oct 7, 2023
5d262a2
Update onnxruntime.md
yinfan98 Oct 7, 2023
f538a1c
Update test_nms_match_small.py
yinfan98 Oct 8, 2023
45e4364
Update onnxruntime.md
yinfan98 Oct 8, 2023
42b7709
Update onnxruntime.md
yinfan98 Oct 8, 2023
f203be5
Update test_nms_match_small.py
yinfan98 Oct 8, 2023
2fb9984
Update test_nms_match_small.py
yinfan98 Oct 8, 2023
354847e
Update test_nms_match_small.py
yinfan98 Oct 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "nms_match.h"

#include <assert.h>

#include <algorithm>
#include <cassert>
#include <cmath>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>

#include "ort_utils.h"

namespace mmdeploy {
struct Box {
float x1, y1, x2, y2;
};

float nms_match_iou(Box box1, Box box2) {
auto max_x1 = std::max(box1.x1, box2.x1);
yinfan98 marked this conversation as resolved.
Show resolved Hide resolved
auto max_y1 = std::max(box1.y1, box2.y1);
auto max_x2 = std::min(box1.x2, box2.x2);
auto max_y2 = std::min(box1.y2, box2.y2);

auto w = std::max(static_cast<float>(0), max_x2 - max_x1);
auto h = std::max(static_cast<float>(0), max_y2 - max_y1);

auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
auto area2 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
yinfan98 marked this conversation as resolved.
Show resolved Hide resolved
auto inter = w * h;
auto ovr = inter / (area1 + area2 - inter);
yinfan98 marked this conversation as resolved.
Show resolved Hide resolved

return ovr;
}
NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info)
: ort_(api), info_(info) {
// iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
yinfan98 marked this conversation as resolved.
Show resolved Hide resolved
// score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");

// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}

void NMSMatchKernel::Compute(OrtKernelContext* context) {

const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0);
const float* boxes_data = reinterpret_cast<const float*>(ort_.GetTensorData<float>(boxes));
const OrtValue* scores = ort_.KernelContext_GetInput(context, 1);
const float* scores_data = reinterpret_cast<const float*>(ort_.GetTensorData<float>(scores));
const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2);
const float iou_threshold_data = ort_.GetTensorData<float>(iou_threshold_)[0];
const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3);
const float score_threshold_data = ort_.GetTensorData<float>(score_threshold_)[0];

OrtTensorDimensions boxes_dim(ort_, boxes);
OrtTensorDimensions scores_dim(ort_, scores);
// loop over batch
int64_t nbatch = boxes_dim[0];
int64_t nboxes = boxes_dim[1];
int64_t nclass = scores_dim[1];
assert(boxes_dim[2] == 4); //(x1, x2, y1, y2)
// alloc some temp memory
float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 4);
yinfan98 marked this conversation as resolved.
Show resolved Hide resolved
float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes);
bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes);

memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 4);
memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes);
std::vector<int64_t> res_order;
for (int64_t k = 0; k < nbatch; k++) {
for (int64_t g = 0; g < nclass; g++) {
for (int64_t i = 0; i < nboxes; i++) {
select[i] = true;
}
// scores
// k * nboxes * nclass means per batch
// g * nboxes means per class
// batch = 2 boxes = 3 classes = 4
std::vector<float> tmp_sc;
// get the class scores
for (int i = 0; i < nboxes; i++) {
tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]);
}

std::vector<int64_t> order(tmp_sc.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(),
[&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; });
for (int64_t _i = 0; _i < nboxes; _i++) {
if (select[_i] == false) continue;
auto i = order[_i];
std::vector<int64_t> v_i;
for (int64_t _j = _i + 1; _j < nboxes; _j++) {
if (select[_j] == false) continue;
auto j = order[_j];
Box vbox1, vbox2;
vbox1.x1 = tmp_boxes[k * nboxes * 4 + i * 4];
vbox1.y1 = tmp_boxes[k * nboxes * 4 + i * 4 + 1];
vbox1.x2 = tmp_boxes[k * nboxes * 4 + i * 4 + 2];
vbox1.y2 = tmp_boxes[k * nboxes * 4 + i * 4 + 3];

vbox2.x1 = tmp_boxes[k * nboxes * 4 + j * 4];
vbox2.y1 = tmp_boxes[k * nboxes * 4 + j * 4 + 1];
vbox2.x2 = tmp_boxes[k * nboxes * 4 + j * 4 + 2];
vbox2.y2 = tmp_boxes[k * nboxes * 4 + j * 4 + 3];

auto ovr = nms_match_iou(vbox1, vbox2);
if (ovr > iou_threshold_data) {
select[_j] = false;
v_i.push_back(j);
}
}
if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) {
for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) {
res_order.push_back(k);
res_order.push_back(g);
res_order.push_back(i);
res_order.push_back(v_i[v_i_idx]);
}
}
}
}
}
std::vector<int64_t> inds_dims({(int64_t)res_order.size() / 4, 4});

OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size());
int64_t* res_data = ort_.GetTensorMutableData<int64_t>(res);

memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size());

allocator_.Free(tmp_boxes);
allocator_.Free(sc);
allocator_.Free(select);
}
REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp);
}
46 changes: 46 additions & 0 deletions csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef ONNXRUNTIME_NMS_MATCH_H
#define ONNXRUNTIME_NMS_MATCH_H

#include <assert.h>
#include <onnxruntime_cxx_api.h>

#include <cmath>
#include <mutex>
#include <string>
#include <vector>

namespace mmdeploy {
struct NMSMatchKernel {
NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info);

void Compute(OrtKernelContext* context);

private:
Ort::CustomOpApi ort_;
const OrtKernelInfo* info_;
Ort::AllocatorWithDefaultOptions allocator_;
};

struct NMSMatchOp : Ort::CustomOpBase<NMSMatchOp, NMSMatchKernel> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new NMSMatchKernel(api, info);
}
const char* GetName() const { return "NMSMatch"; }

size_t GetInputTypeCount() const { return 4; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}

// force cpu
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
};
} // namespace mmdeploy

#endif // ONNXRUNTIME_NMS_MATCH_H
7 changes: 4 additions & 3 deletions mmdeploy/mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from . import roi_align_rotated # noqa: F401,F403
from . import transformer # noqa: F401,F403
from .nms import ONNXNMSop, TRTBatchedNMSop, multiclass_nms
from .nms_rotated import (ONNXNMSRotatedOp, TRTBatchedRotatedNMSop,
multiclass_nms_rotated)
from .nms_match import ONNXNMSMatchOp, multiclass_nms_match
from .nms_rotated import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop
yinfan98 marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
'ONNXNMSop', 'TRTBatchedNMSop', 'TRTBatchedRotatedNMSop',
'ONNXNMSRotatedOp', 'multiclass_nms', 'multiclass_nms_rotated'
'ONNXNMSRotatedOp', 'multiclass_nms', 'ONNXNMSMatchOp',
'multiclass_nms_match'
]
10 changes: 10 additions & 0 deletions mmdeploy/mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import IR, is_dynamic_batch
from mmdeploy.utils.constants import Backend
from .nms_match import multiclass_nms_match
from .nms_rotated import multiclass_nms_rotated


Expand Down Expand Up @@ -529,6 +530,15 @@ def multiclass_nms(boxes: Tensor,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
elif nms_type == 'nms_match':
return multiclass_nms_match(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
raise NotImplementedError(f'Unsupported nms type: {nms_type}.')

Expand Down
Loading