Skip to content

Commit

Permalink
NMS op added (#2313)
Browse files Browse the repository at this point in the history
* NMS Op Implementation

Signed-off-by: Rishabh Thakur <[email protected]>

---------

Signed-off-by: Rishabh Thakur <[email protected]>
  • Loading branch information
quic-ristha authored Jul 5, 2023
1 parent dff80f0 commit c625be6
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,28 @@ def forward(self, inp: torch.Tensor, roi: torch.Tensor, batch_indices: torch.Ten
"""
roi = torch.cat((torch.reshape(batch_indices, (batch_indices.shape[0], 1)), roi), dim=1)
return torchvision.ops.roi_align(inp, roi, self.output_size, self.spatial_scale, self.sampling_ratio)

class NonMaxSuppression(torch.nn.Module):
"""
Implementation of NMS Op in the form of nn.Module
"""
def __init__(self, iou_threshold: float, max_output_boxes_per_class: int):
super().__init__()
self.iou_threshold = iou_threshold
self.max_output_boxes_per_class = max_output_boxes_per_class

def forward(self, *args) -> torch.Tensor:
"""
Forward-pass routine for NMS op
"""
batches_boxes = args[0]
batch_scores = args[1]

res = []
for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)):
for class_index, classes_score in enumerate(scores):
res_ = torchvision.ops.nms(boxes, classes_score, self.iou_threshold)
for val in res_:
res.append([index, class_index, val.detach()])
res = res[:(self.max_output_boxes_per_class *(index+1))]
return torch.Tensor(res).type(torch.int64)

0 comments on commit c625be6

Please sign in to comment.