diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp index 73a0b8acef..075c3277bc 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp @@ -14,7 +14,7 @@ void parallel_unroll_gemm(const float *A, const float *B, const float *V, const const int32_t M, const int32_t N, const int32_t K, const float alpha, const float beta, float *Y, const int32_t start_row, const int32_t end_row) { - float tmp[N]; // tmp + std::vector tmp(N); for (int32_t m = start_row; m < end_row; ++m) { for (int32_t n = 0; n < N; n++) { tmp[n] = 0; diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp new file mode 100644 index 0000000000..784be2c987 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp @@ -0,0 +1,129 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "nms_match.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ort_utils.h" + +namespace mmdeploy { +struct Box { + float x1, y1, x2, y2; +}; + +float nms_match_iou(Box box1, Box box2) { + auto inter_x1 = std::max(box1.x1, box2.x1); + auto inter_y1 = std::max(box1.y1, box2.y1); + auto inter_x2 = std::min(box1.x2, box2.x2); + auto inter_y2 = std::min(box1.y2, box2.y2); + + auto eps = 1e-10; + + auto w = std::max(static_cast(0), inter_x2 - inter_x1); + auto h = std::max(static_cast(0), inter_y2 - inter_y1); + + auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1); + auto area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1); + auto inter = w * h; + auto ovr = inter / (area1 + area2 - inter + eps); + return ovr; +} +NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api), info_(info) { + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); +} + +void NMSMatchKernel::Compute(OrtKernelContext* context) { + const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); + const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); + const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); + const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2); + const float iou_threshold_data = ort_.GetTensorData(iou_threshold_)[0]; + const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3); + const float score_threshold_data = ort_.GetTensorData(score_threshold_)[0]; + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + // loop over batch + int64_t nbatch = boxes_dim[0]; + int64_t nboxes = boxes_dim[1]; + int64_t nclass = scores_dim[1]; + assert(boxes_dim[2] == 4); //(x1, x2, y1, y2) + // alloc some temp memory + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); + + std::vector res_order; + for (int64_t k = 0; k < nbatch; k++) { + for (int64_t g = 0; g < nclass; g++) { + for (int64_t i = 0; i < nboxes; i++) { + select[i] = true; + } + // scores + // k * nboxes * nclass means per batch + // g * nboxes means per class + // batch = 2 boxes = 3 classes = 4 + std::vector tmp_sc; + // get the class scores + for (int i = 0; i < nboxes; i++) { + tmp_sc.push_back(scores_data[k * nboxes * nclass + g * nboxes + i]); + } + + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), + [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); + for (int64_t _i = 0; _i < nboxes; _i++) { + auto i = order[_i]; + if (select[i] == false) continue; + std::vector v_i; + for (int64_t _j = _i + 1; _j < nboxes; _j++) { + auto j = order[_j]; + if (select[j] == false) continue; + Box vbox1, vbox2; + vbox1.x1 = boxes_data[k * nboxes * 4 + i * 4]; + vbox1.y1 = boxes_data[k * nboxes * 4 + i * 4 + 1]; + vbox1.x2 = boxes_data[k * nboxes * 4 + i * 4 + 2]; + vbox1.y2 = boxes_data[k * nboxes * 4 + i * 4 + 3]; + + vbox2.x1 = boxes_data[k * nboxes * 4 + j * 4]; + vbox2.y1 = boxes_data[k * nboxes * 4 + j * 4 + 1]; + vbox2.x2 = boxes_data[k * nboxes * 4 + j * 4 + 2]; + vbox2.y2 = boxes_data[k * nboxes * 4 + j * 4 + 3]; + + auto ovr = nms_match_iou(vbox1, vbox2); + if (ovr >= iou_threshold_data) { + select[j] = false; + v_i.push_back(j); + } + } + if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) { + for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) { + res_order.push_back(k); + res_order.push_back(g); + res_order.push_back(i); + res_order.push_back(v_i[v_i_idx]); + } + } + } + } + } + std::vector inds_dims({(int64_t)res_order.size() / 4, 4}); + + OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); + int64_t* res_data = ort_.GetTensorMutableData(res); + + memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + + allocator_.Free(select); +} +REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp); +} // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h new file mode 100644 index 0000000000..57aa94d964 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h @@ -0,0 +1,46 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef ONNXRUNTIME_NMS_MATCH_H +#define ONNXRUNTIME_NMS_MATCH_H + +#include +#include + +#include +#include +#include +#include + +namespace mmdeploy { +struct NMSMatchKernel { + NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; +}; + +struct NMSMatchOp : Ort::CustomOpBase { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { + return new NMSMatchKernel(api, info); + } + const char* GetName() const { return "NMSMatch"; } + + size_t GetInputTypeCount() const { return 4; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + // force cpu + const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } +}; +} // namespace mmdeploy + +#endif // ONNXRUNTIME_NMS_MATCH_H diff --git a/docs/en/06-custom-ops/onnxruntime.md b/docs/en/06-custom-ops/onnxruntime.md index a3d96332d6..5d748db1ca 100644 --- a/docs/en/06-custom-ops/onnxruntime.md +++ b/docs/en/06-custom-ops/onnxruntime.md @@ -27,6 +27,12 @@ - [Inputs](#inputs-3) - [Outputs](#outputs-3) - [Type Constraints](#type-constraints-3) +- [NMSMatch](#nmsmatch) + - [Description](#description-2) + - [Parameters](#parameters-2) + - [Inputs](#inputs-2) + - [Outputs](#outputs-2) + - [Type Constraints](#type-constraints-2) @@ -174,3 +180,36 @@ Perform RoIAlignRotated on output feature, used in bbox_head of most two-stage r #### Type Constraints - T:tensor(float32) + +### NMSMatch + +#### Description + +Non Max Suppression with the suppression box match. + +#### Parameters + +| Type | Parameter | Description | +| ------- | ----------- | --------------------------------- | +| `float` | `iou_thr` | The IoU threshold for NMSMatch. | +| `float` | `score_thr` | The score threshold for NMSMatch. | + +#### Inputs + +
+
inputs[0]: T
+
Input boxes; 3-D tensor of shape (b, N, 4), where b is the batch size, N is the number of boxes and 4 means the coordinate.
+
inputs[1]: T
+
Input scores; 3-D tensor of shape (b, c, N), where b is the batch size, c is the class size and N is the number of boxes.
+
+ +#### Outputs + +
+
outputs[0]: T
+
Output feature; 2-D tensor of shape (K, 4), K is the number of matched boxes, 4 is batch id, class id, select boxes, suppressed boxes.
+
+ +#### Type Constraints + +- T:tensor(float32) diff --git a/docs/zh_cn/06-custom-ops/onnxruntime.md b/docs/zh_cn/06-custom-ops/onnxruntime.md index e4f0779efa..eb5cba3781 100644 --- a/docs/zh_cn/06-custom-ops/onnxruntime.md +++ b/docs/zh_cn/06-custom-ops/onnxruntime.md @@ -27,6 +27,12 @@ - [Inputs](#inputs-3) - [Outputs](#outputs-3) - [Type Constraints](#type-constraints-3) +- [NMSMatch](#nmsmatch) + - [Description](#description-2) + - [Parameters](#parameters-2) + - [Inputs](#inputs-2) + - [Outputs](#outputs-2) + - [Type Constraints](#type-constraints-2) @@ -174,3 +180,36 @@ Perform RoIAlignRotated on output feature, used in bbox_head of most two-stage r #### Type Constraints - T:tensor(float32) + +### NMSMatch + +#### Description + +Non Max Suppression with the suppression box match. + +#### Parameters + +| Type | Parameter | Description | +| ------- | ----------- | --------------------------------- | +| `float` | `iou_thr` | The IoU threshold for NMSMatch. | +| `float` | `score_thr` | The score threshold for NMSMatch. | + +#### Inputs + +
+
inputs[0]: T
+
Input boxes; 3-D tensor of shape (b, N, 4), where b is the batch size, N is the number of boxes and 4 means the coordinate.
+
inputs[1]: T
+
Input scores; 3-D tensor of shape (b, c, N), where b is the batch size, c is the class size and N is the number of boxes.
+
+ +#### Outputs + +
+
outputs[0]: T
+
Output feature; 2-D tensor of shape (K, 4), K is the number of matched boxes, 4 is batch id, class id, select boxes, suppressed boxes.
+
+ +#### Type Constraints + +- T:tensor(float32) diff --git a/mmdeploy/mmcv/ops/__init__.py b/mmdeploy/mmcv/ops/__init__.py index ed55820b02..21f9e6f936 100644 --- a/mmdeploy/mmcv/ops/__init__.py +++ b/mmdeploy/mmcv/ops/__init__.py @@ -6,11 +6,13 @@ from . import roi_align # noqa: F401,F403 from . import roi_align_rotated # noqa: F401,F403 from . import transformer # noqa: F401,F403 -from .nms import ONNXNMSop, TRTBatchedNMSop, multiclass_nms -from .nms_rotated import (ONNXNMSRotatedOp, TRTBatchedRotatedNMSop, - multiclass_nms_rotated) +from .nms import ONNXNMSop, TRTBatchedNMSop, multiclass_nms # noqa: F401,F403 +from .nms_match import ONNXNMSMatchOp, multiclass_nms_match +from .nms_rotated import multiclass_nms_rotated # noqa: F401,F403 +from .nms_rotated import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop __all__ = [ 'ONNXNMSop', 'TRTBatchedNMSop', 'TRTBatchedRotatedNMSop', - 'ONNXNMSRotatedOp', 'multiclass_nms', 'multiclass_nms_rotated' + 'ONNXNMSRotatedOp', 'multiclass_nms_rotated' + 'multiclass_nms', 'ONNXNMSMatchOp', 'multiclass_nms_match' ] diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index dc22c55380..ec80232781 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -7,6 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.utils import IR, is_dynamic_batch from mmdeploy.utils.constants import Backend +from .nms_match import multiclass_nms_match from .nms_rotated import multiclass_nms_rotated @@ -529,6 +530,15 @@ def multiclass_nms(boxes: Tensor, score_threshold=score_threshold, pre_top_k=pre_top_k, keep_top_k=keep_top_k) + elif nms_type == 'nms_match': + return multiclass_nms_match( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) else: raise NotImplementedError(f'Unsupported nms type: {nms_type}.') diff --git a/mmdeploy/mmcv/ops/nms_match.py b/mmdeploy/mmcv/ops/nms_match.py new file mode 100644 index 0000000000..4fec6a59cf --- /dev/null +++ b/mmdeploy/mmcv/ops/nms_match.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor +from torch.onnx import symbolic_helper as sym_help + +import mmdeploy +from mmdeploy.core import mark + + +class ONNXNMSMatchOp(torch.autograd.Function): + """Create onnx::NonMaxSuppressionMatch op. + + NMS_Match in mmcv only supports one class with no batch info. This class + assists in exporting NMS_Match of ONNX's definition. + """ + + @staticmethod + def forward(ctx, boxes: Tensor, scores: Tensor, iou_threshold: float, + score_threshold: float) -> Tensor: + """Get NMS_Match_Fake output indices. + + Args: + ctx (Context): The context with meta information. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + iou_threshold (float): IOU threshold of nms. + score_threshold (float): score threshold of nms. + + Returns: + Tensor: Selected indices of boxes. 2-D tensor of shape + (num_selected_indices, 4) with each row of + [batch_index, class_index, box_index, suppresion_index]. + """ + from mmcv.ops import nms_match + batch_size, num_class, _ = scores.shape + + indices = [] + score_threshold = float(score_threshold) + iou_threshold = float(iou_threshold) + for batch_id in range(batch_size): + for cls_id in range(num_class): + _boxes = boxes[batch_id, ...] + _scores = scores[batch_id, cls_id, ...].contiguous() + _dets = torch.cat((_boxes, _scores.unsqueeze(1)), dim=1) + box_inds = nms_match(_dets, iou_threshold) + batch_inds = torch.zeros(1) + batch_id + cls_inds = torch.zeros(1) + cls_id + both_inds = torch.cat([batch_inds, cls_inds]) + for box in box_inds: + if box.size() == 1: + continue + keep = box[0] + box = box[1:] + if _dets[keep][-1] < score_threshold: + continue + for supp in box: + indices.append( + torch.cat((both_inds, keep.unsqueeze(0), + supp.unsqueeze(0)))) + return torch.stack(indices).to(torch.int64) + + @staticmethod + def symbolic(g, boxes: Tensor, scores: Tensor, iou_threshold: float, + score_threshold: float): + """Symbolic function for mmdeploy::NMSMatch. + + Args: + g (Graph): The traced onnx graph. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + iou_threshold (float): IOU threshold of nms. + score_threshold (float): score threshold of nms. + + Returns: + NonMaxSuppressionMatch op for onnx. + """ + if not sym_help._is_value(iou_threshold): + iou_threshold = g.op( + 'Constant', + value_t=torch.tensor([iou_threshold], dtype=torch.float)) + + if not sym_help._is_value(score_threshold): + score_threshold = g.op( + 'Constant', + value_t=torch.tensor([score_threshold], dtype=torch.float)) + return g.op('mmdeploy::NMSMatch', boxes, scores, iou_threshold, + score_threshold) + + +def _select_nms_index(scores: torch.Tensor, + boxes: torch.Tensor, + nms_index, + batch_size: int, + keep_top_k: int = -1): + """Transform NMS_Match output. + + Args: + scores (Tensor): The detection scores of shape + [N, num_classes, num_boxes]. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + nms_index (Tensor): NMS output of bounding boxes indexing. + here is [K, ?] + batch_size (int): Batch size of the input image. + keep_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + + + Returns: + tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] + and `labels` of shape [N, num_det]. + """ + batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1] + box_inds = nms_index[:, 2] + scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1) + boxes = boxes[batch_inds, box_inds, ...] + dets = torch.cat([boxes, scores], dim=1) + + # batch all + batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1) + batch_template = torch.arange( + 0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device) + batched_dets = batched_dets.where( + (batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1), + batched_dets.new_zeros(1)) + + batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1) + batched_labels = batched_labels.where( + (batch_inds == batch_template.unsqueeze(1)), + batched_labels.new_ones(1) * -1) + + N = batched_dets.shape[0] + + # expand tensor to eliminate [0, ...] tensor + batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), + 1) + batched_labels = torch.cat((batched_labels, batched_labels.new_zeros( + (N, 1))), 1) + # sort + is_use_topk = keep_top_k > 0 and \ + (torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1]) + if is_use_topk: + _, topk_inds = batched_dets[:, :, -1].topk(keep_top_k, dim=1) + else: + _, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) + topk_batch_inds = torch.arange( + batch_size, dtype=topk_inds.dtype, + device=topk_inds.device).view(-1, 1) + batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] + batched_labels = batched_labels[topk_batch_inds, topk_inds, ...] + # slice and recover the tensor + return batched_dets, batched_labels + + +def _multiclass_nms_match(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1, + output_index: bool = False): + """Create a dummy onnx::NonMaxSuppressionMatch op while exporting to ONNX. + + This function helps exporting to onnx with batch and multiclass NMSMatch + op. It only supports class-agnostic detection results. That is, the scores + is of shape (N, num_bboxes, num_classes) and the boxes is of shape (N, + num_boxes, 4). + """ + iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) + score_threshold = torch.tensor([score_threshold], dtype=torch.float32) + batch_size = scores.shape[0] + topk_inds = None + if pre_top_k > 0: + max_scores, _ = scores.max(-1) + _, topk_inds = max_scores.topk(pre_top_k) + batch_inds = torch.arange( + batch_size, device=scores.device).view(-1, 1).long() + boxes = boxes[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + + scores = scores.permute(0, 2, 1) + selected_indices = ONNXNMSMatchOp.apply(boxes, scores, iou_threshold, + score_threshold) + return _select_nms_index( + scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) + + +@mark( + 'multiclass_nms_match', + inputs=['boxes', 'scores'], + outputs=['dets', 'labels']) +def multiclass_nms_match(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.1, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): + """Wrapper function for `_multiclass_nms_match`.""" + return mmdeploy.mmcv.ops.nms_match._multiclass_nms_match( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/tests/test_ops/test_nms_match_small.py b/tests/test_ops/test_nms_match_small.py new file mode 100644 index 0000000000..1c08fc5272 --- /dev/null +++ b/tests/test_ops/test_nms_match_small.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import tempfile + +import numpy +import onnxruntime +import pytest +import torch + +from mmdeploy.backend.onnxruntime.init_plugins import get_ops_path +from mmdeploy.mmcv.ops import ONNXNMSMatchOp + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +boxes = torch.tensor([ + [ + [291.1746, 316.2263, 343.5029, 347.7312], + [288.4846, 315.0447, 343.7267, 346.5630], + [288.5307, 318.1989, 341.6425, 349.7222], + [918.9102, 83.7463, 933.3920, 164.9041], + [895.5786, 78.2361, 907.8049, 172.0883], + [292.5816, 316.5563, 340.3462, 352.9989], + [609.4592, 83.5447, 631.2532, 144.0749], + [917.7308, 85.5870, 933.2839, 168.4530], + [895.5138, 79.3596, 908.2865, 171.0418], + [291.4747, 318.6987, 347.1208, 349.5754], + ], + [ + [291.1746, 316.2263, 343.5029, 347.7312], + [288.4846, 315.0447, 343.7267, 346.5630], + [288.5307, 318.1989, 341.6425, 349.7222], + [918.9102, 83.7463, 933.3920, 164.9041], + [895.5786, 78.2361, 907.8049, 172.0883], + [292.5816, 316.5563, 340.3462, 352.9989], + [609.4592, 83.5447, 631.2532, 144.0749], + [917.7308, 85.5870, 933.2839, 168.4530], + [895.5138, 79.3596, 908.2865, 171.0418], + [291.4747, 318.6987, 347.1208, 349.5754], + ], +]) +scores = torch.tensor([ + [ + [0.9577, 0.9745, 0.3030, 0.6589, 0.2742], + [0.1618, 0.7963, 0.5124, 0.6964, 0.6850], + [0.8425, 0.4843, 0.9489, 0.8068, 0.7340], + [0.7337, 0.4340, 0.9923, 0.0704, 0.4506], + [0.3090, 0.5606, 0.6939, 0.3764, 0.6920], + [0.0044, 0.7986, 0.2221, 0.2782, 0.4378], + [0.7293, 0.2735, 0.8381, 0.0264, 0.6278], + [0.7144, 0.1066, 0.4125, 0.4041, 0.8819], + [0.4963, 0.7891, 0.6908, 0.1499, 0.5584], + [0.4385, 0.6035, 0.0508, 0.0662, 0.5938], + ], + [ + [0.9577, 0.9745, 0.3030, 0.6589, 0.2742], + [0.1618, 0.7963, 0.5124, 0.6964, 0.6850], + [0.8425, 0.4843, 0.9489, 0.8068, 0.7340], + [0.7337, 0.4340, 0.9923, 0.0704, 0.4506], + [0.3090, 0.5606, 0.6939, 0.3764, 0.6920], + [0.0044, 0.7986, 0.2221, 0.2782, 0.4378], + [0.7293, 0.2735, 0.8381, 0.0264, 0.6278], + [0.7144, 0.1066, 0.4125, 0.4041, 0.8819], + [0.4963, 0.7891, 0.6908, 0.1499, 0.5584], + [0.4385, 0.6035, 0.0508, 0.0662, 0.5938], + ], +]) +scores = scores.permute(0, 2, 1) +iou_threshold = torch.tensor([0.1]) +score_threshold = torch.tensor([0.1]) +match_op = ONNXNMSMatchOp.apply + + +class test_ONNX_Match(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, boxes, scores, iou_threshold, score_threshold): + return match_op(boxes, scores, iou_threshold, score_threshold) + + +@pytest.mark.skipif( + reason='Need to build onnxrumtime custom op', + condition=get_ops_path() == '') +def test_nms_match(): + print('Running compilation...') + # here is a PyTorch test + model = test_ONNX_Match() + torch_output = model(boxes, scores, iou_threshold, + score_threshold).detach().numpy() + # export the onnx file with a tempfile + temp_onnx = tempfile.NamedTemporaryFile( + suffix='.onnx', delete=False, mode='wb', dir=cur_dir) + input_name = ['boxes', 'scores', 'iou_thr', 'score_thr'] + torch.onnx.export( + model, + (boxes, scores, iou_threshold, score_threshold), + temp_onnx.name, + input_names=input_name, + ) + temp_onnx.close() + options = onnxruntime.SessionOptions() + options.register_custom_ops_library(get_ops_path()) + + sess = onnxruntime.InferenceSession( + temp_onnx.name, options, providers=['CPUExecutionProvider']) + ort_output = sess.run( + None, + { + 'boxes': boxes.numpy(), + 'scores': scores.numpy(), + 'iou_thr': iou_threshold.numpy(), + 'score_thr': score_threshold.numpy(), + }, + ) + + assert numpy.array_equal( + numpy.array(torch_output), + numpy.array(ort_output[0])), 'list are not equal' + os.remove(temp_onnx.name)