Skip to content

Commit

Permalink
Change load_stat_dict to on_load_checkpoint (#3443)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegukhyun authored May 6, 2024
1 parent c21e26e commit ad1a5ec
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _convert_pred_entity_to_compute_metric(
],
}

def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None:
"""Load state_dict from checkpoint.
For detection, it is need to update confidence threshold information when
Expand All @@ -148,7 +148,7 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None))
):
self.hparams["best_confidence_threshold"] = best_confidence_threshold
super().load_state_dict(ckpt, *args, **kwargs)
super().on_load_checkpoint(ckpt)

def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
if key == "val":
Expand Down Expand Up @@ -539,7 +539,7 @@ def _create_model(self) -> Model:

if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]):
best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value
self.hparams["best_confidence_threshold"] = best_confidence_threshold
self.hparams["best_confidence_threshold"] = float(best_confidence_threshold)
else:
msg = (
"Cannot get best_confidence_threshold from OpenVINO IR's rt_info. "
Expand Down
14 changes: 7 additions & 7 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def _export_parameters(self) -> TaskLevelExportParameters:
return super()._export_parameters.wrap(
model_type="MaskRCNN",
task_type="instance_segmentation",
confidence_threshold=self.hparams.get("best_confidence_threshold", 0.0),
confidence_threshold=self.hparams.get("best_confidence_threshold", None),
iou_threshold=0.5,
tile_config=self.tile_config if self.tile_config.enable_tiler else None,
)

def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None:
"""Load state_dict from checkpoint.
For detection, it is need to update confidence threshold information when
Expand All @@ -129,7 +129,7 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None))
):
self.hparams["best_confidence_threshold"] = best_confidence_threshold
super().load_state_dict(ckpt, *args, **kwargs)
super().on_load_checkpoint(ckpt)

def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
if key == "val":
Expand Down Expand Up @@ -597,16 +597,16 @@ def _create_model(self) -> Model:

if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]):
best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value
self.hparams["best_confidence_threshold"] = best_confidence_threshold
self.hparams["best_confidence_threshold"] = float(best_confidence_threshold)
else:
msg = (
"Cannot get best_confidence_threshold from OpenVINO IR's rt_info. "
"Please check whether this model is trained by OTX or not. "
"Without this information, it can produce a wrong F1 metric score. "
"At this time, it will be set as the default value = 0.0."
"At this time, it will be set as the default value = None."
)
log.warning(msg)
self.hparams["best_confidence_threshold"] = 0.0
self.hparams["best_confidence_threshold"] = None

return Model.create_model(model_adapter, model_type=self.model_type, configuration=self.model_api_configuration)

Expand Down Expand Up @@ -729,6 +729,6 @@ def _convert_pred_entity_to_compute_metric(
return {"preds": pred_info, "target": target_info}

def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None:
best_confidence_threshold = self.hparams.get("best_confidence_threshold", 0.0)
best_confidence_threshold = self.hparams.get("best_confidence_threshold", None)
compute_kwargs = {"best_confidence_threshold": best_confidence_threshold}
return super()._log_metrics(meter, key, **compute_kwargs)
2 changes: 1 addition & 1 deletion tests/unit/core/model/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_configure_metric_with_ckpt(
metric=FMeasureCallable,
)

model.load_state_dict(mock_ckpt)
model.on_load_checkpoint(mock_ckpt)

assert model.hparams["best_confidence_threshold"] == 0.35

Expand Down

0 comments on commit ad1a5ec

Please sign in to comment.