Skip to content

Commit

Permalink
Refactor MaskRCNN rotated detection (#3858)
Browse files Browse the repository at this point in the history
* Refactor MaskRCNN rotated detection

* update rotated ov model recipe

* update rotated det recipes
  • Loading branch information
eugene123tw authored Aug 20, 2024
1 parent 8e6f831 commit 4f696f1
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 104 deletions.
13 changes: 11 additions & 2 deletions src/otx/core/model/rotated_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from datumaro import Polygon
from torchvision import tv_tensors

from otx.algo.instance_segmentation.maskrcnn import MaskRCNN, MaskRCNNEfficientNet, MaskRCNNResNet50
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity
from otx.core.model.instance_segmentation import OTXInstanceSegModel, OVInstanceSegmentationModel
from otx.core.model.instance_segmentation import OVInstanceSegmentationModel


class OTXRotatedDetModel(OTXInstanceSegModel):
class RotatedMaskRCNNModel(MaskRCNN):
"""Base class for the rotated detection models used in OTX."""

def predict_step(self, *args: torch.Any, **kwargs: torch.Any) -> InstanceSegBatchPredEntity:
Expand Down Expand Up @@ -93,6 +94,14 @@ def predict_step(self, *args: torch.Any, **kwargs: torch.Any) -> InstanceSegBatc
)


class RotatedMaskRCNNResNet50(RotatedMaskRCNNModel, MaskRCNNResNet50):
"""Rotated MaskRCNN model with ResNet50 backbone."""


class RotatedMaskRCNNEfficientNet(RotatedMaskRCNNModel, MaskRCNNEfficientNet):
"""Rotated MaskRCNN model with EfficientNet backbone."""


class OVRotatedDetectionModel(OVInstanceSegmentationModel):
"""Rotated Detection model compatible for OpenVINO IR Inference.
Expand Down
91 changes: 0 additions & 91 deletions src/otx/recipe/_base_/data/rotated_detection.yaml

This file was deleted.

19 changes: 15 additions & 4 deletions src/otx/recipe/rotated_detection/maskrcnn_efficientnetb2b.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNNEfficientNet
class_path: otx.core.model.rotated_detection.RotatedMaskRCNNEfficientNet
init_args:
label_info: 80

Expand Down Expand Up @@ -28,29 +28,40 @@ engine:

callback_monitor: val/map_50

data: ../_base_/data/rotated_detection.yaml
data: ../_base_/data/instance_segmentation.yaml
overrides:
task: ROTATED_DETECTION
max_epochs: 100
data:
train_subset:
batch_size: 4
num_workers: 8
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
- class_path: torchvision.transforms.v2.Normalize
init_args:
std: [1.0, 1.0, 1.0]
sampler:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
batch_size: 1
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
- class_path: torchvision.transforms.v2.Normalize
init_args:
std: [1.0, 1.0, 1.0]

test_subset:
batch_size: 1
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
- class_path: torchvision.transforms.v2.Normalize
init_args:
std: [1.0, 1.0, 1.0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
model:
class_path: otx.core.model.rotated_detection.RotatedMaskRCNNEfficientNet
init_args:
label_info: 80

optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.007
momentum: 0.9
weight_decay: 0.001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 100
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.1
patience: 9
monitor: val/map_50

engine:
task: ROTATED_DETECTION
device: auto

callback_monitor: val/map_50

data: ../_base_/data/instance_segmentation.yaml
overrides:
task: ROTATED_DETECTION
max_epochs: 100
data:
input_size:
- 512
- 512
tile_config:
enable_tiler: true
enable_adaptive_tiling: true

train_subset:
batch_size: 4
num_workers: 8
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: false
size_divisor: 32
- class_path: torchvision.transforms.v2.Normalize
init_args:
std: [1.0, 1.0, 1.0]
sampler:
class_path: otx.algo.samplers.balanced_sampler.BalancedSampler

val_subset:
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: false
size_divisor: 32
- class_path: torchvision.transforms.v2.Normalize
init_args:
std: [1.0, 1.0, 1.0]

test_subset:
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: false
size_divisor: 32
- class_path: torchvision.transforms.v2.Normalize
init_args:
std: [1.0, 1.0, 1.0]
25 changes: 20 additions & 5 deletions src/otx/recipe/rotated_detection/maskrcnn_r50.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
class_path: otx.algo.instance_segmentation.maskrcnn.MaskRCNNResNet50
class_path: otx.core.model.rotated_detection.RotatedMaskRCNNResNet50
init_args:
label_info: 80

Expand All @@ -19,7 +19,7 @@ model:
init_args:
mode: max
factor: 0.1
patience: 9
patience: 4
monitor: val/map_50

engine:
Expand All @@ -28,15 +28,30 @@ engine:

callback_monitor: val/map_50

data: ../_base_/data/rotated_detection.yaml
data: ../_base_/data/instance_segmentation.yaml
overrides:
task: ROTATED_DETECTION
max_epochs: 100
gradient_clip_val: 35.0
data:
train_subset:
batch_size: 4
num_workers: 8
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32

val_subset:
batch_size: 1
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32

test_subset:
batch_size: 1
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
79 changes: 79 additions & 0 deletions src/otx/recipe/rotated_detection/maskrcnn_r50_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
model:
class_path: otx.core.model.rotated_detection.RotatedMaskRCNNResNet50
init_args:
label_info: 80

optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.007
momentum: 0.9
weight_decay: 0.001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 100
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: ROTATED_DETECTION
device: auto

callback_monitor: val/map_50

data: ../_base_/data/instance_segmentation.yaml
overrides:
task: ROTATED_DETECTION
max_epochs: 100
gradient_clip_val: 35.0
data:
input_size:
- 512
- 512
tile_config:
enable_tiler: true
enable_adaptive_tiling: true

train_subset:
batch_size: 4
num_workers: 8
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: false
size_divisor: 32

val_subset:
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: false
size_divisor: 32

test_subset:
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
scale: $(input_size)
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: false
size_divisor: 32
Loading

0 comments on commit 4f696f1

Please sign in to comment.