From 761c0ff4fdcf79ded38da4d40f10a0559a3e1339 Mon Sep 17 00:00:00 2001 From: Dragonborn <78375175+MLDovakin@users.noreply.github.com> Date: Sun, 26 Feb 2023 21:57:05 +0300 Subject: [PATCH] Added heuristic for bounding bbox ordering (#835) --- sahi/predict.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/sahi/predict.py b/sahi/predict.py index 50bb40434..8863eb367 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -12,6 +12,8 @@ if is_available("torch"): import torch +from functools import cmp_to_key + import numpy as np from tqdm import tqdm @@ -289,6 +291,43 @@ def get_sliced_prediction( ) +def bbox_sort(a, b, thresh): + """ + a, b - function receives two bounding bboxes + + thresh - the threshold takes into account how far two bounding bboxes differ in + Y where thresh is the threshold we set for the + minimum allowable difference in height between adjacent bboxes + and sorts them by the X coordinate + """ + + bbox_a = a + bbox_b = b + + if abs(bbox_a[1] - bbox_b[1]) <= thresh: + return bbox_a[0] - bbox_b[0] + + return bbox_a[1] - bbox_b[1] + + +def agg_prediction(result: PredictionResult, thresh): + coord_list = [] + res = result.to_coco_annotations() + for ann in res: + current_bbox = ann["bbox"] + x = current_bbox[0] + y = current_bbox[1] + w = current_bbox[2] + h = current_bbox[3] + + coord_list.append((x, y, w, h)) + cnts = sorted(coord_list, key=cmp_to_key(lambda a, b: bbox_sort(a, b, thresh))) + for pred in range(len(res) - 1): + res[pred]["image_id"] = cnts.index(tuple(res[pred]["bbox"])) + + return res + + def predict( detection_model: DetectionModel = None, model_type: str = "mmdet",