Skip to content

Commit

Permalink
Added heuristic for bounding bbox ordering (#835)
Browse files Browse the repository at this point in the history
  • Loading branch information
MLDovakin authored Feb 26, 2023
1 parent 906dd4d commit 761c0ff
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
if is_available("torch"):
import torch

from functools import cmp_to_key

import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -289,6 +291,43 @@ def get_sliced_prediction(
)


def bbox_sort(a, b, thresh):
"""
a, b - function receives two bounding bboxes
thresh - the threshold takes into account how far two bounding bboxes differ in
Y where thresh is the threshold we set for the
minimum allowable difference in height between adjacent bboxes
and sorts them by the X coordinate
"""

bbox_a = a
bbox_b = b

if abs(bbox_a[1] - bbox_b[1]) <= thresh:
return bbox_a[0] - bbox_b[0]

return bbox_a[1] - bbox_b[1]


def agg_prediction(result: PredictionResult, thresh):
coord_list = []
res = result.to_coco_annotations()
for ann in res:
current_bbox = ann["bbox"]
x = current_bbox[0]
y = current_bbox[1]
w = current_bbox[2]
h = current_bbox[3]

coord_list.append((x, y, w, h))
cnts = sorted(coord_list, key=cmp_to_key(lambda a, b: bbox_sort(a, b, thresh)))
for pred in range(len(res) - 1):
res[pred]["image_id"] = cnts.index(tuple(res[pred]["bbox"]))

return res


def predict(
detection_model: DetectionModel = None,
model_type: str = "mmdet",
Expand Down

0 comments on commit 761c0ff

Please sign in to comment.