diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index 0f938d6038..9d1e23a06c 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -63,7 +63,7 @@ def __init__( self, imagenet_dir: Path | str = "./datasets/imagenette", teacher_out_channels: int = 384, - model_size: EfficientAdModelSize = EfficientAdModelSize.S, + model_size: EfficientAdModelSize | str = EfficientAdModelSize.S, lr: float = 0.0001, weight_decay: float = 0.00001, padding: bool = False, @@ -72,24 +72,27 @@ def __init__( super().__init__() self.imagenet_dir = Path(imagenet_dir) - self.model_size = model_size + if not isinstance(model_size, EfficientAdModelSize): + model_size = EfficientAdModelSize(model_size) + self.model_size: EfficientAdModelSize = model_size self.model: EfficientAdModel = EfficientAdModel( teacher_out_channels=teacher_out_channels, model_size=model_size, padding=padding, pad_maps=pad_maps, ) - self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper - self.lr = lr - self.weight_decay = weight_decay + self.batch_size: int = 1 # imagenet dataloader batch_size is 1 according to the paper + self.lr: float = lr + self.weight_decay: float = weight_decay def prepare_pretrained_model(self) -> None: """Prepare the pretrained teacher model.""" pretrained_models_dir = Path("./pre_trained/") if not (pretrained_models_dir / "efficientad_pretrained_weights").is_dir(): download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO) + model_size_str = self.model_size.value if isinstance(self.model_size, EfficientAdModelSize) else self.model_size teacher_path = ( - pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{self.model_size.value}.pth" + pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth" ) logger.info(f"Load pretrained teacher model from {teacher_path}") self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device)))