Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

Commit

Permalink
Integrate C-IL for instance segmentation (#19)
Browse files Browse the repository at this point in the history
* Update CHANGELOG.md

Signed-off-by: Songki Choi <[email protected]>

* Add instance-segmentation features EfficientnetB2B_maskrcnn

* add cross_focal_loss

* add resnet50 ins-seg

* add grad_clip for instance seg

Co-authored-by: Songki Choi <[email protected]>
  • Loading branch information
harimkang and goodsong81 authored Jun 29, 2022
1 parent 21d3a6c commit c2e826f
Show file tree
Hide file tree
Showing 16 changed files with 755 additions and 50 deletions.
1 change: 1 addition & 0 deletions mpa/det/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
import mpa.modules.models.heads.custom_retina_head
import mpa.modules.models.heads.custom_ssd_head
import mpa.modules.models.heads.custom_vfnet_head
import mpa.modules.models.heads.custom_roi_head
import mpa.modules.models.losses.cross_focal_loss
import mpa.modules.models.losses.l2sp_loss
65 changes: 45 additions & 20 deletions mpa/det/inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
#

# import numpy as np
# import os.path as osp
from mmcv.parallel import MMDataParallel
import torch
from mmcv.parallel import MMDataParallel, is_module_wrapper
from mmcv.runner import load_checkpoint, wrap_fp16_model

from mmdet.apis import single_gpu_test
from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor
from mmdet.models import build_detector
from mmdet.parallel import MMDataCPU

from mpa.registry import STAGES
from .stage import DetectionStage
Expand All @@ -32,6 +31,7 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):
"""
self._init_logger()
mode = kwargs.get('mode', 'train')
eval = kwargs.get('eval', False)
if mode not in self.mode:
return {}

Expand All @@ -42,7 +42,7 @@ def run(self, model_cfg, model_ckpt, data_cfg, **kwargs):

# mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))

outputs = self.infer(cfg)
outputs = self.infer(cfg, eval=eval)

# Save outputs
# output_file_path = osp.join(cfg.work_dir, 'infer_result.npy')
Expand All @@ -65,7 +65,7 @@ def default(self, obj):
print(json_dump)
"""

def infer(self, cfg):
def infer(self, cfg, dump_features=False, eval=False):
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
Expand Down Expand Up @@ -128,27 +128,52 @@ def infer(self, cfg):
if cfg.get('load_from', None):
load_checkpoint(model, cfg.load_from, map_location='cpu')

# Inference
model = MMDataParallel(model, device_ids=[0])
if torch.cuda.is_available():
eval_model = MMDataParallel(model.cuda(cfg.gpu_ids[0]),
device_ids=cfg.gpu_ids)
else:
eval_model = MMDataCPU(model)

# InferenceProgressCallback (Time Monitor enable into Infer task)
DetectionStage.set_inference_progress_callback(model, cfg)

detections = single_gpu_test(model, data_loader)

eval_cfg = cfg.evaluation.copy()
eval_cfg.pop('interval', None)
eval_cfg.pop('save_best', None)

metric = dataset.evaluate(
detections,
logger='silent',
**eval_cfg
)[cfg.evaluation.metric]
# detections = single_gpu_test(model, data_loader)
eval_predictions = []
feature_vectors = []

def dump_features_hook(mod, inp, out):
with torch.no_grad():
feature_map = out[-1]
feature_vector = torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1))
assert feature_vector.size(0) == 1
feature_vectors.append(feature_vector.view(-1).detach().cpu().numpy())

def dummy_dump_features_hook(mod, inp, out):
feature_vectors.append(None)

hook = dump_features_hook if dump_features else dummy_dump_features_hook
# Use a single gpu for testing. Set in both mm_val_dataloader and eval_model
if is_module_wrapper(model):
model = model.module
with model.backbone.register_forward_hook(hook):
for data in data_loader:
with torch.no_grad():
result = eval_model(return_loss=False, rescale=True, **data)
eval_predictions.extend(result)

for key in [
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
'rule', 'dynamic_intervals'
]:
cfg.evaluation.pop(key, None)

metric = None
if eval:
metric = dataset.evaluate(eval_predictions, **cfg.evaluation)[cfg.evaluation.metric]

outputs = dict(
classes=target_classes,
detections=detections,
detections=eval_predictions,
metric=metric,
)
return outputs
57 changes: 30 additions & 27 deletions mpa/det/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,6 @@ def configure_model(self, cfg, training, **kwargs):
cfg.model.arch_type = cfg.model.type
cfg.model.type = super_type

