diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index b73c5f346dba57..0ccd9b9f6685fe 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -16,14 +16,13 @@ import argparse -import json from pathlib import Path import requests import timm import torch -from huggingface_hub import hf_hub_download from PIL import Image +from timm.data import ImageNetInfo, infer_imagenet_subset from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel from transformers.utils import logging @@ -60,13 +59,11 @@ def create_rename_keys(config, base_model=False): ) if base_model: - # layernorm + pooler + # layernorm rename_keys.extend( [ ("norm.weight", "layernorm.weight"), ("norm.bias", "layernorm.bias"), - ("pre_logits.fc.weight", "pooler.dense.weight"), - ("pre_logits.fc.bias", "pooler.dense.bias"), ] ) @@ -140,60 +137,68 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): # define default ViT configuration config = ViTConfig() base_model = False - # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size - if vit_name[-5:] == "in21k": - base_model = True - config.patch_size = int(vit_name[-12:-10]) - config.image_size = int(vit_name[-9:-6]) - else: - config.num_labels = 1000 - repo_id = "huggingface/label-files" - filename = "imagenet-1k-id2label.json" - id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) - id2label = {int(k): v for k, v in id2label.items()} - config.id2label = id2label - config.label2id = {v: k for k, v in id2label.items()} - config.patch_size = int(vit_name[-6:-4]) - config.image_size = int(vit_name[-3:]) - # size of the architecture - if "deit" in vit_name: - if vit_name[9:].startswith("tiny"): - config.hidden_size = 192 - config.intermediate_size = 768 - config.num_hidden_layers = 12 - config.num_attention_heads = 3 - elif vit_name[9:].startswith("small"): - config.hidden_size = 384 - config.intermediate_size = 1536 - config.num_hidden_layers = 12 - config.num_attention_heads = 6 - else: - pass - else: - if vit_name[4:].startswith("small"): - config.hidden_size = 768 - config.intermediate_size = 2304 - config.num_hidden_layers = 8 - config.num_attention_heads = 8 - elif vit_name[4:].startswith("base"): - pass - elif vit_name[4:].startswith("large"): - config.hidden_size = 1024 - config.intermediate_size = 4096 - config.num_hidden_layers = 24 - config.num_attention_heads = 16 - elif vit_name[4:].startswith("huge"): - config.hidden_size = 1280 - config.intermediate_size = 5120 - config.num_hidden_layers = 32 - config.num_attention_heads = 16 # load original model from timm timm_model = timm.create_model(vit_name, pretrained=True) timm_model.eval() - # load state_dict of original model, remove and rename some keys + # detect unsupported ViT models in transformers + # fc_norm is present + if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity): + raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.") + + # use of global average pooling in combination (or without) class token + if getattr(timm_model, "global_pool", None) == "avg": + raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.") + + # CLIP style vit with norm_pre layer present + if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity): + raise ValueError( + f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer." + ) + + # SigLIP style vit with attn_pool layer present + if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map": + raise ValueError( + f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool." + ) + + # use of layer scale in ViT model blocks + if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance( + getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity + ): + raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.") + + # Hybrid ResNet-ViTs + if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): + raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.") + + # get patch size and image size from the patch embedding submodule + config.patch_size = timm_model.patch_embed.patch_size[0] + config.image_size = timm_model.patch_embed.img_size[0] + + # retrieve architecture-specific parameters from the timm model + config.hidden_size = timm_model.embed_dim + config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features + config.num_hidden_layers = len(timm_model.blocks) + config.num_attention_heads = timm_model.blocks[0].attn.num_heads + + # check whether the model has a classification head or not + if timm_model.num_classes != 0: + config.num_labels = timm_model.num_classes + # infer ImageNet subset from timm model + imagenet_subset = infer_imagenet_subset(timm_model) + dataset_info = ImageNetInfo(imagenet_subset) + config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())} + config.label2id = {v: k for k, v in config.id2label.items()} + else: + print(f"{vit_name} is going to be converted as a feature extractor only.") + base_model = True + + # load state_dict of original model state_dict = timm_model.state_dict() + + # remove and rename some keys in the state dict if base_model: remove_classification_head_(state_dict) rename_keys = create_rename_keys(config, base_model) @@ -202,8 +207,8 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model - if vit_name[-5:] == "in21k": - model = ViTModel(config).eval() + if base_model: + model = ViTModel(config, add_pooling_layer=False).eval() else: model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) @@ -219,8 +224,8 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) - assert timm_pooled_output.shape == outputs.pooler_output.shape - assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) + assert timm_pooled_output.shape == outputs.last_hidden_state.shape + assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) else: timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape