From d580a59f91d2e26b3f7bbf64b0bef0842f160897 Mon Sep 17 00:00:00 2001 From: Lrh Date: Thu, 28 Sep 2023 18:18:01 +0800 Subject: [PATCH] support batch inference --- .../mmdet/models/dense_heads/condinst_head.py | 47 +++++-------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py index fc2e00e22c..c2ef0423e8 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py @@ -3,7 +3,6 @@ import torch from mmdet.models.utils import aligned_bilinear -from mmdet.utils import InstanceList from mmengine.config import ConfigDict from torch import Tensor @@ -117,9 +116,6 @@ def condinst_mask_head__forward(self, x: tuple, param_preds = positive_infos['param_preds'] points = positive_infos['points'] strides = positive_infos['strides'] - param_preds = torch.stack(param_preds, dim=0) - points = torch.stack(points, dim=0) - strides = torch.stack(strides, dim=0) batch_size = points.shape[0] num_insts = points.shape[1] @@ -135,52 +131,35 @@ def condinst_mask_head__forward(self, x: tuple, rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float() rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest) rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1]) - mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=1) - mask_head_inputs = mask_head_inputs.reshape(batch_size, -1, hw[0], hw[1]) + mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2) # TODO: change following code to support batch inference weights, biases = _parse_dynamic_params(self, param_preds) - mask_preds = _dynamic_conv_forward(mask_feats, weights, biases) + mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases) mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1]) - mask_preds = [ - aligned_bilinear( - mask_preds[i].unsqueeze(0), - int(self.mask_feat_stride / self.mask_out_stride), - ).squeeze(0) for i in range(batch_size) - ] - + mask_preds = aligned_bilinear(mask_preds, + int(self.mask_feat_stride / self.mask_out_stride)) return (mask_preds, ) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat') def condinst_mask_head__predict_by_feat(self, - mask_preds: List[Tensor], - results_list: InstanceList, + mask_preds: Tensor, + results_list: Dict[str, torch.Tensor], batch_img_metas: List[dict], rescale: bool = True, **kwargs): - assert len(mask_preds) == len(results_list) == len(batch_img_metas) cfg = self.test_cfg - dets = [results.dets.unsqueeze(0) for results in results_list] - labels = [results.labels.unsqueeze(0) for results in results_list] - img_hw = [img_meta['img_shape'][:2] for img_meta in batch_img_metas] - - mask_preds = [ - mask_preds[i].sigmoid().unsqueeze(0) for i in range(len(mask_preds)) - ] - mask_preds = [ - aligned_bilinear(mask_preds[i], self.mask_out_stride) - for i in range(len(mask_preds)) - ] - mask_preds = [ - mask_preds[i][:, :, :img_hw[i][0], :img_hw[i][1]] - for i in range(len(mask_preds)) - ] + dets = results_list['dets'] + labels = results_list['labels'] + img_hw = batch_img_metas[0]['img_shape'][:2] - masks = [mask_preds[i] > cfg.mask_thr for i in range(len(mask_preds))] - masks = [masks[i].float() for i in range(len(masks))] + mask_preds = mask_preds.sigmoid() + mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) + mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]] + masks = (mask_preds > cfg.mask_thr).float() return dets, labels, masks