Skip to content

Commit

Permalink
【Fix】Fix mmdet head (open-mmlab#290)
Browse files Browse the repository at this point in the history
* fix anchor head and base dense head

* fix base dense head bug

* fix base dense head bug

* fix pad

* add ssd model int8 and fp16 config

* fix a bug about basedensehead

* add config for yolov3 trt fp16 int8

Co-authored-by: maningsheng <[email protected]>
  • Loading branch information
VVsssssk and RunningLeon authored Dec 17, 2021
1 parent 3e8237d commit cde9abd
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 52 deletions.
12 changes: 12 additions & 0 deletions configs/mmdet/_base_/base_tensorrt-fp16_dynamic-300x300-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['./base_dynamic.py', '../../_base_/backends/tensorrt-fp16.py']

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 300, 300],
opt_shape=[1, 3, 300, 300],
max_shape=[1, 3, 512, 512])))
])
12 changes: 12 additions & 0 deletions configs/mmdet/_base_/base_tensorrt-int8_dynamic-300x300-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['./base_dynamic.py', '../../_base_/backends/tensorrt-int8.py']

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 300, 300],
opt_shape=[1, 3, 300, 300],
max_shape=[1, 3, 512, 512])))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = [
'../_base_/base_dynamic.py', '../../_base_/backends/tensorrt-fp16.py'
]

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 160, 160],
opt_shape=[1, 3, 608, 608],
max_shape=[1, 3, 608, 608])))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_base_ = ['../_base_/base_tensorrt-fp16_dynamic-300x300-512x512.py']
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = [
'../_base_/base_dynamic.py', '../../_base_/backends/tensorrt-int8.py'
]

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 160, 160],
opt_shape=[1, 3, 608, 608],
max_shape=[1, 3, 608, 608])))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_base_ = ['../_base_/base_tensorrt-int8_dynamic-300x300-512x512.py']
6 changes: 3 additions & 3 deletions mmdeploy/codebase/mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def anchor_head__get_bboxes__ncnn(ctx,
]
min_sizes = self.anchor_generator.base_sizes
max_sizes = min_sizes[1:] + \
img_metas['img_shape'][0:1].tolist()
img_height = img_metas['img_shape'][0].item()
img_width = img_metas['img_shape'][1].item()
img_metas[0]['img_shape'][0:1].tolist()
img_height = img_metas[0]['img_shape'][0].item()
img_width = img_metas[0]['img_shape'][1].item()

# if no reshape, concat will be error in ncnn.
mlvl_anchors = [
Expand Down
80 changes: 31 additions & 49 deletions mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def base_dense_head__get_bbox(ctx,
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
backend = get_backend(deploy_cfg)
Expand All @@ -68,33 +67,26 @@ def base_dense_head__get_bbox(ctx,

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]

assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
img_shape = img_metas[0]['img_shape']

cfg = self.test_cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
device = cls_scores[0].device
batch_size = cls_scores[0].shape[0]
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1), device=device, dtype=torch.long)
# e.g. Retina, FreeAnchor, etc.
if score_factors is None:
with_score_factors = False
mlvl_score_factor = [None for _ in range(num_levels)]
else:
# e.g. FCOS, PAA, ATSS, etc.
with_score_factors = True
mlvl_score_factor = [
score_factors[i].detach() for i in range(num_levels)
]
mlvl_score_factors = []
assert img_metas is not None
img_shape = img_metas[0]['img_shape']

assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
batch_size = cls_scores[0].shape[0]
cfg = self.test_cfg
pre_topk = cfg.get('nms_pre', -1)

mlvl_batch_bboxes = []
mlvl_scores = []
mlvl_valid_bboxes = []
mlvl_valid_scores = []
mlvl_valid_priors = []

for cls_score, bbox_pred, score_factors, priors in zip(
mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors):
Expand All @@ -108,7 +100,6 @@ def base_dense_head__get_bbox(ctx,
else:
scores = scores.softmax(-1)
nms_pre_score = scores

if with_score_factors:
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
Expand All @@ -117,57 +108,48 @@ def base_dense_head__get_bbox(ctx,
if not is_dynamic_flag:
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
# Get top-k predictions
from mmdet.core.export import get_k_for_topk
size = torch.tensor(bbox_pred.shape[1], device=device)
nms_pre = get_k_for_topk(nms_pre_tensor, size)

if nms_pre > 0:
if pre_topk > 0:
if with_score_factors:
nms_pre_score = (nms_pre_score * score_factors[..., None])
else:
nms_pre_score = nms_pre_score
if backend == Backend.TENSORRT:
priors = pad_with_value(priors, 1, nms_pre_tensor)
bbox_pred = pad_with_value(bbox_pred, 1, nms_pre_tensor)
nms_pre_score = pad_with_value(nms_pre_score, 1,
nms_pre_tensor, 0.)
priors = pad_with_value(priors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value(
score_factors.unsqueeze(2), 1, pre_topk, 0.)
else:
score_factors = score_factors.unsqueeze(2)
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(nms_pre)

_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size,
device=bbox_pred.device).view(-1,
1).expand_as(topk_inds).long()
device=bbox_pred.device).view(-1, 1).expand_as(topk_inds)
priors = priors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]

if with_score_factors:
score_factors = score_factors.unsqueeze(2)[batch_inds,
topk_inds, :]
score_factors = score_factors[batch_inds, topk_inds, :]

bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)

mlvl_batch_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)
mlvl_valid_priors.append(priors)
if with_score_factors:
mlvl_score_factors.append(score_factors)

batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1)
batch_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
batch_priors = torch.cat(mlvl_valid_priors, dim=1)
batch_bboxes = self.bbox_coder.decode(
batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape)
if with_score_factors:
batch_score_factors = torch.cat(mlvl_score_factors, dim=1)

# Replace multiclass_nms with ONNX::NonMaxSuppression in deployment

if not self.use_sigmoid_cls:
batch_scores = batch_scores[..., :self.num_classes]

Expand Down

0 comments on commit cde9abd

Please sign in to comment.