diff --git a/thingsvision/core/extraction/extractors.py b/thingsvision/core/extraction/extractors.py index 1e49f7d..d83f6fe 100644 --- a/thingsvision/core/extraction/extractors.py +++ b/thingsvision/core/extraction/extractors.py @@ -353,7 +353,8 @@ def __init__( def _load_vissl_state_dict(self, model_url: str, unique_model_filename: str): """ Downloads the model in vissl format, converts it to torchvision format and - saves it under output_model_filepath. + caches it under the unique_model_filename. Therefore, this file_name should be unique + per url. Otherwise, the wrong cached variant is loaded. """ model = load_state_dict_from_url(model_url, map_location=torch.device("cpu"), @@ -394,7 +395,7 @@ def load_model_from_source(self) -> None: if self.model_name in SSLExtractor.MODELS: # unique model id name for all models - unique_model_filename = f'thingsvision_ssl_v0_{self.model_name}' + unique_model_filename = f'thingsvision_ssl_v0_{self.model_name}.pth' # defines how the model should be loaded model_config = SSLExtractor.MODELS[self.model_name]