Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RTMDet-tiny enablement for detection task (export/optimize) #3564

Merged
merged 11 commits into from
Jun 14, 2024
76 changes: 76 additions & 0 deletions src/otx/algo/detection/heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch import Tensor, nn

from otx.algo.detection.heads.atss_head import ATSSHead
from otx.algo.detection.ops.nms import multiclass_nms
from otx.algo.detection.utils.utils import (
anchor_inside_flags,
distance2bbox,
Expand All @@ -26,6 +29,9 @@
from otx.algo.utils.mmengine_utils import InstanceData
from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init

if TYPE_CHECKING:
from omegaconf import DictConfig


class RTMDetHead(ATSSHead):
"""Detection Head of RTMDet.
Expand Down Expand Up @@ -314,6 +320,76 @@ def loss_by_feat( # type: ignore[override]
losses_bbox = [x / bbox_avg_factor for x in losses_bbox]
return {"loss_cls": losses_cls, "loss_bbox": losses_bbox}

def export_by_feat( # type: ignore[override]
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
batch_img_metas: list[dict] | None = None,
cfg: DictConfig | None = None,
rescale: bool = False,
with_nms: bool = True,
) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
"""Transform network output for a batch into bbox predictions.

Reference : https://github.com/open-mmlab/mmdeploy/blob/v1.3.1/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py#L18-L108

Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (DictConfig, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.

Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
assert len(cls_scores) == len(bbox_preds) # noqa: S101
device = cls_scores[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(featmap_sizes, device=device)

flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, self.cls_out_channels) for cls_score in cls_scores
]
flatten_bbox_preds = [bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) for bbox_pred in bbox_preds]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
priors = torch.cat(mlvl_priors)
tl_x = priors[..., 0] - flatten_bbox_preds[..., 0] # type: ignore[call-overload]
tl_y = priors[..., 1] - flatten_bbox_preds[..., 1] # type: ignore[call-overload]
br_x = priors[..., 0] + flatten_bbox_preds[..., 2] # type: ignore[call-overload]
br_y = priors[..., 1] + flatten_bbox_preds[..., 3] # type: ignore[call-overload]
bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
scores = flatten_cls_scores
if not with_nms:
return bboxes, scores

return multiclass_nms(
bboxes,
scores,
max_output_boxes_per_class=200,
iou_threshold=cfg.nms.iou_threshold,
score_threshold=cfg.score_thr,
pre_top_k=5000,
keep_top_k=cfg.max_per_img,
)

def get_targets( # type: ignore[override]
self,
cls_scores: Tensor,
Expand Down
8 changes: 4 additions & 4 deletions src/otx/algo/detection/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,17 @@ def _select_nms_index(
batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1)
batched_labels = batched_labels.where((batch_inds == batch_template.unsqueeze(1)), batched_labels.new_ones(1) * -1)

batch_size = batched_dets.shape[0]
new_batch_size = batched_dets.shape[0]

# expand tensor to eliminate [0, ...] tensor
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((batch_size, 1, 5))), 1)
batched_labels = torch.cat((batched_labels, batched_labels.new_zeros((batch_size, 1))), 1)
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((new_batch_size, 1, 5))), 1)
batched_labels = torch.cat((batched_labels, batched_labels.new_zeros((new_batch_size, 1))), 1)
if output_index and pre_inds is not None:
# batch all
pre_inds = pre_inds[batch_inds, box_inds]
pre_inds = pre_inds.unsqueeze(0).repeat(batch_size, 1)
pre_inds = pre_inds.where((batch_inds == batch_template.unsqueeze(1)), pre_inds.new_zeros(1))
pre_inds = torch.cat((pre_inds, -pre_inds.new_ones((batch_size, 1))), 1)
pre_inds = torch.cat((pre_inds, -pre_inds.new_ones((new_batch_size, 1))), 1)
# sort
is_use_topk = keep_top_k > 0 and (torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
if is_use_topk:
Expand Down
8 changes: 3 additions & 5 deletions src/otx/algo/detection/rtmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,19 @@ def _exporter(self) -> OTXModelExporter:
std=self.std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
swap_rgb=True,
via_onnx=True,
onnx_export_configuration={
"input_names": ["image"],
"output_names": ["boxes", "labels", "masks"],
"output_names": ["boxes", "labels"],
"dynamic_axes": {
"image": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_dets"},
"labels": {0: "batch", 1: "num_dets"},
"masks": {0: "batch", 1: "num_dets", 2: "height", 3: "width"},
},
"opset_version": 11,
"autograd_inlining": False,
},
output_names=["bboxes", "labels", "masks", "feature_vector", "saliency_map"] if self.explain_mode else None,
output_names=["bboxes", "labels", "feature_vector", "saliency_map"] if self.explain_mode else None,
)