if not training:
# BBox head for pseudo label output
if 'roi_head' in cfg.model:
# For Faster-RCNNs
bbox_head_cfg = cfg.model.roi_head.bbox_head
else:
# For other architectures
bbox_head_cfg = cfg.model.bbox_head

if bbox_head_cfg.type in ['Shared2FCBBoxHead', 'PseudoShared2FCBBoxHead']:
bbox_head_cfg.type = 'PseudoShared2FCBBoxHead'
elif bbox_head_cfg.type in ['SSDHead', 'PseudoSSDHead']:
bbox_head_cfg.type = 'PseudoSSDHead'

# OMZ-plugin
if cfg.model.backbone.type == 'OmzBackboneDet':
ir_path = kwargs.get('ir_path')
Expand Down Expand Up @@ -144,7 +130,8 @@ def configure_task(self, cfg, training, **kwargs):
self.configure_task_data_pipeline(cfg, model_classes, data_classes)

# Evaluation dataset
self.configure_task_eval_dataset(cfg, model_classes)
if cfg.get('task', 'detection') == 'detection':
self.configure_task_eval_dataset(cfg, model_classes)

# Training hook for task adaptation
self.configure_task_adapt_hook(cfg, org_model_classes, model_classes)
Expand All @@ -165,8 +152,9 @@ def configure_task_classes(self, cfg, task_adapt_type, task_adapt_op):
# Model classes
if task_adapt_op == 'REPLACE':
if len(data_classes) == 0:
raise ValueError('Data classes should contain at least one class!')
model_classes = data_classes.copy()
model_classes = org_model_classes.copy()
else:
model_classes = data_classes.copy()
elif task_adapt_op == 'MERGE':
model_classes = org_model_classes + [cls for cls in data_classes if cls not in org_model_classes]
else:
Expand All @@ -181,12 +169,22 @@ def configure_task_classes(self, cfg, task_adapt_type, task_adapt_op):
)

# Model architecture
head_names = ('mask_head', 'bbox_head', 'segm_head')
num_classes = len(model_classes)
if 'roi_head' in cfg.model:
# For Faster-RCNNs
cfg.model.roi_head.bbox_head.num_classes = len(model_classes)
for head_name in head_names:
if head_name in cfg.model.roi_head:
if isinstance(cfg.model.roi_head[head_name], list):
for head in cfg.model.roi_head[head_name]:
head.num_classes = num_classes
else:
cfg.model.roi_head[head_name].num_classes = num_classes
else:
# For other architectures (including SSD)
cfg.model.bbox_head.num_classes = len(model_classes)
for head_name in head_names:
if head_name in cfg.model:
cfg.model[head_name].num_classes = num_classes

return org_model_classes, model_classes, data_classes

Expand Down Expand Up @@ -227,31 +225,36 @@ def configure_task_adapt_hook(self, cfg, org_model_classes, model_classes):
update_or_add_custom_hook(cfg, task_adapt_hook)

