From 4fe2185ef4bc9d113e19761b254fb349cbb37694 Mon Sep 17 00:00:00 2001 From: vekosek Date: Tue, 13 Aug 2024 22:25:29 +0700 Subject: [PATCH] Added option to get boxes in yolo output after prepostprocessing --- .../tools/add_pre_post_processing_to_model.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py b/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py index 66cd3b01e..cb262d100 100644 --- a/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +++ b/onnxruntime_extensions/tools/add_pre_post_processing_to_model.py @@ -163,7 +163,8 @@ def superresolution(model_file: Path, output_file: Path, output_format: str, onn def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jpg', - onnx_opset: int = 16, num_classes: int = 80, input_shape: List[int] = None): + onnx_opset: int = 16, num_classes: int = 80, input_shape: List[int] = None, + output_as_image: bool = True): """ SSD-like model and Faster-RCNN-like model are including NMS inside already, You can find it from onnx model zoo. @@ -185,6 +186,7 @@ def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jp :param onnx_opset: The opset version of onnx model, default(16). :param num_classes: The number of classes, default(80). :param input_shape: The shape of input image (height,width), default will be asked from model input. + :param output_as_image: The flag that means that the model should have the image with boxes instead of the coordinates of the boxess """ model = onnx.load(str(model_file.resolve(strict=True))) inputs = [create_named_value("image", onnx.TensorProto.UINT8, ["num_bytes"])] @@ -284,19 +286,23 @@ def yolo_detection(model_file: Path, output_file: Path, output_format: str = 'jp utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2), utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3), ]), - # DrawBoundingBoxes on the original image - # Model imported from pytorch has CENTER_XYWH format - # two mode for how to color box, - # 1. colour_by_classes=True, (colour_by_classes), 2. colour_by_classes=False,(colour_by_confidence) - (DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=num_classes, colour_by_classes=True), - [ - utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=0), - utils.IoMapEntry("ScaleBoundingBoxes", producer_idx=0, consumer_idx=1), - ]), - # Encode to jpg/png - ConvertBGRToImage(image_format=output_format), ] + if output_as_image: + post_processing_steps += [ + # DrawBoundingBoxes on the original image + # Model imported from pytorch has CENTER_XYWH format + # two mode for how to color box, + # 1. colour_by_classes=True, (colour_by_classes), 2. colour_by_classes=False,(colour_by_confidence) + (DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=num_classes, colour_by_classes=True), + [ + utils.IoMapEntry("ConvertImageToBGR", producer_idx=0, consumer_idx=0), + utils.IoMapEntry("ScaleBoundingBoxes", producer_idx=0, consumer_idx=1), + ]), + # Encode to jpg/png + ConvertBGRToImage(image_format=output_format), + ] + pipeline.add_post_processing(post_processing_steps) new_model = pipeline.run(model)