def forward_for_tracing(
Expand Down
4 changes: 4 additions & 0 deletions src/otx/algo/detection/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ def clip_bboxes_export(
Returns:
tuple(Tensor): The clipped x1, y1, x2, y2.
"""
if len(max_shape) != 2:
msg = "`max_shape` should be [h, w]."
raise ValueError(msg)

if isinstance(max_shape, torch.Tensor):
# scale by 1/max_shape
x1 = x1 / max_shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def _customize_outputs(
if label_shift:
log.warning(f"label_shift: {label_shift}")

for output in outputs:
for i, output in enumerate(outputs):
output_objects = output.objects
if len(output_objects):
bbox = [[output.xmin, output.ymin, output.xmax, output.ymax] for output in output_objects]
Expand All @@ -564,7 +564,7 @@ def _customize_outputs(
tv_tensors.BoundingBoxes(
bbox,
format="XYXY",
canvas_size=inputs.imgs_info[-1].img_shape,
canvas_size=inputs.imgs_info[i].img_shape,
device=self.device,
),
)
Expand Down
12 changes: 7 additions & 5 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,18 +381,20 @@ def update_ov_subset_pipeline(self, datamodule: OTXDataModule, subset: str = "te
OTXDataModule: The modified OTXDataModule object with OpenVINO subset transforms applied.
"""
data_configuration = datamodule.config
ov_test_config = self._load_default_config(model_name="openvino_model")["data"]["config"][f"{subset}_subset"]
ov_config = self._load_default_config(model_name="openvino_model")["data"]["config"]
subset_config = getattr(data_configuration, f"{subset}_subset")
subset_config.batch_size = ov_test_config["batch_size"]
subset_config.transform_lib_type = ov_test_config["transform_lib_type"]
subset_config.transforms = ov_test_config["transforms"]
subset_config.to_tv_image = ov_test_config["to_tv_image"]
subset_config.batch_size = ov_config[f"{subset}_subset"]["batch_size"]
subset_config.transform_lib_type = ov_config[f"{subset}_subset"]["transform_lib_type"]
subset_config.transforms = ov_config[f"{subset}_subset"]["transforms"]
subset_config.to_tv_image = ov_config[f"{subset}_subset"]["to_tv_image"]
data_configuration.image_color_channel = ov_config["image_color_channel"]
data_configuration.tile_config.enable_tiler = False
msg = (
f"For OpenVINO IR models, Update the following {subset} \n"
f"\t transforms: {subset_config.transforms} \n"
f"\t transform_lib_type: {subset_config.transform_lib_type} \n"
f"\t batch_size: {subset_config.batch_size} \n"
f"\t image_color_channel: {data_configuration.image_color_channel} \n"
"And the tiler is disabled."
)
warn(msg, stacklevel=1)
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/algo/detection/heads/test_rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ def test_loss_by_feat_single(self, rtmdet_head) -> None:
assert loss_cls is not None
assert loss_bbox is not None

def test_export_by_feat(self, mocker, rtmdet_head) -> None:
batch_size = 2
num_priors = 1
num_classes = 80
cls_scores = [torch.rand(batch_size, num_priors * num_classes, 20, 20) for _ in range(3)]
bbox_preds = [torch.rand(batch_size, num_priors * 4, 20, 20) for _ in range(3)]
batch_img_metas = [{"img_shape": (320, 320, 3), "scale_factor": 1.0} for _ in range(2)]
mocker_multiclass_nms = mocker.patch(
"otx.algo.detection.heads.rtmdet_head.multiclass_nms",
return_value=(torch.rand(2, 300, 5), torch.randint(0, 80, (2, 300))),
)

bboxes, scores = rtmdet_head.export_by_feat(cls_scores, bbox_preds, batch_img_metas)

# Verify that the multiclass_nms function was called
mocker_multiclass_nms.assert_called_once()

# Check the shape of the output
assert bboxes.shape[0] == 2 # batch size
assert bboxes.shape[1] == 300 # max_per_img
assert bboxes.shape[2] == 5 # 4 bbox coordinates + score
assert scores.shape[0] == 2 # batch size
assert scores.shape[1] == 300 # max_per_img

def test_get_anchors(self, rtmdet_head) -> None:
featmap_sizes = [(40, 40), (20, 20), (10, 10)]
batch_img_metas = [{"img_shape": (320, 320, 3)} for _ in range(2)]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algo/detection/test_rtmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_exporter(self) -> None:
otx_rtmdet_tiny = RTMDetTiny(label_info=3)
otx_rtmdet_tiny_exporter = otx_rtmdet_tiny._exporter
assert isinstance(otx_rtmdet_tiny_exporter, OTXNativeModelExporter)
assert otx_rtmdet_tiny_exporter.swap_rgb is False
assert otx_rtmdet_tiny_exporter.swap_rgb is True

@pytest.mark.parametrize("model", [RTMDetTiny(3)])
def test_loss(self, model, fxt_data_module):
Expand Down
Loading