Skip to content

Commit

Permalink
Fix for ssd (#2197)
Browse files Browse the repository at this point in the history
* Fix for ssd

* Add unit test

* Update CHANGELOG.md

* Modify tiliing unit test

---------

Signed-off-by: Songki Choi <[email protected]>
Co-authored-by: Songki Choi <[email protected]>
Co-authored-by: Eunwoo Shin <[email protected]>
  • Loading branch information
3 people authored May 30, 2023
1 parent cf815e3 commit 8a3efaa
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ All notable changes to this project will be documented in this file.
- Enhance exportable code file structure, video inference and default value for demo (<https://github.com/openvinotoolkit/training_extensions/pull/2051>).
- Speedup OpenVINO inference in image classificaiton, semantic segmentation, object detection and instance segmentation tasks (<https://github.com/openvinotoolkit/training_extensions/pull/2105>).
- Refactoring of ONNX export functionality (<https://github.com/openvinotoolkit/training_extensions/pull/2155>).
- SSD detector Optimization(<https://github.com/openvinotoolkit/training_extensions/pull/2197>)

### Bug fixes

Expand Down
34 changes: 14 additions & 20 deletions otx/algorithms/detection/adapters/mmdet/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
)
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.detection.adapters.mmdet.utils import (
cluster_anchors,
patch_datasets,
patch_evaluation,
should_cluster_anchors,
)

logger = get_logger()
Expand All @@ -53,6 +55,7 @@ def __init__(self):
def configure(
self,
cfg,
train_dataset,
model_ckpt,
data_cfg,
training=True,
Expand All @@ -70,7 +73,7 @@ def configure(
self.configure_ckpt(cfg, model_ckpt)
self.configure_data(cfg, training, data_cfg)
self.configure_regularization(cfg, training)
self.configure_task(cfg, training)
self.configure_task(cfg, train_dataset, training)
self.configure_hook(cfg)
self.configure_samples_per_gpu(cfg, subset)
self.configure_fp16_optimizer(cfg)
Expand Down Expand Up @@ -235,7 +238,7 @@ def configure_regularization(self, cfg, training): # noqa: C901
if "weight_decay" in cfg.optimizer:
cfg.optimizer.weight_decay = 0.0

def configure_task(self, cfg, training):
def configure_task(self, cfg, train_dataset, training):
"""Patch config to support training algorithm."""
if "task_adapt" in cfg:
logger.info(f"task config!!!!: training={training}")
Expand All @@ -245,9 +248,8 @@ def configure_task(self, cfg, training):

if self.data_classes != self.model_classes:
self.configure_task_data_pipeline(cfg)
# TODO[JAEGUK]: configure_anchor is not working
if cfg["task_adapt"].get("use_mpa_anchor", False):
self.configure_anchor(cfg)
self.configure_anchor(cfg, train_dataset)
if self.task_adapt_type == "mpa":
self.configure_bbox_head(cfg)
self.configure_ema(cfg)
Expand Down Expand Up @@ -329,12 +331,14 @@ def configure_task_data_pipeline(self, cfg):
pipeline_cfg.insert(i + 1, class_adapt_cfg)
break

def configure_anchor(self, cfg):
def configure_anchor(self, cfg, train_dataset):
"""Patch anchor settings for single stage detector."""
if cfg.model.type in ["SingleStageDetector", "CustomSingleStageDetector"]:
anchor_cfg = cfg.model.bbox_head.anchor_generator
if anchor_cfg.type == "SSDAnchorGeneratorClustered":
cfg.model.bbox_head.anchor_generator.pop("input_size", None)
if should_cluster_anchors(cfg) and train_dataset is not None:
cluster_anchors(cfg, train_dataset)

def configure_bbox_head(self, cfg):
"""Patch bbox head in detector for class incremental learning.
Expand All @@ -347,16 +351,7 @@ def configure_bbox_head(self, cfg):
bbox_head = cfg.model.roi_head.bbox_head

alpha, gamma = 0.25, 2.0
if bbox_head.type in ["SSDHead", "CustomSSDHead"]:
gamma = 1 if cfg["task_adapt"].get("efficient_mode", False) else 2
bbox_head.type = "CustomSSDHead"
bbox_head.loss_cls = ConfigDict(
type="FocalLoss",
loss_weight=1.0,
gamma=gamma,
reduction="none",
)
elif bbox_head.type in ["ATSSHead"]:
if bbox_head.type in ["ATSSHead"]:
gamma = 3 if cfg["task_adapt"].get("efficient_mode", False) else 4.5
bbox_head.loss_cls.gamma = gamma
elif bbox_head.type in ["VFNetHead", "CustomVFNetHead"]:
Expand Down Expand Up @@ -661,9 +656,9 @@ def _configure_dataloader(cfg):
class IncrDetectionConfigurer(DetectionConfigurer):
"""Patch config to support incremental learning for object detection."""

def configure_task(self, cfg, training):
def configure_task(self, cfg, train_dataset, training):
"""Patch config to support incremental learning."""
super().configure_task(cfg, training)
super().configure_task(cfg, train_dataset, training)
if "task_adapt" in cfg and self.task_adapt_type == "mpa":
self.configure_task_adapt_hook(cfg)

Expand Down Expand Up @@ -700,7 +695,7 @@ def configure_data(self, cfg, training, data_cfg):
cfg.data.unlabeled.pipeline = cfg.data.train.pipeline.copy()
self.configure_unlabeled_dataloader(cfg)

def configure_task(self, cfg, training):
def configure_task(self, cfg, train_dataset, training):
"""Patch config to support training algorithm."""
logger.info(f"Semi-SL task config!!!!: training={training}")
if "task_adapt" in cfg:
Expand All @@ -710,9 +705,8 @@ def configure_task(self, cfg, training):

if self.data_classes != self.model_classes:
self.configure_task_data_pipeline(cfg)
# TODO[JAEGUK]: configure_anchor is not working
if cfg["task_adapt"].get("use_mpa_anchor", False):
self.configure_anchor(cfg)
self.configure_anchor(cfg, train_dataset)
if self.task_adapt_type == "mpa":
self.configure_bbox_head(cfg)
else:
Expand Down
38 changes: 18 additions & 20 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,7 @@ def build_model(
return model

# pylint: disable=too-many-arguments
def configure(
self,
training=True,
subset="train",
ir_options=None,
):
def configure(self, training=True, subset="train", ir_options=None, train_dataset=None):
"""Patch mmcv configs for OTX detection settings."""

# deepcopy all configs to make sure
Expand All @@ -202,8 +197,21 @@ def configure(
else:
configurer = DetectionConfigurer()
cfg = configurer.configure(
recipe_cfg, self._model_ckpt, data_cfg, training, subset, ir_options, data_classes, model_classes
recipe_cfg,
train_dataset,
self._model_ckpt,
data_cfg,
training,
subset,
ir_options,
data_classes,
model_classes,
)
if should_cluster_anchors(self._recipe_cfg):
if train_dataset is not None:
self._anchors = cfg.model.bbox_head.anchor_generator
elif self._anchors is not None:
self._update_anchors(cfg.model.bbox_head.anchor_generator, self._anchors)
self._config = cfg
return cfg

Expand Down Expand Up @@ -231,7 +239,7 @@ def _train_model(

self._init_task(dataset)

cfg = self.configure(True, "train", None)
cfg = self.configure(True, "train", None, get_dataset(dataset, Subset.TRAINING))
logger.info("train!")

timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
Expand All @@ -246,14 +254,6 @@ def _train_model(
# Data
datasets = [build_dataset(cfg.data.train)]

# TODO. This should be moved to configurer
# TODO. Anchor clustering should be checked
# if hasattr(cfg, "hparams"):
# if cfg.hparams.get("adaptive_anchor", False):
# num_ratios = cfg.hparams.get("num_anchor_ratios", 5)
# proposal_ratio = extract_anchor_ratio(datasets[0], num_ratios)
# self.configure_anchor(cfg, proposal_ratio)

# Target classes
if "task_adapt" in cfg:
target_classes = cfg.task_adapt.get("final", [])
Expand All @@ -271,8 +271,6 @@ def _train_model(
mmdet_version=__version__ + get_git_hash()[:7],
CLASSES=target_classes,
)
# if "proposal_ratio" in locals():
# cfg.checkpoint_config.meta.update({"anchor_ratio": proposal_ratio})

# Model
model = self.build_model(cfg, fp16=cfg.get("fp16", False))
Expand Down Expand Up @@ -658,9 +656,9 @@ def save_model(self, output_model: ModelEntity):
"confidence_threshold": self.confidence_threshold,
"VERSION": 1,
}
if self._recipe_cfg is not None and should_cluster_anchors(self._recipe_cfg):
if self.config is not None and should_cluster_anchors(self.config):
modelinfo["anchors"] = {}
self._update_anchors(modelinfo["anchors"], self._recipe_cfg.model.bbox_head.anchor_generator)
self._update_anchors(modelinfo["anchors"], self.config.model.bbox_head.anchor_generator)

torch.save(modelinfo, buffer)
output_model.set_data("weights.pth", buffer.getvalue())
Expand Down
2 changes: 2 additions & 0 deletions otx/algorithms/detection/adapters/mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
patch_tiling,
prepare_for_training,
set_hyperparams,
should_cluster_anchors,
)

__all__ = [
Expand All @@ -29,4 +30,5 @@
"patch_input_preprocessing",
"patch_input_shape",
"patch_ir_scale_factor",
"should_cluster_anchors",
]
1 change: 1 addition & 0 deletions otx/recipes/stages/detection/incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
type="mpa",
op="REPLACE",
efficient_mode=False,
use_mpa_anchor=True,
)

runner = dict(max_epochs=30)
Expand Down
36 changes: 24 additions & 12 deletions tests/unit/algorithms/detection/adapters/mmdet/test_configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import tempfile
from mmcv.utils import ConfigDict

from otx.api.entities.model_template import TaskType
from otx.algorithms.common.adapters.mmcv.utils.config_utils import MPAConfig
from otx.algorithms.detection.adapters.mmdet.configurer import (
DetectionConfigurer,
IncrDetectionConfigurer,
SemiSLDetectionConfigurer,
)
from tests.test_suite.e2e_test_system import e2e_pytest_unit
from tests.unit.algorithms.detection.test_helpers import DEFAULT_DET_TEMPLATE_DIR
from tests.unit.algorithms.detection.test_helpers import (
DEFAULT_DET_TEMPLATE_DIR,
generate_det_dataset,
)


class TestDetectionConfigurer:
Expand All @@ -21,6 +25,7 @@ def setup(self) -> None:
self.configurer = DetectionConfigurer()
self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "model.py"))
self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "data_pipeline.py"))
self.det_dataset, self.det_labels = generate_det_dataset(TaskType.DETECTION, 100)

@e2e_pytest_unit
def test_configure(self, mocker):
Expand All @@ -37,13 +42,13 @@ def test_configure(self, mocker):

model_cfg = copy.deepcopy(self.model_cfg)
data_cfg = copy.deepcopy(self.data_cfg)
returned_value = self.configurer.configure(model_cfg, "", data_cfg, True)
returned_value = self.configurer.configure(model_cfg, self.det_dataset, "", data_cfg, True)
mock_cfg_base.assert_called_once_with(model_cfg, data_cfg, None, None)
mock_cfg_device.assert_called_once_with(model_cfg, True)
mock_cfg_model.assert_called_once_with(model_cfg, None)
mock_cfg_ckpt.assert_called_once_with(model_cfg, "")
mock_cfg_regularization.assert_called_once_with(model_cfg, True)
mock_cfg_task.assert_called_once_with(model_cfg, True)
mock_cfg_task.assert_called_once_with(model_cfg, self.det_dataset, True)
mock_cfg_hook.assert_called_once_with(model_cfg)
mock_cfg_gpu.assert_called_once_with(model_cfg, "train")
mock_cfg_fp16_optimizer.assert_called_once_with(model_cfg)
Expand Down Expand Up @@ -161,19 +166,24 @@ def test_configure_data(self, mocker):

@e2e_pytest_unit
def test_configure_task(self, mocker):
ssd_dir = os.path.join("otx/algorithms/detection/configs/detection", "mobilenetv2_ssd")
ssd_cfg = MPAConfig.fromfile(os.path.join(ssd_dir, "model.py"))
ssd_cfg.task_adapt = {"type": "mpa", "op": "REPLACE", "use_mpa_anchor": True}
model_cfg = copy.deepcopy(ssd_cfg)
self.configurer.configure_task(model_cfg, self.det_dataset, True)
assert model_cfg.model.bbox_head.anchor_generator != ssd_cfg.model.bbox_head.anchor_generator

model_cfg = copy.deepcopy(self.model_cfg)
model_cfg.task_adapt = {"type": "mpa", "op": "REPLACE", "use_mpa_anchor": True}
self.configurer.configure_task(model_cfg, True)

model_cfg.model.bbox_head.type = "ATSSHead"
self.configurer.configure_task(model_cfg, True)
self.configurer.configure_task(model_cfg, self.det_dataset, True)

model_cfg.model.bbox_head.type = "VFNetHead"
self.configurer.configure_task(model_cfg, True)
self.configurer.configure_task(model_cfg, self.det_dataset, True)

model_cfg.model.bbox_head.type = "YOLOXHead"
model_cfg.data.train.type = "MultiImageMixDataset"
self.configurer.configure_task(model_cfg, True)
self.configurer.configure_task(model_cfg, self.det_dataset, True)

def mock_configure_classes(*args, **kwargs):
return True
Expand All @@ -182,7 +192,7 @@ def mock_configure_classes(*args, **kwargs):
self.configurer.model_classes = []
self.configurer.data_classes = ["red", "green"]
self.configurer.configure_classes = mock_configure_classes
self.configurer.configure_task(model_cfg, True)
self.configurer.configure_task(model_cfg, self.det_dataset, True)

@e2e_pytest_unit
def test_configure_hook(self):
Expand Down Expand Up @@ -244,12 +254,13 @@ def setup(self) -> None:
self.configurer = IncrDetectionConfigurer()
self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "model.py"))
self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "data_pipeline.py"))
self.det_dataset, self.det_labels = generate_det_dataset(TaskType.DETECTION, 100)

def test_configure_task(self, mocker):
mocker.patch.object(DetectionConfigurer, "configure_task")
self.model_cfg.task_adapt = {}
self.configurer.task_adapt_type = "mpa"
self.configurer.configure_task(self.model_cfg, True)
self.configurer.configure_task(self.model_cfg, self.det_dataset, True)
assert self.model_cfg.custom_hooks[1].type == "TaskAdaptHook"
assert self.model_cfg.custom_hooks[1].sampler_flag is False

Expand All @@ -260,6 +271,7 @@ def setup(self) -> None:
self.configurer = SemiSLDetectionConfigurer()
self.model_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "model.py"))
self.data_cfg = MPAConfig.fromfile(os.path.join(DEFAULT_DET_TEMPLATE_DIR, "data_pipeline.py"))
self.det_dataset, self.det_labels = generate_det_dataset(TaskType.DETECTION, 100)

def test_configure_data(self, mocker):
mocker.patch.object(DetectionConfigurer, "configure_data")
Expand All @@ -272,7 +284,7 @@ def test_configure_data(self, mocker):

def test_configure_task(self):
self.model_cfg.task_adapt = {"type": "mpa", "op": "REPLACE", "use_mpa_anchor": True}
self.configurer.configure_task(self.model_cfg, True)
self.configurer.configure_task(self.model_cfg, self.det_dataset, True)

self.model_cfg.task_adapt = {"type": "not_mpa", "op": "REPLACE", "use_mpa_anchor": True}
self.configurer.configure_task(self.model_cfg, True)
self.configurer.configure_task(self.model_cfg, self.det_dataset, True)
51 changes: 51 additions & 0 deletions tests/unit/algorithms/detection/adapters/mmdet/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,54 @@ def test_explain(self, mocker):
explain_predicted_classes=True,
)
outputs = self.det_task.explain(self.det_dataset, explain_parameters)

@e2e_pytest_unit
def test_anchor_clustering(self, mocker):

ssd_dir = os.path.join("otx/algorithms/detection/configs/detection", "mobilenetv2_ssd")
ssd_cfg = MPAConfig.fromfile(os.path.join(ssd_dir, "model.py"))
model_template = parse_model_template(os.path.join(ssd_dir, "template.yaml"))
hyper_parameters = create(model_template.hyper_parameters.data)
hyper_parameters.learning_parameters.auto_num_workers = True
task_env = init_environment(hyper_parameters, model_template, task_type=TaskType.DETECTION)

det_task = MMDetectionTask(task_env)

def _mock_train_detector_det(*args, **kwargs):
with open(os.path.join(self.det_task._output_path, "latest.pth"), "wb") as f:
torch.save({"dummy": torch.randn(1, 3, 3, 3)}, f)

mocker.patch(
"otx.algorithms.detection.adapters.mmdet.task.build_dataset",
return_value=MockDataset(self.det_dataset, "det"),
)
mocker.patch(
"otx.algorithms.detection.adapters.mmdet.task.build_dataloader",
return_value=MockDataLoader(self.det_dataset),
)
mocker.patch(
"otx.algorithms.detection.adapters.mmdet.task.patch_data_pipeline",
return_value=True,
)
mocker.patch(
"otx.algorithms.detection.adapters.mmdet.task.train_detector",
side_effect=_mock_train_detector_det,
)

det_task._train_model(self.det_dataset)
assert ssd_cfg.model.bbox_head.anchor_generator != det_task.config.model.bbox_head.anchor_generator

mocker.patch(
"otx.algorithms.detection.adapters.mmdet.task.single_gpu_test",
return_value=[
np.array([np.array([[0, 0, 1, 1, 0.1]]), np.array([[0, 0, 1, 1, 0.2]]), np.array([[0, 0, 1, 1, 0.7]])])
]
* 100,
)
mocker.patch(
"otx.algorithms.detection.adapters.mmdet.task.FeatureVectorHook",
return_value=nullcontext(),
)
inference_parameters = InferenceParameters(is_evaluation=True)
det_task._infer_model(self.det_dataset, inference_parameters)
assert ssd_cfg.model.bbox_head.anchor_generator != det_task.config.model.bbox_head.anchor_generator
Loading

0 comments on commit 8a3efaa

Please sign in to comment.