From c625be652e8406ba7d35c9eaedb978a4dba8dd48 Mon Sep 17 00:00:00 2001 From: Rishabh Thakur Date: Wed, 5 Jul 2023 13:02:54 +0530 Subject: [PATCH] NMS op added (#2313) * NMS Op Implementation Signed-off-by: Rishabh Thakur --------- Signed-off-by: Rishabh Thakur --- .../src/python/aimet_torch/elementwise_ops.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index 5b8108f084..691eaeecee 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -341,3 +341,28 @@ def forward(self, inp: torch.Tensor, roi: torch.Tensor, batch_indices: torch.Ten """ roi = torch.cat((torch.reshape(batch_indices, (batch_indices.shape[0], 1)), roi), dim=1) return torchvision.ops.roi_align(inp, roi, self.output_size, self.spatial_scale, self.sampling_ratio) + +class NonMaxSuppression(torch.nn.Module): + """ + Implementation of NMS Op in the form of nn.Module + """ + def __init__(self, iou_threshold: float, max_output_boxes_per_class: int): + super().__init__() + self.iou_threshold = iou_threshold + self.max_output_boxes_per_class = max_output_boxes_per_class + + def forward(self, *args) -> torch.Tensor: + """ + Forward-pass routine for NMS op + """ + batches_boxes = args[0] + batch_scores = args[1] + + res = [] + for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)): + for class_index, classes_score in enumerate(scores): + res_ = torchvision.ops.nms(boxes, classes_score, self.iou_threshold) + for val in res_: + res.append([index, class_index, val.detach()]) + res = res[:(self.max_output_boxes_per_class *(index+1))] + return torch.Tensor(res).type(torch.int64)