From d85f5da40214ebd836f34061235550dd3154ffbb Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Wed, 23 Oct 2024 17:39:40 +0900 Subject: [PATCH] Fix applying model's hparams when loading model from checkpoint (#4057) --- CHANGELOG.md | 2 ++ src/otx/engine/engine.py | 40 ++++++++++++++++++++++++++++---- tests/unit/engine/test_engine.py | 6 +---- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0331d198e8..c15028bc356 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,6 +100,8 @@ All notable changes to this project will be documented in this file. () - Upgrade MAPI in 2.2 () +- Fix applying model's hparams when loading model from checkpoint + () ## \[v2.1.0\] diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 3c0addd547b..003b2f29112 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -367,8 +367,14 @@ def test( # NOTE, trainer.test takes only lightning based checkpoint. # So, it can't take the OTX1.x checkpoint. if checkpoint is not None and not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + model_cls = model.__class__ - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams) + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) if model.label_info != self.datamodule.label_info: if ( @@ -462,8 +468,14 @@ def predict( datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") if checkpoint is not None and not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + model_cls = model.__class__ - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams) + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) if model.label_info != self.datamodule.label_info: msg = ( @@ -574,11 +586,17 @@ def export( ) if not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + model_cls = self.model.__class__ self.model = model_cls.load_from_checkpoint( checkpoint_path=checkpoint, map_location="cpu", - **self.model.hparams, + **kwargs_user_input, ) self.model.eval() @@ -742,8 +760,14 @@ def explain( model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) if checkpoint is not None and not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + model_cls = model.__class__ - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams) + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) if model.label_info != self.datamodule.label_info: msg = ( @@ -845,11 +869,17 @@ def benchmark( ) if not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + model_cls = self.model.__class__ self.model = model_cls.load_from_checkpoint( checkpoint_path=checkpoint, map_location="cpu", - **self.model.hparams, + **kwargs_user_input, ) elif isinstance(self.model, OVModel): msg = "To run benchmark on OV model, checkpoint must be specified." diff --git a/tests/unit/engine/test_engine.py b/tests/unit/engine/test_engine.py index 879987f19cc..3adcc5678d7 100644 --- a/tests/unit/engine/test_engine.py +++ b/tests/unit/engine/test_engine.py @@ -223,11 +223,7 @@ def test_exporting(self, fxt_engine, mocker) -> None: checkpoint = "path/to/checkpoint.ckpt" fxt_engine.checkpoint = checkpoint fxt_engine.export() - mock_load_from_checkpoint.assert_called_once_with( - checkpoint_path=checkpoint, - map_location="cpu", - **fxt_engine.model.hparams, - ) + mock_load_from_checkpoint.assert_called_once_with(checkpoint_path=checkpoint, map_location="cpu") mock_export.assert_called_once_with( output_dir=Path(fxt_engine.work_dir), base_name="exported_model",