diff --git a/CHANGELOG.md b/CHANGELOG.md index cd63ad8702a..220ebddac28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file. - Support encrypted dataset training () - Add custom max iou assigner to prevent CPU OOM when large annotations are used () - Auto train type detection for Semi-SL, Self-SL and Incremental: "--train-type" now is optional (https://github.com/openvinotoolkit/training_extensions/pull/2195) +- Add new object detector Deformable DETR () ### Enhancements diff --git a/otx/algorithms/common/adapters/mmcv/__init__.py b/otx/algorithms/common/adapters/mmcv/__init__.py index 572eadc7f25..10276590646 100644 --- a/otx/algorithms/common/adapters/mmcv/__init__.py +++ b/otx/algorithms/common/adapters/mmcv/__init__.py @@ -34,6 +34,7 @@ ) from .nncf.hooks import CompressionHook from .nncf.runners import AccuracyAwareRunner +from .ops import multi_scale_deformable_attn_pytorch from .runner import EpochRunnerWithCancel, IterBasedRunnerWithCancel __all__ = [ @@ -57,4 +58,5 @@ "CompressionHook", "AccuracyAwareRunner", "TwoCropTransformHook", + "multi_scale_deformable_attn_pytorch", ] diff --git a/otx/algorithms/common/adapters/mmcv/ops/__init__.py b/otx/algorithms/common/adapters/mmcv/ops/__init__.py new file mode 100644 index 00000000000..eeae7de787f --- /dev/null +++ b/otx/algorithms/common/adapters/mmcv/ops/__init__.py @@ -0,0 +1,8 @@ +"""Initial file for mmcv ops.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from .multi_scale_deformable_attn_pytorch import multi_scale_deformable_attn_pytorch + +__all__ = ["multi_scale_deformable_attn_pytorch"] diff --git a/otx/algorithms/common/adapters/mmcv/ops/multi_scale_deformable_attn_pytorch.py b/otx/algorithms/common/adapters/mmcv/ops/multi_scale_deformable_attn_pytorch.py new file mode 100644 index 00000000000..a2f4d796731 --- /dev/null +++ b/otx/algorithms/common/adapters/mmcv/ops/multi_scale_deformable_attn_pytorch.py @@ -0,0 +1,140 @@ +"""Custom patch of multi_scale_deformable_attn_pytorch for openvino export.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import torch +import torch.nn.functional as F +from mmcv.ops import multi_scale_deform_attn + + +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: + """Custom patch for multi_scale_deformable_attn_pytorch function. + + Original implementation in mmcv.ops use torch.nn.functional.grid_sample. + It raises errors during inference with OpenVINO exported model. + Therefore this function change grid_sample function to _custom_grid_sample function. + """ + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = _custom_grid_sample( + value_l_, + sampling_grid_l_, + # mode='bilinear', + # padding_mode='zeros', + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +def _custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: bool = False) -> torch.Tensor: + """Custom patch for mmcv.ops.point_sample.bilinear_grid_sample. + + This function is almost same with mmcv.ops.point_sample.bilinear_grid_sample. + The only difference is this function use reshape instead of view. + + Args: + im (torch.Tensor): Input feature map, shape (N, C, H, W) + grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) + align_corners (bool): If set to True, the extrema (-1 and 1) are + considered as referring to the center points of the input’s + corner pixels. If set to False, they are instead considered as + referring to the corner points of the input’s corner pixels, + making the sampling more resolution agnostic. + + Returns: + torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) + """ + n, c, h, w = im.shape + gn, gh, gw, _ = grid.shape + assert n == gn + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + + if align_corners: + x = ((x + 1) / 2) * (w - 1) + y = ((y + 1) / 2) * (h - 1) + else: + x = ((x + 1) * w - 1) / 2 + y = ((y + 1) * h - 1) / 2 + + x = x.reshape(n, -1) + y = y.reshape(n, -1) + + x0 = torch.floor(x).long() + y0 = torch.floor(y).long() + x1 = x0 + 1 + y1 = y0 + 1 + + wa = ((x1 - x) * (y1 - y)).unsqueeze(1) + wb = ((x1 - x) * (y - y0)).unsqueeze(1) + wc = ((x - x0) * (y1 - y)).unsqueeze(1) + wd = ((x - x0) * (y - y0)).unsqueeze(1) + + # Apply default for grid_sample function zero padding + im_padded = F.pad(im, pad=[1, 1, 1, 1], mode="constant", value=0) + padded_h = h + 2 + padded_w = w + 2 + # save points positions after padding + x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 + + # Clip coordinates to padded image size + x0 = torch.where(x0 < 0, torch.tensor(0), x0) + x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0) + x1 = torch.where(x1 < 0, torch.tensor(0), x1) + x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1) + y0 = torch.where(y0 < 0, torch.tensor(0), y0) + y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0) + y1 = torch.where(y1 < 0, torch.tensor(0), y1) + y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1) + + im_padded = im_padded.view(n, c, -1) + + x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) + x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) + + Ia = torch.gather(im_padded, 2, x0_y0) + Ib = torch.gather(im_padded, 2, x0_y1) + Ic = torch.gather(im_padded, 2, x1_y0) + Id = torch.gather(im_padded, 2, x1_y1) + + return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) + + +multi_scale_deform_attn.multi_scale_deformable_attn_pytorch = multi_scale_deformable_attn_pytorch diff --git a/otx/algorithms/common/adapters/mmdeploy/__init__.py b/otx/algorithms/common/adapters/mmdeploy/__init__.py index 01687800428..f0dceb2b560 100644 --- a/otx/algorithms/common/adapters/mmdeploy/__init__.py +++ b/otx/algorithms/common/adapters/mmdeploy/__init__.py @@ -1,10 +1,12 @@ """Adapters for mmdeploy.""" # Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 # -# SPDX-License-Identifier: MIT +from .ops import squeeze__default from .utils.mmdeploy import is_mmdeploy_enabled __all__ = [ + "squeeze__default", "is_mmdeploy_enabled", ] diff --git a/otx/algorithms/common/adapters/mmdeploy/ops/__init__.py b/otx/algorithms/common/adapters/mmdeploy/ops/__init__.py new file mode 100644 index 00000000000..1e8ac2cb03e --- /dev/null +++ b/otx/algorithms/common/adapters/mmdeploy/ops/__init__.py @@ -0,0 +1,8 @@ +"""Initial file for mmdeploy ops.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from .custom_ops import squeeze__default + +__all__ = ["squeeze__default"] diff --git a/otx/algorithms/common/adapters/mmdeploy/ops/custom_ops.py b/otx/algorithms/common/adapters/mmdeploy/ops/custom_ops.py new file mode 100644 index 00000000000..c64fbdcac2a --- /dev/null +++ b/otx/algorithms/common/adapters/mmdeploy/ops/custom_ops.py @@ -0,0 +1,39 @@ +"""Custom patch of mmdeploy ops for openvino export.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import torch +from mmdeploy.core import SYMBOLIC_REWRITER +from mmdeploy.utils import get_ir_config +from torch.onnx import symbolic_helper + +# Remove previous registered symbolic +SYMBOLIC_REWRITER._registry._rewrite_records["squeeze"] = list() + + +@SYMBOLIC_REWRITER.register_symbolic("squeeze", is_pytorch=True) +def squeeze__default(ctx, g, self, dim=None): + """Register default symbolic function for `squeeze`. + + squeeze might be exported with IF node in ONNX, which is not supported in + lots of backend. + + mmdeploy 0.x version do not support opset13 version squeeze, therefore this function is for + custom patch for supporting opset13 version squeeze. + + If we adapt mmdeploy1.x version, then this function is no longer needed. + """ + if dim is None: + dims = [] + for i, size in enumerate(self.type().sizes()): + if size == 1: + dims.append(i) + else: + dims = [symbolic_helper._get_const(dim, "i", "dim")] + + if get_ir_config(ctx.cfg).get("opset_version", 11) >= 13: + axes = g.op("Constant", value_t=torch.tensor(dims, dtype=torch.long)) + return g.op("Squeeze", self, axes) + + return g.op("Squeeze", self, axes_i=dims) diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/__init__.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/__init__.py index 029d4523207..962407bb091 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/detectors/__init__.py +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/__init__.py @@ -4,6 +4,7 @@ # from .custom_atss_detector import CustomATSS +from .custom_deformable_detr_detector import CustomDeformableDETR from .custom_maskrcnn_detector import CustomMaskRCNN from .custom_maskrcnn_tile_optimized import CustomMaskRCNNTileOptimized from .custom_single_stage_detector import CustomSingleStageDetector @@ -16,6 +17,7 @@ __all__ = [ "CustomATSS", + "CustomDeformableDETR", "CustomMaskRCNN", "CustomSingleStageDetector", "CustomTwoStageDetector", diff --git a/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py new file mode 100644 index 00000000000..6535d89c398 --- /dev/null +++ b/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_deformable_detr_detector.py @@ -0,0 +1,51 @@ +"""OTX Deformable DETR Class for mmdetection detectors.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from mmdet.models.builder import DETECTORS +from mmdet.models.detectors.deformable_detr import DeformableDETR + +from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( + ActivationMapHook, + FeatureVectorHook, +) +from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled + + +@DETECTORS.register_module() +class CustomDeformableDETR(DeformableDETR): + """Custom Deformable DETR with task adapt. + + Deformable DETR does not support task adapt, so it just take task_adpat argument. + """ + + def __init__(self, *args, task_adapt=None, **kwargs): + super().__init__(*args, **kwargs) + self.task_adapt = task_adapt + + +if is_mmdeploy_enabled(): + from mmdeploy.core import FUNCTION_REWRITER + + @FUNCTION_REWRITER.register_rewriter( + "otx.algorithms.detection.adapters.mmdet.models.detectors.custom_deformable_detr_detector.CustomDeformableDETR.simple_test" + ) + def custom_deformable_detr__simple_test(ctx, self, img, img_metas, **kwargs): + """Function for custom_deformable_detr__simple_test.""" + height = int(img_metas[0]["img_shape"][0]) + width = int(img_metas[0]["img_shape"][1]) + img_metas[0]["batch_input_shape"] = (height, width) + img_metas[0]["img_shape"] = (height, width, 3) + feat = self.extract_feat(img) + outs = self.bbox_head(feat, img_metas) + bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, **kwargs) + + if ctx.cfg["dump_features"]: + feature_vector = FeatureVectorHook.func(feat) + cls_scores = outs[0] + saliency_map = ActivationMapHook.func(cls_scores) + return (*bbox_results, feature_vector, saliency_map) + + return bbox_results diff --git a/otx/algorithms/detection/adapters/mmdet/task.py b/otx/algorithms/detection/adapters/mmdet/task.py index 7b005794964..1a07013fd73 100644 --- a/otx/algorithms/detection/adapters/mmdet/task.py +++ b/otx/algorithms/detection/adapters/mmdet/task.py @@ -29,7 +29,7 @@ from mmdet import __version__ from mmdet.apis import single_gpu_test, train_detector from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor -from mmdet.models.detectors import TwoStageDetector +from mmdet.models.detectors import DETR, TwoStageDetector from mmdet.utils import collect_env from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( @@ -401,6 +401,8 @@ def hook(module, inp, outp): if isinstance(raw_model, TwoStageDetector): height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"] saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width)) + elif isinstance(raw_model, DETR): + saliency_hook = ActivationMapHook(feature_model) else: saliency_hook = DetClassProbabilityMapHook(feature_model) diff --git a/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/data_pipeline.py b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/data_pipeline.py new file mode 100644 index 00000000000..4b577e21eb8 --- /dev/null +++ b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/data_pipeline.py @@ -0,0 +1,115 @@ +"""Data pipeline for Deformable DETR.""" +# dataset settings +dataset_type = "CocoDataset" +data_root = "data/coco/" +img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different +# from the default setting in mmdet. +train_pipeline = [ + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations", with_bbox=True), + dict(type="RandomFlip", flip_ratio=0.5), + dict( + type="AutoAugment", + policies=[ + [ + dict( + type="Resize", + img_scale=[ + (480, 1333), + (512, 1333), + (544, 1333), + (576, 1333), + (608, 1333), + (640, 1333), + (672, 1333), + (704, 1333), + (736, 1333), + (768, 1333), + (800, 1333), + ], + multiscale_mode="value", + keep_ratio=True, + ) + ], + [ + dict( + type="Resize", + # The radio of all image in train dataset < 7 + # follow the original impl + img_scale=[(400, 4200), (500, 4200), (600, 4200)], + multiscale_mode="value", + keep_ratio=True, + ), + dict(type="RandomCrop", crop_type="absolute_range", crop_size=(384, 600), allow_negative_crop=True), + dict( + type="Resize", + img_scale=[ + (480, 1333), + (512, 1333), + (544, 1333), + (576, 1333), + (608, 1333), + (640, 1333), + (672, 1333), + (704, 1333), + (736, 1333), + (768, 1333), + (800, 1333), + ], + multiscale_mode="value", + override=True, + keep_ratio=True, + ), + ], + ], + ), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size_divisor=1), + dict(type="DefaultFormatBundle"), + dict(type="Collect", keys=["img", "gt_bboxes", "gt_labels"]), +] +# test_pipeline, NOTE the Pad's size_divisor is different from the default +# setting (size_divisor=32). While there is little effect on the performance +# whether we use the default setting or use size_divisor=1. +test_pipeline = [ + dict(type="LoadImageFromFile"), + dict( + type="MultiScaleFlipAug", + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size_divisor=1), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + filter_empty_gt=False, + ann_file=data_root + "annotations/instances_train2017.json", + img_prefix=data_root + "train2017/", + pipeline=train_pipeline, + ), + val=dict( + type=dataset_type, + ann_file=data_root + "annotations/instances_val2017.json", + img_prefix=data_root + "val2017/", + pipeline=test_pipeline, + ), + test=dict( + type=dataset_type, + ann_file=data_root + "annotations/instances_val2017.json", + img_prefix=data_root + "val2017/", + pipeline=test_pipeline, + ), +) +evaluation = dict(interval=1, metric="bbox") diff --git a/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/deployment.py b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/deployment.py new file mode 100644 index 00000000000..76b4a6544f5 --- /dev/null +++ b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/deployment.py @@ -0,0 +1,12 @@ +"""MMDeploy config of Deformable DETR model for Detection Task.""" + +_base_ = ["../../base/deployments/base_detection_dynamic.py"] + +ir_config = dict( + output_names=["boxes", "labels"], + opset_version=16, +) + +backend_config = dict( + model_inputs=[dict(opt_shapes=dict(input=[-1, 3, 800, 1333]))], +) diff --git a/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py new file mode 100644 index 00000000000..bebbe0b9c80 --- /dev/null +++ b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/model.py @@ -0,0 +1,109 @@ +"""Model config for Deformable DETR.""" +model = dict( + type="CustomDeformableDETR", + backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type="BN", requires_grad=False), + norm_eval=True, + style="pytorch", + init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), + ), + neck=dict( + type="ChannelMapper", + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type="GN", num_groups=32), + num_outs=4, + ), + bbox_head=dict( + type="DeformableDETRHead", + num_query=300, + num_classes=80, + in_channels=2048, + sync_cls_avg_factor=True, + with_box_refine=True, + as_two_stage=True, + transformer=dict( + type="DeformableDetrTransformer", + encoder=dict( + type="DetrTransformerEncoder", + num_layers=6, + transformerlayers=dict( + type="BaseTransformerLayer", + attn_cfgs=dict(type="MultiScaleDeformableAttention", embed_dims=256), + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + ), + decoder=dict( + type="DeformableDetrTransformerDecoder", + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type="DetrTransformerDecoderLayer", + attn_cfgs=[ + dict(type="MultiheadAttention", embed_dims=256, num_heads=8, dropout=0.1), + dict(type="MultiScaleDeformableAttention", embed_dims=256), + ], + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), + ), + ), + ), + positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True, offset=-0.5), + loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0), + loss_bbox=dict(type="L1Loss", loss_weight=5.0), + loss_iou=dict(type="GIoULoss", loss_weight=2.0), + ), + # training and testing settings + train_cfg=dict( + assigner=dict( + type="HungarianAssigner", + cls_cost=dict(type="FocalLossCost", weight=2.0), + reg_cost=dict(type="BBoxL1Cost", weight=5.0, box_format="xywh"), + iou_cost=dict(type="IoUCost", iou_mode="giou", weight=2.0), + ) + ), + test_cfg=dict(max_per_img=100), +) +# optimizer +optimizer = dict( + type="AdamW", + lr=2e-4, + weight_decay=0.0001, + paramwise_cfg=dict( + custom_keys={ + "backbone": dict(lr_mult=0.1), + "sampling_offsets": dict(lr_mult=0.1), + "reference_points": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) +# learning policy +lr_config = dict(policy="step", step=[10]) +runner = dict(type="EpochRunnerWithCancel", max_epochs=12) +load_from = "https://download.openmmlab.com/mmdetection/v2.0/deformable_detr/\ +deformable_detr_twostage_refine_r50_16x2_50e_coco/\ +deformable_detr_twostage_refine_r50_16x2_50e_coco_20210419_220613-9d28ab72.pth" +resume_from = None + +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=100, + hooks=[ + dict(type="TextLoggerHook"), + ], +) +log_level = "INFO" +workflow = [("train", 1)] +task_adapt = dict(op="REPLACE", type="temp", efficient_mode=False, use_mpa_anchor=False) diff --git a/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/template_experimental.yaml b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/template_experimental.yaml new file mode 100644 index 00000000000..f585c68c591 --- /dev/null +++ b/otx/algorithms/detection/configs/detection/resnet50_deformable-detr/template_experimental.yaml @@ -0,0 +1,64 @@ +# Description. +model_template_id: Custom_Object_Detection_Gen3_Deformable_DETR +name: Deformable_DETR +task_type: DETECTION +task_family: VISION +instantiation: "CLASS" +summary: Class-Incremental Object Detection for Deformable_DETR +application: ~ + +# Algo backend. +framework: OTXDetection v2.9.1 + +# Task implementations. +entrypoints: + base: otx.algorithms.detection.adapters.mmdet.task.MMDetectionTask + openvino: otx.algorithms.detection.adapters.openvino.task.OpenVINODetectionTask + nncf: otx.algorithms.detection.adapters.mmdet.nncf.task.DetectionNNCFTask + +# Capabilities. +capabilities: + - compute_representations + +# Hyperparameters. +hyper_parameters: + base_path: ../configuration.yaml + parameter_overrides: + learning_parameters: + batch_size: + default_value: 2 + auto_hpo_state: POSSIBLE + learning_rate: + default_value: 0.0002 + auto_hpo_state: POSSIBLE + learning_rate_warmup_iters: + default_value: 3 + num_iters: + default_value: 12 + nncf_optimization: + enable_quantization: + default_value: true + enable_pruning: + default_value: false + pruning_supported: + default_value: true + maximal_accuracy_degradation: + default_value: 1.0 + algo_backend: + train_type: + default_value: Incremental + +# Training resources. +max_nodes: 1 +training_targets: + - GPU + - CPU + +# Stats. +gigaflops: ??? +size: ??? +# # Inference options. Defined by OpenVINO capabilities, not Algo Backend or Platform. +# inference_targets: +# - CPU +# - GPU +# - VPU diff --git a/tests/integration/cli/detection/test_detection.py b/tests/integration/cli/detection/test_detection.py index 883aef94c94..18a2d65f230 100644 --- a/tests/integration/cli/detection/test_detection.py +++ b/tests/integration/cli/detection/test_detection.py @@ -68,10 +68,18 @@ templates = Registry("otx/algorithms/detection").filter(task_type="DETECTION").templates templates_ids = [template.model_template_id for template in templates] +experimental_template = parse_model_template( + "otx/algorithms/detection/configs/detection/resnet50_deformable-detr/template_experimental.yaml" +) +experimental_template_id = experimental_template.model_template_id + +templates_w_experimental = templates + [experimental_template] +templates_ids_w_experimental = templates_ids + [experimental_template_id] + class TestDetectionCLI: @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_train(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_train_testing(template, tmp_dir_path, otx_dir, args) @@ -90,26 +98,26 @@ def test_otx_resume(self, template, tmp_dir_path): otx_resume_testing(template, tmp_dir_path, otx_dir, args1) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) @pytest.mark.parametrize("dump_features", [True, False]) def test_otx_export(self, template, tmp_dir_path, dump_features): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, dump_features, check_ir_meta=True) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_export_fp16(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, half_precision=True) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_export_onnx(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_export_testing(template, tmp_dir_path, half_precision=False, is_onnx=True) @e2e_pytest_component - @pytest.mark.parametrize("template", templates, ids=templates_ids) + @pytest.mark.parametrize("template", templates_w_experimental, ids=templates_ids_w_experimental) def test_otx_eval(self, template, tmp_dir_path): tmp_dir_path = tmp_dir_path / "detection" otx_eval_testing(template, tmp_dir_path, otx_dir, args) diff --git a/tests/unit/algorithms/common/adapters/mmcv/ops/__init__.py b/tests/unit/algorithms/common/adapters/mmcv/ops/__init__.py new file mode 100644 index 00000000000..1344d3cacb2 --- /dev/null +++ b/tests/unit/algorithms/common/adapters/mmcv/ops/__init__.py @@ -0,0 +1,4 @@ +"""Test for otx.algorithms.common.adapters.mmcv.ops""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/tests/unit/algorithms/common/adapters/mmcv/ops/test_multi_scale_deformable_attn_pytorch.py b/tests/unit/algorithms/common/adapters/mmcv/ops/test_multi_scale_deformable_attn_pytorch.py new file mode 100644 index 00000000000..9743b20a502 --- /dev/null +++ b/tests/unit/algorithms/common/adapters/mmcv/ops/test_multi_scale_deformable_attn_pytorch.py @@ -0,0 +1,20 @@ +"""Test for otx.algorithms.common.adapters.mmcv.ops.multi_scale_deformable_attn_pytorch.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import torch + +from otx.algorithms.common.adapters.mmcv.ops import multi_scale_deformable_attn_pytorch +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +@e2e_pytest_unit +def test_multi_scale_deformable_attn_pytorch(): + value = torch.randn([1, 22223, 8, 32]) + value_spatial_shapes = torch.tensor([[100, 167], [50, 84], [25, 42], [13, 21]]) + sampling_locations = torch.randn([1, 2223, 8, 4, 4, 2]) + attention_weights = torch.randn([1, 2223, 8, 4, 4]) + + out = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights) + assert out.shape == torch.Size([1, 2223, 256]) diff --git a/tests/unit/algorithms/common/adapters/mmdeploy/ops/__init__.py b/tests/unit/algorithms/common/adapters/mmdeploy/ops/__init__.py new file mode 100644 index 00000000000..ce553f5423b --- /dev/null +++ b/tests/unit/algorithms/common/adapters/mmdeploy/ops/__init__.py @@ -0,0 +1,4 @@ +"""Test for otx.algorithms.common.adapters.mmdeploy.ops""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# diff --git a/tests/unit/algorithms/common/adapters/mmdeploy/ops/test_custom_ops.py b/tests/unit/algorithms/common/adapters/mmdeploy/ops/test_custom_ops.py new file mode 100644 index 00000000000..46b7ef11a3e --- /dev/null +++ b/tests/unit/algorithms/common/adapters/mmdeploy/ops/test_custom_ops.py @@ -0,0 +1,54 @@ +"""Test for otx.algorithms.common.adapters.mmdeploy.ops.custom_ops.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import torch +from mmcv.utils import Config +from mmdeploy.core import SYMBOLIC_REWRITER + +from otx.algorithms.common.adapters.mmdeploy.ops.custom_ops import squeeze__default +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +@e2e_pytest_unit +def test_symbolic_registery(): + assert len(SYMBOLIC_REWRITER._registry._rewrite_records["squeeze"]) == 1 + + +class MockOps: + def op(self, *args, **kwargs): + return (args, kwargs) + + +@e2e_pytest_unit +def test_squeeze(mocker): + """Test squeeze__default function.""" + + class MockClass: + class _size: + def sizes(self): + return [1, 1, 1] + + size = _size() + + def type(self): + return self.size + + # Patching for squeeze op + mock_ctx = Config({"cfg": Config({"opset_version": 11})}) + mock_g = MockOps() + mock_self = MockClass() + mocker.patch("otx.algorithms.common.adapters.mmdeploy.ops.custom_ops.get_ir_config", return_value=mock_ctx.cfg) + op = squeeze__default(mock_ctx, mock_g, mock_self) + assert op[0][0] == "Squeeze" + assert op[1]["axes_i"] == [0, 1, 2] + + mock_ctx = Config({"cfg": Config({"opset_version": 13})}) + mock_g = MockOps() + mock_self = MockClass() + mocker.patch("otx.algorithms.common.adapters.mmdeploy.ops.custom_ops.get_ir_config", return_value=mock_ctx.cfg) + op = squeeze__default(mock_ctx, mock_g, mock_self) + assert op[0][0] == "Squeeze" + assert op[0][2][0][0] == "Constant" + assert torch.all(op[0][2][1]["value_t"] == torch.Tensor([0, 1, 2])) diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py index c5ebd807d90..52b50f2722d 100644 --- a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/conftest.py @@ -312,3 +312,86 @@ def fxt_cfg_custom_yolox(num_classes: int = 3): }, } return cfg + + +@pytest.fixture +def fxt_cfg_custom_deformable_detr(num_classes: int = 3): + return ConfigDict( + type="CustomDeformableDETR", + backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type="BN", requires_grad=False), + norm_eval=True, + style="pytorch", + init_cfg=dict(type="Pretrained", checkpoint="torchvision://resnet50"), + ), + neck=dict( + type="ChannelMapper", + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type="GN", num_groups=32), + num_outs=4, + ), + bbox_head=dict( + type="DeformableDETRHead", + num_query=300, + num_classes=80, + in_channels=2048, + sync_cls_avg_factor=True, + with_box_refine=True, + as_two_stage=True, + transformer=dict( + type="DeformableDetrTransformer", + encoder=dict( + type="DetrTransformerEncoder", + num_layers=6, + transformerlayers=dict( + type="BaseTransformerLayer", + attn_cfgs=dict(type="MultiScaleDeformableAttention", embed_dims=256), + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + ), + decoder=dict( + type="DeformableDetrTransformerDecoder", + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type="DetrTransformerDecoderLayer", + attn_cfgs=[ + dict(type="MultiheadAttention", embed_dims=256, num_heads=8, dropout=0.1), + dict(type="MultiScaleDeformableAttention", embed_dims=256), + ], + feedforward_channels=1024, + ffn_dropout=0.1, + operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), + ), + ), + ), + positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True, offset=-0.5), + loss_cls=dict(type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0), + loss_bbox=dict(type="L1Loss", loss_weight=5.0), + loss_iou=dict(type="GIoULoss", loss_weight=2.0), + ), + # training and testing settings + train_cfg=dict( + assigner=dict( + type="HungarianAssigner", + cls_cost=dict(type="FocalLossCost", weight=2.0), + reg_cost=dict(type="BBoxL1Cost", weight=5.0, box_format="xywh"), + iou_cost=dict(type="IoUCost", iou_mode="giou", weight=2.0), + ) + ), + test_cfg=dict(max_per_img=100), + task_adapt=dict( + src_classes=["person", "car"], + dst_classes=["tree", "car", "person"], + ), + ) diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py new file mode 100644 index 00000000000..ef3f0e8145a --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/detectors/test_custom_deformable_detr_detector.py @@ -0,0 +1,19 @@ +"""Test for CustomDeformableDETR Detector.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from mmdet.models.builder import build_detector + +from otx.algorithms.detection.adapters.mmdet.models.detectors.custom_deformable_detr_detector import ( + CustomDeformableDETR, +) +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +class TestCustomDeformableDETR: + @e2e_pytest_unit + def test_custom_deformable_detr_build(self, fxt_cfg_custom_deformable_detr): + model = build_detector(fxt_cfg_custom_deformable_detr) + assert isinstance(model, CustomDeformableDETR) + assert model.task_adapt is not None