Skip to content

Commit

Permalink
Rename SAMImageClassifier -> CustomImageClassifier (#2384)
Browse files Browse the repository at this point in the history
* Rename SAMImageClassifier -> CustomImageClassifier

Signed-off-by: Songki Choi <[email protected]>

* Fix module name sam_classifier -> custom_image_classifier

Signed-off-by: Songki Choi <[email protected]>

* Fix pre-commit

Signed-off-by: Songki Choi <[email protected]>

---------

Signed-off-by: Songki Choi <[email protected]>
  • Loading branch information
goodsong81 authored Jul 24, 2023
1 parent 6551637 commit 5053957
Show file tree
Hide file tree
Showing 28 changed files with 112 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def _configure_dataloader(cfg):
"MPAHierarchicalClsDataset",
"ClsTVDataset",
]
WEIGHT_MIX_CLASSIFIER = ["SAMImageClassifier"]
WEIGHT_MIX_CLASSIFIER = ["CustomImageClassifier"]


class IncrClassificationConfigurer(ClassificationConfigurer):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
"""OTX Algorithms - Classification Models."""

# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from .classifiers import BYOL, SAMImageClassifier, SemiSLClassifier, SupConClassifier
from .classifiers import BYOL, CustomImageClassifier, SemiSLClassifier, SupConClassifier
from .heads import (
ClsHead,
ConstrastiveHead,
Expand All @@ -41,7 +31,7 @@

__all__ = [
"BYOL",
"SAMImageClassifier",
"CustomImageClassifier",
"SemiSLClassifier",
"SupConClassifier",
"CustomLinearClsHead",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
"""OTX Algorithms - Classification Classifiers."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from .byol import BYOL
from .sam_classifier import SAMImageClassifier
from .custom_image_classifier import CustomImageClassifier
from .semisl_classifier import SemiSLClassifier
from .semisl_multilabel_classifier import SemiSLMultilabelClassifier
from .supcon_classifier import SupConClassifier

__all__ = ["BYOL", "SAMImageClassifier", "SemiSLClassifier", "SemiSLMultilabelClassifier", "SupConClassifier"]
__all__ = ["BYOL", "CustomImageClassifier", "SemiSLClassifier", "SemiSLMultilabelClassifier", "SupConClassifier"]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Module for defining SAMClassifier for classification task."""
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from functools import partial
Expand All @@ -17,7 +17,7 @@


@CLASSIFIERS.register_module()
class SAMImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier):
class CustomImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier):
"""SAM-enabled ImageClassifier."""

def __init__(self, task_adapt=None, **kwargs):
Expand Down Expand Up @@ -283,7 +283,7 @@ def extract_feat(self, img):
)

@FUNCTION_REWRITER.register_rewriter(
"otx.algorithms.classification.adapters.mmcls.models.classifiers.SAMImageClassifier.extract_feat"
"otx.algorithms.classification.adapters.mmcls.models.classifiers.CustomImageClassifier.extract_feat"
)
def sam_image_classifier__extract_feat(ctx, self, img): # pylint: disable=unused-argument
"""Feature extraction function for SAMClassifier with mmdeploy."""
Expand All @@ -298,7 +298,7 @@ def sam_image_classifier__extract_feat(ctx, self, img): # pylint: disable=unuse
return feat, backbone_feat

@FUNCTION_REWRITER.register_rewriter(
"otx.algorithms.classification.adapters.mmcls.models.classifiers.SAMImageClassifier.simple_test"
"otx.algorithms.classification.adapters.mmcls.models.classifiers.CustomImageClassifier.simple_test"
)
def sam_image_classifier__simple_test(ctx, self, img, img_metas): # pylint: disable=unused-argument
"""Simple test function used for inference for SAMClassifier with mmdeploy."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Module for defining a semi-supervised classifier using mmcls."""
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

Expand All @@ -8,13 +8,13 @@

from otx.algorithms.common.utils.logger import get_logger

from .sam_classifier import SAMImageClassifier
from .custom_image_classifier import CustomImageClassifier

logger = get_logger()


@CLASSIFIERS.register_module()
class SemiSLClassifier(SAMImageClassifier):
class SemiSLClassifier(CustomImageClassifier):
"""Semi-SL Classifier.
This classifier supports unlabeled data by overriding forward_train
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""Module for defining a semi-supervised multi-label classifier using mmcls."""
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
#

from mmcls.models.builder import CLASSIFIERS

from otx.algorithms.common.utils.logger import get_logger

from .sam_classifier import SAMImageClassifier
from .custom_image_classifier import CustomImageClassifier

logger = get_logger()


@CLASSIFIERS.register_module()
class SemiSLMultilabelClassifier(SAMImageClassifier):
class SemiSLMultilabelClassifier(CustomImageClassifier):
"""Semi-SL Multilabel Classifier which supports unlabeled data by overriding forward_train."""

def forward_train(self, img, gt_label, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion src/otx/algorithms/classification/configs/deit_tiny/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""deit-tiny for multi-class config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/deit.py"]
ckpt_url = "https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth"

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(arch="deit-tiny", init_cfg=dict(type="Pretrained", checkpoint=ckpt_url, prefix="backbone")),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""deit-tiny for hierarchical config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/deit.py"]
ckpt_url = "https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth"

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(arch="deit-tiny", init_cfg=dict(type="Pretrained", checkpoint=ckpt_url, prefix="backbone")),
head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""deit-tiny for multi-label config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/multilabel/incremental.yaml", "../base/models/deit.py"]
ckpt_url = "https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth"

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(arch="deit-tiny", init_cfg=dict(type="Pretrained", checkpoint=ckpt_url, prefix="backbone")),
head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""EfficientNet-B0 for multi-class config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/efficientnet.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
version="b0",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""EfficientNet-B0 for hierarchical config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/efficientnet.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
version="b0",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""EfficientNet-B0 for multi-label config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/multilabel/incremental.yaml", "../base/models/efficientnet.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
version="b0",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""EfficientNet-V2 for multi-class config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/efficientnet_v2.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
version="s_21k",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""EfficientNet-V2 for hierarchical config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/efficientnet_v2.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(version="s_21k"),
head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""EfficientNet-V2 for multi-label config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/multilabel/incremental.yaml", "../base/models/efficientnet_v2.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
version="s_21k",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""MobileNet-V3-large-075 for multi-class config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/mobilenet_v3.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
mode="large",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""MobileNet-V3-large-075 for hierarchical config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/mobilenet_v3.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
mode="large",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""MobileNet-V3-large-075 for multi-label config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/multilabel/incremental.yaml", "../base/models/mobilenet_v3.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(
mode="large",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""MobileNet-V3-large-1 for multi-class config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/mobilenet_v3.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(mode="large"),
head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""MobileNet-V3-large-1 for hierarchical config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/incremental.yaml", "../base/models/mobilenet_v3.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(mode="large"),
head=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""MobileNet-V3-large-1 for multi-label config."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

# pylint: disable=invalid-name

_base_ = ["../../../../recipes/stages/classification/multilabel/incremental.yaml", "../base/models/mobilenet_v3.py"]

model = dict(
type="SAMImageClassifier",
type="CustomImageClassifier",
task="classification",
backbone=dict(mode="large"),
head=dict(
Expand Down
Loading

0 comments on commit 5053957

Please sign in to comment.