diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 0246ebc0959..53d45a474d7 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -137,7 +137,7 @@ def _convert_pred_entity_to_compute_metric( ], } - def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: + def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None: """Load state_dict from checkpoint. For detection, it is need to update confidence threshold information when @@ -148,7 +148,7 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None)) ): self.hparams["best_confidence_threshold"] = best_confidence_threshold - super().load_state_dict(ckpt, *args, **kwargs) + super().on_load_checkpoint(ckpt) def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: if key == "val": @@ -539,7 +539,7 @@ def _create_model(self) -> Model: if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]): best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value - self.hparams["best_confidence_threshold"] = best_confidence_threshold + self.hparams["best_confidence_threshold"] = float(best_confidence_threshold) else: msg = ( "Cannot get best_confidence_threshold from OpenVINO IR's rt_info. " diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 58ca328bc95..8cea389d233 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -113,12 +113,12 @@ def _export_parameters(self) -> TaskLevelExportParameters: return super()._export_parameters.wrap( model_type="MaskRCNN", task_type="instance_segmentation", - confidence_threshold=self.hparams.get("best_confidence_threshold", 0.0), + confidence_threshold=self.hparams.get("best_confidence_threshold", None), iou_threshold=0.5, tile_config=self.tile_config if self.tile_config.enable_tiler else None, ) - def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: + def on_load_checkpoint(self, ckpt: dict[str, Any]) -> None: """Load state_dict from checkpoint. For detection, it is need to update confidence threshold information when @@ -129,7 +129,7 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: and (best_confidence_threshold := hyper_parameters.get("best_confidence_threshold", None)) ): self.hparams["best_confidence_threshold"] = best_confidence_threshold - super().load_state_dict(ckpt, *args, **kwargs) + super().on_load_checkpoint(ckpt) def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: if key == "val": @@ -597,16 +597,16 @@ def _create_model(self) -> Model: if model_adapter.model.has_rt_info(["model_info", "confidence_threshold"]): best_confidence_threshold = model_adapter.model.get_rt_info(["model_info", "confidence_threshold"]).value - self.hparams["best_confidence_threshold"] = best_confidence_threshold + self.hparams["best_confidence_threshold"] = float(best_confidence_threshold) else: msg = ( "Cannot get best_confidence_threshold from OpenVINO IR's rt_info. " "Please check whether this model is trained by OTX or not. " "Without this information, it can produce a wrong F1 metric score. " - "At this time, it will be set as the default value = 0.0." + "At this time, it will be set as the default value = None." ) log.warning(msg) - self.hparams["best_confidence_threshold"] = 0.0 + self.hparams["best_confidence_threshold"] = None return Model.create_model(model_adapter, model_type=self.model_type, configuration=self.model_api_configuration) @@ -729,6 +729,6 @@ def _convert_pred_entity_to_compute_metric( return {"preds": pred_info, "target": target_info} def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwargs) -> None: - best_confidence_threshold = self.hparams.get("best_confidence_threshold", 0.0) + best_confidence_threshold = self.hparams.get("best_confidence_threshold", None) compute_kwargs = {"best_confidence_threshold": best_confidence_threshold} return super()._log_metrics(meter, key, **compute_kwargs) diff --git a/tests/unit/core/model/test_detection.py b/tests/unit/core/model/test_detection.py index 7ef81129cf1..61fac037603 100644 --- a/tests/unit/core/model/test_detection.py +++ b/tests/unit/core/model/test_detection.py @@ -72,7 +72,7 @@ def test_configure_metric_with_ckpt( metric=FMeasureCallable, ) - model.load_state_dict(mock_ckpt) + model.on_load_checkpoint(mock_ckpt) assert model.hparams["best_confidence_threshold"] == 0.35