Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mergeback: Label addtion/deletion 1.2.4 --> 1.4.0 #2326

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ def evaluate(
)

eval_results["MHAcc"] = total_acc
eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"]
if self.hierarchical_info["num_multiclass_heads"] > 0:
eval_results["avgClsAcc"] = total_acc_sl / self.hierarchical_info["num_multiclass_heads"]
else:
eval_results["avgClsAcc"] = total_acc_sl
eval_results["mAP"] = mAP_value
eval_results["accuracy"] = total_acc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
logger = get_logger()


def is_hierarchical_chkpt(chkpt: dict):
"""Detect whether previous checkpoint is hierarchical or not."""
for k, v in chkpt.items():
if "fc" in k:
return True
return False


@CLASSIFIERS.register_module()
class SAMImageClassifier(SAMClassifierMixin, ClsLossDynamicsTrackingMixin, ImageClassifier):
"""SAM-enabled ImageClassifier."""
Expand Down Expand Up @@ -193,11 +201,19 @@ def load_state_dict_pre_hook(module, state_dict, prefix, *args, **kwargs): # no
def load_state_dict_mixing_hook(
model, model_classes, chkpt_classes, chkpt_dict, prefix, *args, **kwargs
): # pylint: disable=unused-argument, too-many-branches, too-many-locals
"""Modify input state_dict according to class name matching before weight loading."""
"""Modify input state_dict according to class name matching before weight loading.

If previous training is hierarchical training,
then the current training should be hierarchical training. vice versa.

"""
backbone_type = type(model.backbone).__name__
if backbone_type not in ["OTXMobileNetV3", "OTXEfficientNet", "OTXEfficientNetV2"]:
return

if model.hierarchical != is_hierarchical_chkpt(chkpt_dict):
return

# Dst to src mapping index
model_classes = list(model_classes)
chkpt_classes = list(chkpt_classes)
Expand Down Expand Up @@ -249,13 +265,15 @@ def load_state_dict_mixing_hook(
continue

# Mix weights
chkpt_param = chkpt_dict[chkpt_name]
for module, c in enumerate(model2chkpt):
if c >= 0:
model_param[module].copy_(chkpt_param[c])
# NOTE: Label mix is not supported for H-label classification.
if not model.hierarchical:
chkpt_param = chkpt_dict[chkpt_name]
for module, c in enumerate(model2chkpt):
if c >= 0:
model_param[module].copy_(chkpt_param[c])

# Replace checkpoint weight by mixed weights
chkpt_dict[chkpt_name] = model_param
# Replace checkpoint weight by mixed weights
chkpt_dict[chkpt_name] = model_param

def extract_feat(self, img):
"""Directly extract features from the backbone + neck.
Expand Down
22 changes: 15 additions & 7 deletions src/otx/algorithms/classification/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from otx.api.entities.inference_parameters import (
default_progress_callback as default_infer_progress_callback,
)
from otx.api.entities.label import LabelEntity
from otx.api.entities.label_schema import LabelGroup
from otx.api.entities.metadata import FloatMetadata, FloatType
from otx.api.entities.metrics import (
CurveMetric,
Expand Down Expand Up @@ -125,16 +127,22 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
if self._task_environment.model is not None:
self._load_model()

def _is_multi_label(self, label_groups: List[LabelGroup], all_labels: List[LabelEntity]):
"""Check whether the current training mode is multi-label or not."""
# NOTE: In the current Geti, multi-label should have `___` symbol for all group names.
find_multilabel_symbol = ["___" in getattr(i, "name", "") for i in label_groups]
return (
(len(label_groups) > 1) and (len(label_groups) == len(all_labels)) and (False not in find_multilabel_symbol)
)

def _set_train_mode(self):
self._multilabel = len(self._task_environment.label_schema.get_groups(False)) > 1 and len(
self._task_environment.label_schema.get_groups(False)
) == len(
self._task_environment.get_labels(include_empty=False)
) # noqa:E127
label_groups = self._task_environment.label_schema.get_groups(include_empty=False)
all_labels = self._task_environment.label_schema.get_labels(include_empty=False)

self._multilabel = self._is_multi_label(label_groups, all_labels)
if self._multilabel:
logger.info("Classification mode: multilabel")

if not self._multilabel and len(self._task_environment.label_schema.get_groups(False)) > 1:
elif len(label_groups) > 1:
logger.info("Classification mode: hierarchical")
self._hierarchical = True
self._hierarchical_info = get_hierarchical_info(self._task_environment.label_schema)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
{
"info": {},
"categories": {
"label": {
"labels": [
{
"name": "right",
"parent": "triangle",
"attributes": []
},
{
"name": "multi a",
"parent": "triangle",
"attributes": []
},
{
"name": "equilateral",
"parent": "triangle",
"attributes": []
},
{
"name": "square",
"parent": "rectangle",
"attributes": []
},
{
"name": "triangle",
"parent": "",
"attributes": []
},
{
"name": "non_square",
"parent": "rectangle",
"attributes": []
},
{
"name": "rectangle",
"parent": "",
"attributes": []
}
],
"label_groups": [
{
"name": "shape",
"group_type": "exclusive",
"labels": ["rectangle", "triangle"]
},
{
"name": "rectangle default",
"group_type": "exclusive",
"labels": ["non_square", "square"]
},
{
"name": "triangle default",
"group_type": "exclusive",
"labels": ["equilateral", "right"]
},
{
"name": "shape___multiple example___multi a",
"group_type": "exclusive",
"labels": ["multi a"]
}
],
"attributes": []
},
"mask": {
"colormap": [
{
"label_id": 0,
"r": 129,
"g": 64,
"b": 123
},
{
"label_id": 1,
"r": 91,
"g": 105,
"b": 255
},
{
"label_id": 2,
"r": 91,
"g": 105,
"b": 255
},
{
"label_id": 3,
"r": 255,
"g": 86,
"b": 98
},
{
"label_id": 4,
"r": 204,
"g": 148,
"b": 218
},
{
"label_id": 5,
"r": 0,
"g": 251,
"b": 87
},
{
"label_id": 6,
"r": 84,
"g": 143,
"b": 173
}
]
}
},
"items": [
{
"id": "a",
"annotations": [
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 4
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 5
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 1
}
],
"image": {
"path": "a.jpg",
"size": [10, 5]
},
"media": {
"path": ""
}
},
{
"id": "b",
"annotations": [
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 6
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 5
},
{
"id": 0,
"type": "label",
"attributes": {},
"group": 0,
"label_id": 2
}
],
"image": {
"path": "b.jpg",
"size": [10, 5]
},
"media": {
"path": ""
}
}
]
}
Loading