def configure_task_cls_incr(self, cfg, task_adapt_type, org_model_classes, model_classes):
if task_adapt_type == 'mpa' and cfg.model.bbox_head.type != 'PseudoSSDHead':
if cfg.get('task', 'detection') == 'detection':
bbox_head = cfg.model.bbox_head
else:
bbox_head = cfg.model.roi_head.bbox_head
if task_adapt_type == 'mpa':
tr_data_cfg = self.get_train_data_cfg(cfg)
if tr_data_cfg.type != 'MPADetDataset':
tr_data_cfg.img_ids_dict = self.get_img_ids_for_incr(cfg, org_model_classes, model_classes)
tr_data_cfg.org_type = tr_data_cfg.type
tr_data_cfg.type = 'DetIncrCocoDataset'
alpha, gamma = 0.25, 2.0
if cfg.model.bbox_head.type in ['SSDHead', 'CustomSSDHead']:
if bbox_head.type in ['SSDHead', 'CustomSSDHead']:
gamma = 1 if cfg['task_adapt'].get('efficient_mode', False) else 2
cfg.model.bbox_head.type = 'CustomSSDHead'
cfg.model.bbox_head.loss_cls = ConfigDict(
bbox_head.type = 'CustomSSDHead'
bbox_head.loss_cls = ConfigDict(
type='FocalLoss',
loss_weight=1.0,
gamma=gamma,
reduction='none',
)
elif cfg.model.bbox_head.type in ['ATSSHead']:
elif bbox_head.type in ['ATSSHead']:
gamma = 3 if cfg['task_adapt'].get('efficient_mode', False) else 4.5
cfg.model.bbox_head.loss_cls.gamma = gamma
elif cfg.model.bbox_head.type in ['VFNetHead', 'CustomVFNetHead']:
bbox_head.loss_cls.gamma = gamma
elif bbox_head.type in ['VFNetHead', 'CustomVFNetHead']:
alpha = 0.75
gamma = 1 if cfg['task_adapt'].get('efficient_mode', False) else 2

# Ignore Mode
if cfg.get('ignore', False):
cfg.model.bbox_head.loss_cls = ConfigDict(
bbox_head.loss_cls = ConfigDict(
type='CrossSigmoidFocalLoss',
use_sigmoid=True,
num_classes=len(model_classes),
Expand Down
1 change: 1 addition & 0 deletions mpa/modules/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from . import unbiased_teacher
from . import custom_vfnet_detector
from . import custom_atss_detector
from . import custom_maskrcnn_detector
74 changes: 74 additions & 0 deletions mpa/modules/models/detectors/custom_maskrcnn_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import functools
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.mask_rcnn import MaskRCNN
from .sam_detector_mixin import SAMDetectorMixin
from .l2sp_detector_mixin import L2SPDetectorMixin
from mpa.modules.utils.task_adapt import map_class_names
from mpa.utils.logger import get_logger

logger = get_logger()


@DETECTORS.register_module()
class CustomMaskRCNN(SAMDetectorMixin, L2SPDetectorMixin, MaskRCNN):
def __init__(self, *args, task_adapt=None, **kwargs):
super().__init__(*args, **kwargs)

# Hook for class-sensitive weight loading
if task_adapt:
self._register_load_state_dict_pre_hook(
functools.partial(
self.load_state_dict_pre_hook,
self, # model
task_adapt['dst_classes'], # model_classes
task_adapt['src_classes'] # chkpt_classes
)
)

@staticmethod
def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs):
"""Modify input state_dict according to class name matching before weight loading
"""
logger.info(f'----------------- CustomMaskRCNN.load_state_dict_pre_hook() called w/ prefix: {prefix}')

# Dst to src mapping index
model_dict = model.state_dict()
model_classes = list(model_classes)
chkpt_classes = list(chkpt_classes)
model2chkpt = map_class_names(model_classes, chkpt_classes)
logger.info(f'{chkpt_classes} -> {model_classes} ({model2chkpt})')

# List of class-relevant params & their row-stride
param_strides = {
'roi_head.bbox_head.fc_cls.weight': 1,
'roi_head.bbox_head.fc_cls.bias': 1,
'roi_head.bbox_head.fc_reg.weight': 4, # 4 rows (bbox) for each class
'roi_head.bbox_head.fc_reg.bias': 4,
}

for model_name, stride in param_strides.items():
chkpt_name = prefix + model_name
if model_name not in model_dict or chkpt_name not in chkpt_dict:
logger.info(f'Skipping weight copy: {chkpt_name}')
continue

# Mix weights
model_param = model_dict[model_name].clone()
chkpt_param = chkpt_dict[chkpt_name]
for m, c in enumerate(model2chkpt):
if c >= 0:
# Copying only matched weight rows
model_param[(m) * stride:(m + 1) * stride].copy_(
chkpt_param[(c) * stride:(c + 1) * stride])
if model_param.shape[0] > len(model_classes * stride): # BG class
c = len(chkpt_classes)
m = len(model_classes)
model_param[(m) * stride:(m + 1) * stride].copy_(
chkpt_param[(c) * stride:(c + 1) * stride])

# Replace checkpoint weight by mixed weights
chkpt_dict[chkpt_name] = model_param
9 changes: 6 additions & 3 deletions mpa/modules/models/heads/cross_dataset_detector_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,15 @@ def vfnet_to_atss_targets(self,
bbox_weights = torch.cat(bbox_weights_list)
return labels_list, label_weights, bbox_targets_list, bbox_weights, valid_label_mask

def get_valid_label_mask(self, img_metas, all_labels):
def get_valid_label_mask(self, img_metas, all_labels, use_bg=False):
num_classes = self.num_classes + 1 if use_bg else self.num_classes
valid_label_mask = []
for i, meta in enumerate(img_metas):
mask = torch.Tensor([1 for _ in range(self.num_classes)])
if 'ignored_labels' in meta and meta['ignored_labels']:
mask = torch.Tensor([1 for _ in range(num_classes)])
if 'ignored_labels' in meta and len(meta['ignored_labels']) > 0:
mask[meta['ignored_labels']] = 0
if use_bg:
mask[self.num_classes] = 0
mask = mask.repeat(len(all_labels[i]), 1)
mask = mask.cuda() if torch.cuda.is_available() else mask
valid_label_mask.append(mask)
Expand Down
Loading

0 comments on commit c2e826f

Please sign in to comment.