From cfd3d8e3a018f5bd00719518401ae2cb34258108 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Tue, 20 Aug 2024 12:47:15 +0100 Subject: [PATCH] Refactor Lightning's `trainer.model` to `trainer.lightning_module` (#2255) Refactor trainer.model to trainer.lightning_module --- src/anomalib/callbacks/checkpoint.py | 4 ++-- src/anomalib/data/base/datamodule.py | 8 ++++---- src/anomalib/engine/engine.py | 2 +- tests/unit/metrics/test_adaptive_threshold.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/anomalib/callbacks/checkpoint.py b/src/anomalib/callbacks/checkpoint.py index d4af9dfa8e..8947124364 100644 --- a/src/anomalib/callbacks/checkpoint.py +++ b/src/anomalib/callbacks/checkpoint.py @@ -35,7 +35,7 @@ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool: Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models). """ - is_zero_or_few_shot = trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT] + is_zero_or_few_shot = trainer.lightning_module.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT] return ( bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING] # don't save anything during non-fit @@ -52,7 +52,7 @@ def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool: if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end - if trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]: + if trainer.lightning_module.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]: return False return super()._should_save_on_train_epoch_end(trainer) diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 23ab10f882..cb95ca8171 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -270,8 +270,8 @@ def train_transform(self) -> Transform: """ if self._train_transform: return self._train_transform - if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform: - return self.trainer.model.transform + if getattr(self, "trainer", None) and self.trainer.lightning_module and self.trainer.lightning_module.transform: + return self.trainer.lightning_module.transform if self.image_size: return Resize(self.image_size, antialias=True) return None @@ -284,8 +284,8 @@ def eval_transform(self) -> Transform: """ if self._eval_transform: return self._eval_transform - if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform: - return self.trainer.model.transform + if getattr(self, "trainer", None) and self.trainer.lightning_module and self.trainer.lightning_module.transform: + return self.trainer.lightning_module.transform if self.image_size: return Resize(self.image_size, antialias=True) return None diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 485ab8e66e..8648cf30ae 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -183,7 +183,7 @@ def model(self) -> AnomalyModule: Returns: AnomalyModule: Anomaly model. """ - if not self.trainer.model: + if not self.trainer.lightning_module: msg = "Trainer does not have a model assigned yet." raise UnassignedError(msg) return self.trainer.lightning_module diff --git a/tests/unit/metrics/test_adaptive_threshold.py b/tests/unit/metrics/test_adaptive_threshold.py index 5163720d66..1eadab4e4d 100644 --- a/tests/unit/metrics/test_adaptive_threshold.py +++ b/tests/unit/metrics/test_adaptive_threshold.py @@ -55,5 +55,5 @@ def test_manual_threshold() -> None: devices=1, ) engine.fit(model=model, datamodule=datamodule) - assert engine.trainer.model.image_metrics.F1Score.threshold == image_threshold - assert engine.trainer.model.pixel_metrics.F1Score.threshold == pixel_threshold + assert engine.trainer.lightning_module.image_metrics.F1Score.threshold == image_threshold + assert engine.trainer.lightning_module.pixel_metrics.F1Score.threshold == pixel_threshold