Skip to content

Commit

Permalink
Merge pull request #179 from ViCCo-Group/fix/ssl-models-caching
Browse files Browse the repository at this point in the history
Removes separate VISSL caching and adds file_name to torch.hub.load_state_dict_from_url everywhere
  • Loading branch information
LukasMut authored Sep 11, 2024
2 parents 200e047 + c087a2a commit da9d717
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,14 @@ class SSLExtractor(PyTorchExtractor):
"type": "vissl",
},
"barlowtwins-rn50": {
"repository": "facebookresearch/barlowtwins:main",
"arch": "resnet50",
"type": "hub",
"type": "checkpoint_url",
"checkpoint_url": "https://dl.fbaipublicfiles.com/barlowtwins/ljng/resnet50.pth"
},
"vicreg-rn50": {
"repository": "facebookresearch/vicreg:main",
"arch": "resnet50",
"type": "hub",
"type": "checkpoint_url",
"checkpoint_url": "https://dl.fbaipublicfiles.com/vicreg/resnet50.pth"
},
"dino-vit-small-p16": {
"repository": "facebookresearch/dino:main",
Expand Down Expand Up @@ -350,15 +350,15 @@ def __init__(
device=device,
)

def _download_and_save_model(self, model_url: str,
output_model_filepath: str, unique_model_id: str):
def _load_vissl_state_dict(self, model_url: str, unique_model_filename: str):
"""
Downloads the model in vissl format, converts it to torchvision format and
saves it under output_model_filepath.
caches it under the unique_model_filename. Therefore, this file_name should be unique
per url. Otherwise, the wrong cached variant is loaded.
"""
model = load_state_dict_from_url(model_url,
map_location=torch.device("cpu"),
file_name=f'{unique_model_id}.pt')
file_name=unique_model_filename)

# get the model trunk to rename
if "classy_state_dict" in model.keys():
Expand All @@ -369,11 +369,10 @@ def _download_and_save_model(self, model_url: str,
model_trunk = model

converted_model = self._replace_module_prefix(model_trunk, "_feature_blocks.")
torch.save(converted_model, output_model_filepath)
return converted_model

def _replace_module_prefix(
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
Expand All @@ -394,25 +393,25 @@ def load_model_from_source(self) -> None:
Otherwise, loads it from the cache directory.
"""
if self.model_name in SSLExtractor.MODELS:

# unique model id name for all models
unique_model_filename = f'thingsvision_ssl_v0_{self.model_name}.pth'

# defines how the model should be loaded
model_config = SSLExtractor.MODELS[self.model_name]

# VISSL MODELS
if model_config["type"] == "vissl":
cache_dir = os.path.join(get_torch_home(), "vissl")
model_filepath = os.path.join(cache_dir, self.model_name + ".torch")
if not os.path.exists(model_filepath):
os.makedirs(cache_dir, exist_ok=True)
model_state_dict = self._download_and_save_model(
model_url=model_config["url"],
output_model_filepath=model_filepath,
unique_model_id=f'thingsvision_vissl_{self.model_name}'
)
else:
model_state_dict = torch.load(
model_filepath, map_location=torch.device("cpu")
)
model_state_dict = self._load_vissl_state_dict(
model_url=model_config["url"],
unique_model_filename=unique_model_filename
)
self.model = getattr(torchvision.models, model_config["arch"])()
if model_config["arch"] == "resnet50":
self.model.fc = torch.nn.Identity()
self.model.load_state_dict(model_state_dict, strict=True)

# HUB MODELS
elif model_config["type"] == "hub":
if self.model_name.startswith("dino-vit"):
if self.model_name == "dino-vit-tiny-p8":
Expand All @@ -430,7 +429,10 @@ def load_model_from_source(self) -> None:
else:
raise ValueError(f"\n{self.model_name} is not available.\n")
state_dict = torch.hub.load_state_dict_from_url(
model_config["checkpoint_url"]
model_config["checkpoint_url"],
map_location=torch.device("cpu"),
# This is used to cache the file
file_name=unique_model_filename
)
model.load_state_dict(state_dict, strict=True)
self.model = model
Expand All @@ -444,7 +446,9 @@ def load_model_from_source(self) -> None:
else:
raise ValueError(f"\n{self.model_name} is not available.\n")
state_dict = torch.hub.load_state_dict_from_url(
model_config["checkpoint_url"]
model_config["checkpoint_url"],
map_location=torch.device("cpu"),
file_name=unique_model_filename
)
checkpoint_model = state_dict["model"]
# interpolate position embedding
Expand All @@ -457,6 +461,26 @@ def load_model_from_source(self) -> None:
)
if model_config["arch"] == "resnet50":
self.model.fc = torch.nn.Identity()

# MODELS FROM CHECKPOINT URL
elif model_config["type"] == "checkpoint_url":

# load architecture
self.model = getattr(torchvision.models, model_config["arch"])()
if model_config["arch"] == "resnet50":
self.model.fc = torch.nn.Identity()

# load and cache state_dict
state_dict = torch.hub.load_state_dict_from_url(
model_config["checkpoint_url"],
map_location=torch.device("cpu"),
# IMPORTANT that this is unique as it will be used for caching
file_name=unique_model_filename
)

# load state dict to model
self.model.load_state_dict(state_dict, strict=True)

else:
type = model_config["type"]
raise ValueError(f"\nUnknown model type: {type}.\n")
Expand Down

0 comments on commit da9d717

Please sign in to comment.