Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable DINO to OTX - Step 1. Enable Deformable DETR to OTX #2249

Merged
merged 6 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file.
- Support encrypted dataset training (<https://github.com/openvinotoolkit/training_extensions/pull/2209>)
- Add custom max iou assigner to prevent CPU OOM when large annotations are used (<https://github.com/openvinotoolkit/training_extensions/pull/2228>)
- Auto train type detection for Semi-SL, Self-SL and Incremental: "--train-type" now is optional (https://github.com/openvinotoolkit/training_extensions/pull/2195)
- Add new object detector Deformable DETR (<https://github.com/openvinotoolkit/training_extensions/pull/2249>)

### Enhancements

Expand Down
2 changes: 2 additions & 0 deletions otx/algorithms/common/adapters/mmcv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from .nncf.hooks import CompressionHook
from .nncf.runners import AccuracyAwareRunner
from .ops import multi_scale_deformable_attn_pytorch
from .runner import EpochRunnerWithCancel, IterBasedRunnerWithCancel

__all__ = [
Expand All @@ -57,4 +58,5 @@
"CompressionHook",
"AccuracyAwareRunner",
"TwoCropTransformHook",
"multi_scale_deformable_attn_pytorch",
]
8 changes: 8 additions & 0 deletions otx/algorithms/common/adapters/mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Initial file for mmcv ops."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .multi_scale_deformable_attn_pytorch import multi_scale_deformable_attn_pytorch

__all__ = ["multi_scale_deformable_attn_pytorch"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Custom patch of multi_scale_deformable_attn_pytorch for openvino export."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import torch
import torch.nn.functional as F
from mmcv.ops import multi_scale_deform_attn


def multi_scale_deformable_attn_pytorch(
value: torch.Tensor,
value_spatial_shapes: torch.Tensor,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor,
) -> torch.Tensor:
"""Custom patch for multi_scale_deformable_attn_pytorch function.

Original implementation in mmcv.ops use torch.nn.functional.grid_sample.
harimkang marked this conversation as resolved.
Show resolved Hide resolved
It raises errors during inference with OpenVINO exported model.
Therefore this function change grid_sample function to _custom_grid_sample function.
"""

bs, _, num_heads, embed_dims = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level, (H_, W_) in enumerate(value_spatial_shapes):
# bs, H_*W_, num_heads, embed_dims ->
# bs, H_*W_, num_heads*embed_dims ->
# bs, num_heads*embed_dims, H_*W_ ->
# bs*num_heads, embed_dims, H_, W_
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
# bs, num_queries, num_heads, num_points, 2 ->
# bs, num_heads, num_queries, num_points, 2 ->
# bs*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
# bs*num_heads, embed_dims, num_queries, num_points
sampling_value_l_ = _custom_grid_sample(
value_l_,
sampling_grid_l_,
# mode='bilinear',
# padding_mode='zeros',
align_corners=False,
)
sampling_value_list.append(sampling_value_l_)
# (bs, num_queries, num_heads, num_levels, num_points) ->
# (bs, num_heads, num_queries, num_levels, num_points) ->
# (bs, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
bs * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(bs, num_heads * embed_dims, num_queries)
)
return output.transpose(1, 2).contiguous()


def _custom_grid_sample(im: torch.Tensor, grid: torch.Tensor, align_corners: bool = False) -> torch.Tensor:
"""Custom patch for mmcv.ops.point_sample.bilinear_grid_sample.

This function is almost same with mmcv.ops.point_sample.bilinear_grid_sample.
The only difference is this function use reshape instead of view.

Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners (bool): If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the input’s
corner pixels. If set to False, they are instead considered as
referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.

Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn

x = grid[:, :, :, 0]
y = grid[:, :, :, 1]

if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2

x = x.reshape(n, -1)
y = y.reshape(n, -1)

x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1

wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)

# Apply default for grid_sample function zero padding
im_padded = F.pad(im, pad=[1, 1, 1, 1], mode="constant", value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

# Clip coordinates to padded image size
x0 = torch.where(x0 < 0, torch.tensor(0), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
x1 = torch.where(x1 < 0, torch.tensor(0), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
y0 = torch.where(y0 < 0, torch.tensor(0), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
y1 = torch.where(y1 < 0, torch.tensor(0), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

im_padded = im_padded.view(n, c, -1)

x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)

return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)


multi_scale_deform_attn.multi_scale_deformable_attn_pytorch = multi_scale_deformable_attn_pytorch
4 changes: 3 additions & 1 deletion otx/algorithms/common/adapters/mmdeploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Adapters for mmdeploy."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# SPDX-License-Identifier: MIT

from .ops import squeeze__default
from .utils.mmdeploy import is_mmdeploy_enabled

__all__ = [
"squeeze__default",
"is_mmdeploy_enabled",
]
8 changes: 8 additions & 0 deletions otx/algorithms/common/adapters/mmdeploy/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Initial file for mmdeploy ops."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .custom_ops import squeeze__default

__all__ = ["squeeze__default"]
39 changes: 39 additions & 0 deletions otx/algorithms/common/adapters/mmdeploy/ops/custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Custom patch of mmdeploy ops for openvino export."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import torch
from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import get_ir_config
from torch.onnx import symbolic_helper

# Remove previous registered symbolic
SYMBOLIC_REWRITER._registry._rewrite_records["squeeze"] = list()


@SYMBOLIC_REWRITER.register_symbolic("squeeze", is_pytorch=True)
def squeeze__default(ctx, g, self, dim=None):
harimkang marked this conversation as resolved.
Show resolved Hide resolved
"""Register default symbolic function for `squeeze`.

squeeze might be exported with IF node in ONNX, which is not supported in
lots of backend.

mmdeploy 0.x version do not support opset13 version squeeze, therefore this function is for
custom patch for supporting opset13 version squeeze.

If we adapt mmdeploy1.x version, then this function is no longer needed.
"""
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [symbolic_helper._get_const(dim, "i", "dim")]

if get_ir_config(ctx.cfg).get("opset_version", 11) >= 13:
axes = g.op("Constant", value_t=torch.tensor(dims, dtype=torch.long))
return g.op("Squeeze", self, axes)

return g.op("Squeeze", self, axes_i=dims)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

from .custom_atss_detector import CustomATSS
from .custom_deformable_detr_detector import CustomDeformableDETR
from .custom_maskrcnn_detector import CustomMaskRCNN
from .custom_maskrcnn_tile_optimized import CustomMaskRCNNTileOptimized
from .custom_single_stage_detector import CustomSingleStageDetector
Expand All @@ -16,6 +17,7 @@

__all__ = [
"CustomATSS",
"CustomDeformableDETR",
"CustomMaskRCNN",
"CustomSingleStageDetector",
"CustomTwoStageDetector",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""OTX Deformable DETR Class for mmdetection detectors."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.deformable_detr import DeformableDETR

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
ActivationMapHook,
FeatureVectorHook,
)
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled


@DETECTORS.register_module()
class CustomDeformableDETR(DeformableDETR):
"""Custom Deformable DETR with task adapt.

Deformable DETR does not support task adapt, so it just take task_adpat argument.
"""

def __init__(self, *args, task_adapt=None, **kwargs):
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
self.task_adapt = task_adapt


if is_mmdeploy_enabled():
from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
"otx.algorithms.detection.adapters.mmdet.models.detectors.custom_deformable_detr_detector.CustomDeformableDETR.simple_test"
)
def custom_deformable_detr__simple_test(ctx, self, img, img_metas, **kwargs):
"""Function for custom_deformable_detr__simple_test."""
height = int(img_metas[0]["img_shape"][0])
width = int(img_metas[0]["img_shape"][1])
img_metas[0]["batch_input_shape"] = (height, width)
img_metas[0]["img_shape"] = (height, width, 3)
feat = self.extract_feat(img)
outs = self.bbox_head(feat, img_metas)
bbox_results = self.bbox_head.get_bboxes(*outs, img_metas=img_metas, **kwargs)

if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = ActivationMapHook.func(cls_scores)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
4 changes: 3 additions & 1 deletion otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from mmdet import __version__
from mmdet.apis import single_gpu_test, train_detector
from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor
from mmdet.models.detectors import TwoStageDetector
from mmdet.models.detectors import DETR, TwoStageDetector
from mmdet.utils import collect_env

from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
Expand Down Expand Up @@ -401,6 +401,8 @@ def hook(module, inp, outp):
if isinstance(raw_model, TwoStageDetector):
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width))
elif isinstance(raw_model, DETR):
saliency_hook = ActivationMapHook(feature_model)
else:
saliency_hook = DetClassProbabilityMapHook(feature_model)

Expand Down
Loading