diff --git a/src/otx/algo/detection/heads/rtmdet_head.py b/src/otx/algo/detection/heads/rtmdet_head.py index b05aab77cef..179f47946f4 100644 --- a/src/otx/algo/detection/heads/rtmdet_head.py +++ b/src/otx/algo/detection/heads/rtmdet_head.py @@ -5,10 +5,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import torch from torch import Tensor, nn from otx.algo.detection.heads.atss_head import ATSSHead +from otx.algo.detection.ops.nms import multiclass_nms from otx.algo.detection.utils.utils import ( anchor_inside_flags, distance2bbox, @@ -26,6 +29,9 @@ from otx.algo.utils.mmengine_utils import InstanceData from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init +if TYPE_CHECKING: + from omegaconf import DictConfig + class RTMDetHead(ATSSHead): """Detection Head of RTMDet. @@ -314,6 +320,76 @@ def loss_by_feat( # type: ignore[override] losses_bbox = [x / bbox_avg_factor for x in losses_bbox] return {"loss_cls": losses_cls, "loss_bbox": losses_bbox} + def export_by_feat( # type: ignore[override] + self, + cls_scores: list[Tensor], + bbox_preds: list[Tensor], + batch_img_metas: list[dict] | None = None, + cfg: DictConfig | None = None, + rescale: bool = False, + with_nms: bool = True, + ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]: + """Transform network output for a batch into bbox predictions. + + Reference : https://github.com/open-mmlab/mmdeploy/blob/v1.3.1/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py#L18-L108 + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (DictConfig, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor, + where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch + size and the score between 0 and 1. The shape of the second + tensor in the tuple is (N, num_box), and each element + represents the class label of the corresponding box. + """ + assert len(cls_scores) == len(bbox_preds) # noqa: S101 + device = cls_scores[0].device + cfg = self.test_cfg if cfg is None else cfg + batch_size = bbox_preds[0].shape[0] + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + mlvl_priors = self.prior_generator.grid_priors(featmap_sizes, device=device) + + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, self.cls_out_channels) for cls_score in cls_scores + ] + flatten_bbox_preds = [bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) for bbox_pred in bbox_preds] + flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + priors = torch.cat(mlvl_priors) + tl_x = priors[..., 0] - flatten_bbox_preds[..., 0] # type: ignore[call-overload] + tl_y = priors[..., 1] - flatten_bbox_preds[..., 1] # type: ignore[call-overload] + br_x = priors[..., 0] + flatten_bbox_preds[..., 2] # type: ignore[call-overload] + br_y = priors[..., 1] + flatten_bbox_preds[..., 3] # type: ignore[call-overload] + bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) + scores = flatten_cls_scores + if not with_nms: + return bboxes, scores + + return multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class=200, + iou_threshold=cfg.nms.iou_threshold, + score_threshold=cfg.score_thr, + pre_top_k=5000, + keep_top_k=cfg.max_per_img, + ) + def get_targets( # type: ignore[override] self, cls_scores: Tensor, diff --git a/src/otx/algo/detection/ops/nms.py b/src/otx/algo/detection/ops/nms.py index 000f7e69609..f79393a686e 100644 --- a/src/otx/algo/detection/ops/nms.py +++ b/src/otx/algo/detection/ops/nms.py @@ -315,17 +315,17 @@ def _select_nms_index( batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1) batched_labels = batched_labels.where((batch_inds == batch_template.unsqueeze(1)), batched_labels.new_ones(1) * -1) - batch_size = batched_dets.shape[0] + new_batch_size = batched_dets.shape[0] # expand tensor to eliminate [0, ...] tensor - batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((batch_size, 1, 5))), 1) - batched_labels = torch.cat((batched_labels, batched_labels.new_zeros((batch_size, 1))), 1) + batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((new_batch_size, 1, 5))), 1) + batched_labels = torch.cat((batched_labels, batched_labels.new_zeros((new_batch_size, 1))), 1) if output_index and pre_inds is not None: # batch all pre_inds = pre_inds[batch_inds, box_inds] pre_inds = pre_inds.unsqueeze(0).repeat(batch_size, 1) pre_inds = pre_inds.where((batch_inds == batch_template.unsqueeze(1)), pre_inds.new_zeros(1)) - pre_inds = torch.cat((pre_inds, -pre_inds.new_ones((batch_size, 1))), 1) + pre_inds = torch.cat((pre_inds, -pre_inds.new_ones((new_batch_size, 1))), 1) # sort is_use_topk = keep_top_k > 0 and (torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1]) if is_use_topk: diff --git a/src/otx/algo/detection/rtmdet.py b/src/otx/algo/detection/rtmdet.py index d3cf5477f5e..d7efbdf91d6 100644 --- a/src/otx/algo/detection/rtmdet.py +++ b/src/otx/algo/detection/rtmdet.py @@ -169,21 +169,19 @@ def _exporter(self) -> OTXModelExporter: std=self.std, resize_mode="fit_to_window_letterbox", pad_value=114, - swap_rgb=False, + swap_rgb=True, via_onnx=True, onnx_export_configuration={ "input_names": ["image"], - "output_names": ["boxes", "labels", "masks"], + "output_names": ["boxes", "labels"], "dynamic_axes": { "image": {0: "batch", 2: "height", 3: "width"}, "boxes": {0: "batch", 1: "num_dets"}, "labels": {0: "batch", 1: "num_dets"}, - "masks": {0: "batch", 1: "num_dets", 2: "height", 3: "width"}, }, - "opset_version": 11, "autograd_inlining": False, }, - output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None, + output_names=["bboxes", "labels", "feature_vector", "saliency_map"] if self.explain_mode else None, ) def forward_for_tracing( diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py index 94f5bc487b6..738b0194e72 100644 --- a/src/otx/algo/detection/utils/utils.py +++ b/src/otx/algo/detection/utils/utils.py @@ -464,6 +464,10 @@ def clip_bboxes_export( Returns: tuple(Tensor): The clipped x1, y1, x2, y2. """ + if len(max_shape) != 2: + msg = "`max_shape` should be [h, w]." + raise ValueError(msg) + if isinstance(max_shape, torch.Tensor): # scale by 1/max_shape x1 = x1 / max_shape[1] diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index b1f284cbb0e..94b5d5300a9 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -554,7 +554,7 @@ def _customize_outputs( if label_shift: log.warning(f"label_shift: {label_shift}") - for output in outputs: + for i, output in enumerate(outputs): output_objects = output.objects if len(output_objects): bbox = [[output.xmin, output.ymin, output.xmax, output.ymax] for output in output_objects] @@ -564,7 +564,7 @@ def _customize_outputs( tv_tensors.BoundingBoxes( bbox, format="XYXY", - canvas_size=inputs.imgs_info[-1].img_shape, + canvas_size=inputs.imgs_info[i].img_shape, device=self.device, ), ) diff --git a/src/otx/engine/utils/auto_configurator.py b/src/otx/engine/utils/auto_configurator.py index d662f0f70ce..25983ce4e9d 100644 --- a/src/otx/engine/utils/auto_configurator.py +++ b/src/otx/engine/utils/auto_configurator.py @@ -381,18 +381,20 @@ def update_ov_subset_pipeline(self, datamodule: OTXDataModule, subset: str = "te OTXDataModule: The modified OTXDataModule object with OpenVINO subset transforms applied. """ data_configuration = datamodule.config - ov_test_config = self._load_default_config(model_name="openvino_model")["data"]["config"][f"{subset}_subset"] + ov_config = self._load_default_config(model_name="openvino_model")["data"]["config"] subset_config = getattr(data_configuration, f"{subset}_subset") - subset_config.batch_size = ov_test_config["batch_size"] - subset_config.transform_lib_type = ov_test_config["transform_lib_type"] - subset_config.transforms = ov_test_config["transforms"] - subset_config.to_tv_image = ov_test_config["to_tv_image"] + subset_config.batch_size = ov_config[f"{subset}_subset"]["batch_size"] + subset_config.transform_lib_type = ov_config[f"{subset}_subset"]["transform_lib_type"] + subset_config.transforms = ov_config[f"{subset}_subset"]["transforms"] + subset_config.to_tv_image = ov_config[f"{subset}_subset"]["to_tv_image"] + data_configuration.image_color_channel = ov_config["image_color_channel"] data_configuration.tile_config.enable_tiler = False msg = ( f"For OpenVINO IR models, Update the following {subset} \n" f"\t transforms: {subset_config.transforms} \n" f"\t transform_lib_type: {subset_config.transform_lib_type} \n" f"\t batch_size: {subset_config.batch_size} \n" + f"\t image_color_channel: {data_configuration.image_color_channel} \n" "And the tiler is disabled." ) warn(msg, stacklevel=1) diff --git a/tests/unit/algo/detection/heads/test_rtmdet_head.py b/tests/unit/algo/detection/heads/test_rtmdet_head.py index 06feae4466a..10af4f2c734 100644 --- a/tests/unit/algo/detection/heads/test_rtmdet_head.py +++ b/tests/unit/algo/detection/heads/test_rtmdet_head.py @@ -91,6 +91,30 @@ def test_loss_by_feat_single(self, rtmdet_head) -> None: assert loss_cls is not None assert loss_bbox is not None + def test_export_by_feat(self, mocker, rtmdet_head) -> None: + batch_size = 2 + num_priors = 1 + num_classes = 80 + cls_scores = [torch.rand(batch_size, num_priors * num_classes, 20, 20) for _ in range(3)] + bbox_preds = [torch.rand(batch_size, num_priors * 4, 20, 20) for _ in range(3)] + batch_img_metas = [{"img_shape": (320, 320, 3), "scale_factor": 1.0} for _ in range(2)] + mocker_multiclass_nms = mocker.patch( + "otx.algo.detection.heads.rtmdet_head.multiclass_nms", + return_value=(torch.rand(2, 300, 5), torch.randint(0, 80, (2, 300))), + ) + + bboxes, scores = rtmdet_head.export_by_feat(cls_scores, bbox_preds, batch_img_metas) + + # Verify that the multiclass_nms function was called + mocker_multiclass_nms.assert_called_once() + + # Check the shape of the output + assert bboxes.shape[0] == 2 # batch size + assert bboxes.shape[1] == 300 # max_per_img + assert bboxes.shape[2] == 5 # 4 bbox coordinates + score + assert scores.shape[0] == 2 # batch size + assert scores.shape[1] == 300 # max_per_img + def test_get_anchors(self, rtmdet_head) -> None: featmap_sizes = [(40, 40), (20, 20), (10, 10)] batch_img_metas = [{"img_shape": (320, 320, 3)} for _ in range(2)] diff --git a/tests/unit/algo/detection/test_rtmdet.py b/tests/unit/algo/detection/test_rtmdet.py index fa59cae63f4..30b88a9bd5b 100644 --- a/tests/unit/algo/detection/test_rtmdet.py +++ b/tests/unit/algo/detection/test_rtmdet.py @@ -25,7 +25,7 @@ def test_exporter(self) -> None: otx_rtmdet_tiny = RTMDetTiny(label_info=3) otx_rtmdet_tiny_exporter = otx_rtmdet_tiny._exporter assert isinstance(otx_rtmdet_tiny_exporter, OTXNativeModelExporter) - assert otx_rtmdet_tiny_exporter.swap_rgb is False + assert otx_rtmdet_tiny_exporter.swap_rgb is True @pytest.mark.parametrize("model", [RTMDetTiny(3)]) def test_loss(self, model, fxt_data_module):