diff --git a/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py b/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py index c9d92450570..46b3c135dd8 100644 --- a/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py +++ b/configs/maskformer/maskformer_r50_mstrain_16x1_75e_coco.py @@ -1,7 +1,9 @@ _base_ = [ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py' ] - +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes model = dict( type='MaskFormer', backbone=dict( @@ -19,8 +21,8 @@ in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside feat_channels=256, out_channels=256, - num_things_classes=80, - num_stuff_classes=53, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, num_queries=100, pixel_decoder=dict( type='TransformerEncoderPixelDecoder', @@ -87,11 +89,10 @@ init_cfg=None), loss_cls=dict( type='CrossEntropyLoss', - bg_cls_weight=0.1, use_sigmoid=False, loss_weight=1.0, reduction='mean', - class_weight=1.0), + class_weight=[1.0] * num_classes + [0.1]), loss_mask=dict( type='FocalLoss', use_sigmoid=True, @@ -107,6 +108,12 @@ naive_dice=True, eps=1.0, loss_weight=1.0)), + panoptic_fusion_head=dict( + type='MaskFormerFusionHead', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None), train_cfg=dict( assigner=dict( type='MaskHungarianAssigner', @@ -116,8 +123,19 @@ dice_cost=dict( type='DiceCost', weight=1.0, pred_act=True, eps=1.0)), sampler=dict(type='MaskPseudoSampler')), - test_cfg=dict(object_mask_thr=0.8, iou_thr=0.8), - # pretrained=None, + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=False, + # max_per_image is for instance segmentation. + max_per_image=100, + object_mask_thr=0.8, + iou_thr=0.8, + # In MaskFormer's panoptic postprocessing, + # it will not filter masks whose score is smaller than 0.5 . + filter_low_score=False), init_cfg=None) # dataset settings diff --git a/mmdet/core/mask/__init__.py b/mmdet/core/mask/__init__.py index 2083af20251..644a9b1d9b4 100644 --- a/mmdet/core/mask/__init__.py +++ b/mmdet/core/mask/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .mask_target import mask_target from .structures import BaseInstanceMasks, BitmapMasks, PolygonMasks -from .utils import encode_mask_results, split_combined_polys +from .utils import encode_mask_results, mask2bbox, split_combined_polys __all__ = [ 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks', - 'PolygonMasks', 'encode_mask_results' + 'PolygonMasks', 'encode_mask_results', 'mask2bbox' ] diff --git a/mmdet/core/mask/utils.py b/mmdet/core/mask/utils.py index 8e95f72b364..90544b34f49 100644 --- a/mmdet/core/mask/utils.py +++ b/mmdet/core/mask/utils.py @@ -2,6 +2,7 @@ import mmcv import numpy as np import pycocotools.mask as mask_util +import torch def split_combined_polys(polys, poly_lens, polys_per_mask): @@ -62,3 +63,27 @@ def encode_mask_results(mask_results): return encoded_mask_results, cls_mask_scores else: return encoded_mask_results + + +def mask2bbox(masks): + """Obtain tight bounding boxes of binary masks. + + Args: + masks (Tensor): Binary mask of shape (n, h, w). + + Returns: + Tensor: Bboxe with shape (n, 4) of \ + positive region in binary mask. + """ + N = masks.shape[0] + bboxes = masks.new_zeros((N, 4), dtype=torch.float32) + x_any = torch.any(masks, dim=1) + y_any = torch.any(masks, dim=2) + for i in range(N): + x = torch.where(x_any[i, :])[0] + y = torch.where(y_any[i, :])[0] + if len(x) > 0 and len(y) > 0: + bboxes[i, :] = bboxes.new_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1]) + + return bboxes diff --git a/mmdet/models/dense_heads/maskformer_head.py b/mmdet/models/dense_heads/maskformer_head.py index 7d7a644c7e1..cb26d8a9096 100644 --- a/mmdet/models/dense_heads/maskformer_head.py +++ b/mmdet/models/dense_heads/maskformer_head.py @@ -8,7 +8,6 @@ from mmcv.runner import force_fp32 from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean -from mmdet.core.evaluation import INSTANCE_OFFSET from mmdet.models.utils import preprocess_panoptic_gt from ..builder import HEADS, build_loss from .anchor_free_head import AnchorFreeHead @@ -64,10 +63,9 @@ def __init__(self, positional_encoding=None, loss_cls=dict( type='CrossEntropyLoss', - bg_cls_weight=0.1, use_sigmoid=False, loss_weight=1.0, - class_weight=1.0), + class_weight=[1.0] * 133 + [0.1]), loss_mask=dict( type='FocalLoss', use_sigmoid=True, @@ -125,25 +123,7 @@ def __init__(self, sampler_cfg = dict(type='MaskPseudoSampler') self.sampler = build_sampler(sampler_cfg, context=self) - self.bg_cls_weight = 0 - class_weight = loss_cls.get('class_weight', None) - if class_weight is not None and (self.__class__ is MaskFormerHead): - assert isinstance(class_weight, float), 'Expected ' \ - 'class_weight to have type float. Found ' \ - f'{type(class_weight)}.' - # NOTE following the official MaskFormerHead repo, bg_cls_weight - # means relative classification weight of the VOID class. - bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) - assert isinstance(bg_cls_weight, float), 'Expected ' \ - 'bg_cls_weight to have type float. Found ' \ - f'{type(bg_cls_weight)}.' - class_weight = torch.ones(self.num_classes + 1) * class_weight - # set VOID class as the last indice - class_weight[self.num_classes] = bg_cls_weight - loss_cls.update({'class_weight': class_weight}) - if 'bg_cls_weight' in loss_cls: - loss_cls.pop('bg_cls_weight') - self.bg_cls_weight = bg_cls_weight + self.class_weight = loss_cls.class_weight self.loss_cls = build_loss(loss_cls) self.loss_mask = build_loss(loss_mask) self.loss_dice = build_loss(loss_dice) @@ -304,7 +284,8 @@ def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, Args: all_cls_scores (Tensor): Classification scores for all decoder layers with shape (num_decoder, batch_size, num_queries, - cls_out_channels). + cls_out_channels). Note `cls_out_channels` should includes + background. all_mask_preds (Tensor): Mask scores for all decoder layers with shape (num_decoder, batch_size, num_queries, h, w). gt_labels_list (list[Tensor]): Ground truth class indices for each @@ -347,7 +328,8 @@ def loss_single(self, cls_scores, mask_preds, gt_labels_list, Args: cls_scores (Tensor): Mask score logits from a single decoder layer for all images. Shape (batch_size, num_queries, - cls_out_channels). + cls_out_channels). Note `cls_out_channels` should includes + background. mask_preds (Tensor): Mask logits for a pixel decoder for all images. Shape (batch_size, num_queries, h, w). gt_labels_list (list[Tensor]): Ground truth class indices for each @@ -385,8 +367,7 @@ def loss_single(self, cls_scores, mask_preds, gt_labels_list, labels = labels.flatten(0, 1) label_weights = label_weights.flatten(0, 1) - class_weight = cls_scores.new_ones(self.num_classes + 1) - class_weight[-1] = self.bg_cls_weight + class_weight = cls_scores.new_tensor(self.class_weight) loss_cls = self.loss_cls( cls_scores, labels, @@ -544,30 +525,22 @@ def forward_train(self, return losses - def simple_test(self, feats, img_metas, rescale=False): - """Test segment without test-time aumengtation. - - Only the output of last decoder layers was used. + def simple_test(self, feats, img_metas, **kwargs): + """Test without augmentaton. Args: feats (list[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. img_metas (list[dict]): List of image information. - rescale (bool, optional): If True, return boxes in - original image space. Default False. Returns: - list[dict[str, np.array]]: semantic segmentation results\ - and panoptic segmentation results for each image. - - .. code-block:: none + tuple: A tuple contains two tensors. - [ - { - 'pan_results': , # shape = [h, w] - }, - ... - ] + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). """ all_cls_scores, all_mask_preds = self(feats, img_metas) mask_cls_results = all_cls_scores[-1] @@ -581,84 +554,4 @@ def simple_test(self, feats, img_metas, rescale=False): mode='bilinear', align_corners=False) - results = [] - for mask_cls_result, mask_pred_result, meta in zip( - mask_cls_results, mask_pred_results, img_metas): - # remove padding - img_height, img_width = meta['img_shape'][:2] - mask_pred_result = mask_pred_result[:, :img_height, :img_width] - - if rescale: - # return result in original resolution - ori_height, ori_width = meta['ori_shape'][:2] - mask_pred_result = F.interpolate(mask_pred_result.unsqueeze(1), - size=(ori_height, ori_width), - mode='bilinear', - align_corners=False)\ - .squeeze(1) - - mask = self.post_process(mask_cls_result, mask_pred_result) - results.append(mask) - - return results - - def post_process(self, mask_cls, mask_pred): - """Panoptic segmengation inference. - - This implementation is modified from `MaskFormer - `_. - - Args: - mask_cls (Tensor): Classfication outputs for a image. - shape = (num_queries, cls_out_channels). - mask_pred (Tensor): Mask outputs for a image. - shape = (num_queries, h, w). - - Returns: - Tensor: panoptic segment result of shape (h, w),\ - each element in Tensor means: - segment_id = _cls + instance_id * INSTANCE_OFFSET. - """ - object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) - iou_thr = self.test_cfg.get('iou_thr', 0.8) - - scores, labels = F.softmax(mask_cls, dim=-1).max(-1) - mask_pred = mask_pred.sigmoid() - - keep = labels.ne(self.num_classes) & (scores > object_mask_thr) - cur_scores = scores[keep] - cur_classes = labels[keep] - cur_masks = mask_pred[keep] - - cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks - - h, w = cur_masks.shape[-2:] - panoptic_seg = torch.full((h, w), - self.num_classes, - dtype=torch.int32, - device=cur_masks.device) - if cur_masks.shape[0] == 0: - # We didn't detect any mask :( - pass - else: - cur_mask_ids = cur_prob_masks.argmax(0) - instance_id = 1 - for k in range(cur_classes.shape[0]): - pred_class = int(cur_classes[k].item()) - isthing = pred_class < self.num_things_classes - mask = cur_mask_ids == k - mask_area = mask.sum().item() - original_area = (cur_masks[k] >= 0.5).sum().item() - if mask_area > 0 and original_area > 0: - if mask_area / original_area < iou_thr: - continue - - if not isthing: - # different stuff regions of same class will be - # merged here, and stuff share the instance_id 0. - panoptic_seg[mask] = pred_class - else: - panoptic_seg[mask] = ( - pred_class + instance_id * INSTANCE_OFFSET) - instance_id += 1 - return panoptic_seg + return mask_cls_results, mask_pred_results diff --git a/mmdet/models/detectors/maskformer.py b/mmdet/models/detectors/maskformer.py index f7257d2547d..b626e070813 100644 --- a/mmdet/models/detectors/maskformer.py +++ b/mmdet/models/detectors/maskformer.py @@ -2,7 +2,7 @@ import mmcv import numpy as np -from mmdet.core import INSTANCE_OFFSET +from mmdet.core import INSTANCE_OFFSET, bbox2result from mmdet.core.visualization import imshow_det_bboxes from ..builder import DETECTORS, build_backbone, build_head, build_neck from .single_stage import SingleStageDetector @@ -18,6 +18,7 @@ def __init__(self, backbone, neck=None, panoptic_head=None, + panoptic_fusion_head=None, train_cfg=None, test_cfg=None, init_cfg=None): @@ -25,9 +26,15 @@ def __init__(self, self.backbone = build_backbone(backbone) if neck is not None: self.neck = build_neck(neck) - panoptic_head.update(train_cfg=train_cfg) - panoptic_head.update(test_cfg=test_cfg) - self.panoptic_head = build_head(panoptic_head) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head = build_head(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = build_head(panoptic_fusion_head_) self.num_things_classes = self.panoptic_head.num_things_classes self.num_stuff_classes = self.panoptic_head.num_stuff_classes @@ -96,16 +103,53 @@ def forward_train(self, return losses - def simple_test(self, img, img_metas, **kwargs): - """Test without augmentation.""" - feat = self.extract_feat(img) - mask_results = self.panoptic_head.simple_test(feat, img_metas, - **kwargs) + def simple_test(self, imgs, img_metas, **kwargs): + """Test without augmentation. + + Args: + imgs (Tensor): A batch of images. + img_metas (list[dict]): List of image information. - results = [] - for mask in mask_results: - result = {'pan_results': mask.detach().cpu().numpy()} - results.append(result) + Returns: + list[dict[str, np.array | tuple]]: Semantic segmentation \ + results and panoptic segmentation results for each \ + image. + + .. code-block:: none + + [ + { + 'pan_results': np.array, # shape = [h, w] + 'ins_results': tuple[list], + # semantic segmentation results are not supported yet + 'sem_results': np.array + }, + ... + ] + """ + feats = self.extract_feat(imgs) + mask_cls_results, mask_pred_results = self.panoptic_head.simple_test( + feats, img_metas, **kwargs) + results = self.panoptic_fusion_head.simple_test( + mask_cls_results, mask_pred_results, img_metas, **kwargs) + for i in range(len(results)): + if 'pan_results' in results[i]: + results[i]['pan_results'] = results[i]['pan_results'].detach( + ).cpu().numpy() + + if 'ins_results' in results[i]: + labels_per_image, bboxes, mask_pred_binary = results[i][ + 'ins_results'] + bbox_results = bbox2result(bboxes, labels_per_image, + self.num_things_classes) + mask_results = [[] for _ in range(self.num_things_classes)] + for j, label in enumerate(labels_per_image): + mask = mask_pred_binary[j].detach().cpu().numpy() + mask_results[label].append(mask) + results[i]['ins_results'] = bbox_results, mask_results + + assert 'sem_results' not in results[i], 'segmantic segmentation '\ + 'results are not supported yet.' return results diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py b/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py index d14a33c317a..41625a61d6d 100644 --- a/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py +++ b/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py @@ -2,3 +2,4 @@ from .base_panoptic_fusion_head import \ BasePanopticFusionHead # noqa: F401,F403 from .heuristic_fusion_head import HeuristicFusionHead # noqa: F401,F403 +from .maskformer_fusion_head import MaskFormerFusionHead # noqa: F401,F403 diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py new file mode 100644 index 00000000000..5b59ce4deae --- /dev/null +++ b/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py @@ -0,0 +1,241 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + +from mmdet.core.evaluation.panoptic_utils import INSTANCE_OFFSET +from mmdet.core.mask import mask2bbox +from mmdet.models.builder import HEADS +from .base_panoptic_fusion_head import BasePanopticFusionHead + + +@HEADS.register_module() +class MaskFormerFusionHead(BasePanopticFusionHead): + + def __init__(self, + num_things_classes=80, + num_stuff_classes=53, + test_cfg=None, + loss_panoptic=None, + init_cfg=None, + **kwargs): + super().__init__(num_things_classes, num_stuff_classes, test_cfg, + loss_panoptic, init_cfg, **kwargs) + + def forward_train(self, **kwargs): + """MaskFormerFusionHead has no training loss.""" + return dict() + + def panoptic_postprocess(self, mask_cls, mask_pred): + """Panoptic segmengation inference. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + Tensor: Panoptic segment result of shape \ + (h, w), each element in Tensor means: \ + ``segment_id = _cls + instance_id * INSTANCE_OFFSET``. + """ + object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) + iou_thr = self.test_cfg.get('iou_thr', 0.8) + filter_low_score = self.test_cfg.get('filter_low_score', False) + + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.num_classes) & (scores > object_mask_thr) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.full((h, w), + self.num_classes, + dtype=torch.int32, + device=cur_masks.device) + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + pass + else: + cur_mask_ids = cur_prob_masks.argmax(0) + instance_id = 1 + for k in range(cur_classes.shape[0]): + pred_class = int(cur_classes[k].item()) + isthing = pred_class < self.num_things_classes + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + + if filter_low_score: + mask = mask & (cur_masks[k] >= 0.5) + + if mask_area > 0 and original_area > 0: + if mask_area / original_area < iou_thr: + continue + + if not isthing: + # different stuff regions of same class will be + # merged here, and stuff share the instance_id 0. + panoptic_seg[mask] = pred_class + else: + panoptic_seg[mask] = ( + pred_class + instance_id * INSTANCE_OFFSET) + instance_id += 1 + + return panoptic_seg + + def semantic_postprocess(self, mask_cls, mask_pred): + """Semantic segmengation postprocess. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + Tensor: Semantic segment result of shape \ + (cls_out_channels, h, w). + """ + # TODO add semantic segmentation result + raise NotImplementedError + + def instance_postprocess(self, mask_cls, mask_pred): + """Instance segmengation postprocess. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + tuple[Tensor]: Instance segmentation results. + + - labels_per_image (Tensor): Predicted labels,\ + shape (n, ). + - bboxes (Tensor): Bboxes and scores with shape (n, 5) of \ + positive region in binary mask, the last column is scores. + - mask_pred_binary (Tensor): Instance masks of \ + shape (n, h, w). + """ + max_per_image = self.test_cfg.get('max_per_image', 100) + num_queries = mask_cls.shape[0] + # shape (num_queries, num_class) + scores = F.softmax(mask_cls, dim=-1)[:, :-1] + # shape (num_queries * num_class, ) + labels = torch.arange(self.num_classes, device=mask_cls.device).\ + unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + scores_per_image, top_indices = scores.flatten(0, 1).topk( + max_per_image, sorted=False) + labels_per_image = labels[top_indices] + + query_indices = top_indices // self.num_classes + mask_pred = mask_pred[query_indices] + + # extract things + is_thing = labels_per_image < self.num_things_classes + scores_per_image = scores_per_image[is_thing] + labels_per_image = labels_per_image[is_thing] + mask_pred = mask_pred[is_thing] + + mask_pred_binary = (mask_pred > 0).float() + mask_scores_per_image = (mask_pred.sigmoid() * + mask_pred_binary).flatten(1).sum(1) / ( + mask_pred_binary.flatten(1).sum(1) + 1e-6) + det_scores = scores_per_image * mask_scores_per_image + mask_pred_binary = mask_pred_binary.bool() + bboxes = mask2bbox(mask_pred_binary) + bboxes = torch.cat([bboxes, det_scores[:, None]], dim=-1) + + return labels_per_image, bboxes, mask_pred_binary + + def simple_test(self, + mask_cls_results, + mask_pred_results, + img_metas, + rescale=False, + **kwargs): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + mask_cls_results (Tensor): Mask classification logits, + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + mask_pred_results (Tensor): Mask logits, shape + (batch_size, num_queries, h, w). + img_metas (list[dict]): List of image information. + rescale (bool, optional): If True, return boxes in + original image space. Default False. + + Returns: + list[dict[str, Tensor | tuple[Tensor]]]: Semantic segmentation \ + results and panoptic segmentation results for each \ + image. + + .. code-block:: none + + [ + { + 'pan_results': Tensor, # shape = [h, w] + 'ins_results': tuple[Tensor], + # semantic segmentation results are not supported yet + 'sem_results': Tensor + }, + ... + ] + """ + panoptic_on = self.test_cfg.get('panoptic_on', True) + semantic_on = self.test_cfg.get('semantic_on', False) + instance_on = self.test_cfg.get('instance_on', False) + assert not semantic_on, 'segmantic segmentation '\ + 'results are not supported yet.' + + results = [] + for mask_cls_result, mask_pred_result, meta in zip( + mask_cls_results, mask_pred_results, img_metas): + # remove padding + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_result[:, :img_height, :img_width] + + if rescale: + # return result in original resolution + ori_height, ori_width = meta['ori_shape'][:2] + mask_pred_result = F.interpolate( + mask_pred_result[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + + result = dict() + if panoptic_on: + pan_results = self.panoptic_postprocess( + mask_cls_result, mask_pred_result) + result['pan_results'] = pan_results + + if instance_on: + ins_results = self.instance_postprocess( + mask_cls_result, mask_pred_result) + result['ins_results'] = ins_results + + if semantic_on: + sem_results = self.semantic_postprocess( + mask_cls_result, mask_pred_result) + result['sem_results'] = sem_results + + results.append(result) + + return results diff --git a/tests/test_models/test_dense_heads/test_maskformer_head.py b/tests/test_models/test_dense_heads/test_maskformer_head.py index e70f09afe3f..f9cf3b2326f 100644 --- a/tests/test_models/test_dense_heads/test_maskformer_head.py +++ b/tests/test_models/test_dense_heads/test_maskformer_head.py @@ -23,15 +23,17 @@ def test_maskformer_head_loss(): torch.rand((2, 64 * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) for i in range(4) ] - + num_things_classes = 80 + num_stuff_classes = 53 + num_classes = num_things_classes + num_stuff_classes config = ConfigDict( dict( type='MaskFormerHead', in_channels=[base_channels * 2**i for i in range(4)], feat_channels=base_channels, out_channels=base_channels, - num_things_classes=80, - num_stuff_classes=53, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, num_queries=100, pixel_decoder=dict( type='TransformerEncoderPixelDecoder', @@ -102,11 +104,10 @@ def test_maskformer_head_loss(): init_cfg=None), loss_cls=dict( type='CrossEntropyLoss', - bg_cls_weight=0.1, use_sigmoid=False, loss_weight=1.0, reduction='mean', - class_weight=1.0), + class_weight=[1.0] * num_classes + [0.1]), loss_mask=dict( type='FocalLoss', use_sigmoid=True, diff --git a/tests/test_models/test_seg_heads/test_maskformer_fusion_head.py b/tests/test_models/test_seg_heads/test_maskformer_fusion_head.py new file mode 100644 index 00000000000..8d5131f9a60 --- /dev/null +++ b/tests/test_models/test_seg_heads/test_maskformer_fusion_head.py @@ -0,0 +1,53 @@ +import pytest +import torch +from mmcv import ConfigDict + +from mmdet.models.seg_heads.panoptic_fusion_heads import MaskFormerFusionHead + + +def test_maskformer_fusion_head(): + img_metas = [ + { + 'batch_input_shape': (128, 160), + 'img_shape': (126, 160, 3), + 'ori_shape': (63, 80, 3), + 'pad_shape': (128, 160, 3) + }, + ] + num_things_classes = 80 + num_stuff_classes = 53 + num_classes = num_things_classes + num_stuff_classes + config = ConfigDict( + type='MaskFormerFusionHead', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=True, + max_per_image=100, + object_mask_thr=0.8, + iou_thr=0.8, + filter_low_score=False), + init_cfg=None) + + self = MaskFormerFusionHead(**config) + + # test forward_train + assert self.forward_train() == dict() + + mask_cls_results = torch.rand((1, 100, num_classes + 1)) + mask_pred_results = torch.rand((1, 100, 128, 160)) + + # test panoptic_postprocess and instance_postprocess + results = self.simple_test(mask_cls_results, mask_pred_results, img_metas) + assert 'ins_results' in results[0] and 'pan_results' in results[0] + + # test semantic_postprocess + config.test_cfg.semantic_on = True + with pytest.raises(AssertionError): + self.simple_test(mask_cls_results, mask_pred_results, img_metas) + + with pytest.raises(NotImplementedError): + self.semantic_postprocess(mask_cls_results, mask_pred_results) diff --git a/tests/test_utils/test_masks.py b/tests/test_utils/test_masks.py index 7061046a377..3e08cb4e9c0 100644 --- a/tests/test_utils/test_masks.py +++ b/tests/test_utils/test_masks.py @@ -3,7 +3,7 @@ import pytest import torch -from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.core import BitmapMasks, PolygonMasks, mask2bbox def dummy_raw_bitmap_masks(size): @@ -687,3 +687,27 @@ def test_polygon_mask_iter(): polygon_masks = PolygonMasks(raw_masks, 28, 28) for i, polygon_mask in enumerate(polygon_masks): assert np.equal(polygon_mask, raw_masks[i]).all() + + +def test_mask2bbox(): + # no instance + masks = torch.zeros((1, 20, 15), dtype=torch.bool).float() + bboxes_empty_gt = torch.tensor([[0, 0, 0, 0]]) + bboxes = mask2bbox(masks) + assert torch.allclose(bboxes_empty_gt.float(), bboxes) + + # the entire mask is an instance + bboxes_full_gt = torch.tensor([[0, 0, 15, 20]]).float() + masks = torch.ones((1, 20, 15), dtype=torch.bool) + bboxes = mask2bbox(masks) + assert torch.allclose(bboxes_full_gt, bboxes) + + # a pentagon-shaped instance + bboxes_gt = torch.tensor([[2, 2, 7, 6]]).float() + masks = torch.zeros((1, 20, 15), dtype=torch.bool) + masks[0, 2, 4] = True + masks[0, 3, 3:6] = True + masks[0, 4, 2:7] = True + masks[0, 5, 2:7] = True + bboxes = mask2bbox(masks) + assert torch.allclose(bboxes_gt, bboxes)