From 1ed2469b2ab2ce364d846ac75d575ca14cb5dfc6 Mon Sep 17 00:00:00 2001 From: staghado Date: Wed, 18 Oct 2023 16:47:44 +0200 Subject: [PATCH 1/5] timm to pytorch conversion for vit model fix --- .../models/vit/convert_vit_timm_to_pytorch.py | 88 ++++++++----------- 1 file changed, 35 insertions(+), 53 deletions(-) diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index b73c5f346dba57..c7df0356f2c984 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -60,13 +60,13 @@ def create_rename_keys(config, base_model=False): ) if base_model: - # layernorm + pooler + # layernorm rename_keys.extend( [ ("norm.weight", "layernorm.weight"), ("norm.bias", "layernorm.bias"), - ("pre_logits.fc.weight", "pooler.dense.weight"), - ("pre_logits.fc.bias", "pooler.dense.bias"), + # ("pre_logits.fc.weight", "pooler.dense.weight"), + # ("pre_logits.fc.bias", "pooler.dense.bias"), ] ) @@ -140,60 +140,39 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): # define default ViT configuration config = ViTConfig() base_model = False - # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size - if vit_name[-5:] == "in21k": - base_model = True - config.patch_size = int(vit_name[-12:-10]) - config.image_size = int(vit_name[-9:-6]) - else: - config.num_labels = 1000 + + # load original model from timm + timm_model = timm.create_model(vit_name, pretrained=True) + timm_model.eval() + + # get patch size and image size from the patch embedding submodule + config.patch_size = timm_model.patch_embed.patch_size[0] + config.image_size = timm_model.patch_embed.img_size[0] + + # retrieve architecture-specific parameters from the timm model + config.hidden_size = timm_model.embed_dim + config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features + config.num_hidden_layers = len(timm_model.blocks) + config.num_attention_heads = timm_model.blocks[0].attn.num_heads + + # check whether the model has a classification head or not + if timm_model.num_classes != 0: + config.num_labels = timm_model.num_classes repo_id = "huggingface/label-files" - filename = "imagenet-1k-id2label.json" + # .__ceil__() avoids having to import math + filename = f"imagenet-{(config.num_labels / 1000).__ceil__()}k-id2label.json" id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} - config.patch_size = int(vit_name[-6:-4]) - config.image_size = int(vit_name[-3:]) - # size of the architecture - if "deit" in vit_name: - if vit_name[9:].startswith("tiny"): - config.hidden_size = 192 - config.intermediate_size = 768 - config.num_hidden_layers = 12 - config.num_attention_heads = 3 - elif vit_name[9:].startswith("small"): - config.hidden_size = 384 - config.intermediate_size = 1536 - config.num_hidden_layers = 12 - config.num_attention_heads = 6 - else: - pass else: - if vit_name[4:].startswith("small"): - config.hidden_size = 768 - config.intermediate_size = 2304 - config.num_hidden_layers = 8 - config.num_attention_heads = 8 - elif vit_name[4:].startswith("base"): - pass - elif vit_name[4:].startswith("large"): - config.hidden_size = 1024 - config.intermediate_size = 4096 - config.num_hidden_layers = 24 - config.num_attention_heads = 16 - elif vit_name[4:].startswith("huge"): - config.hidden_size = 1280 - config.intermediate_size = 5120 - config.num_hidden_layers = 32 - config.num_attention_heads = 16 - - # load original model from timm - timm_model = timm.create_model(vit_name, pretrained=True) - timm_model.eval() + print(f"{vit_name} is going to be converted as a feature extractor only. This is not guaranteed to work.") + base_model = True - # load state_dict of original model, remove and rename some keys + # load state_dict of original model state_dict = timm_model.state_dict() + + # remove and rename some keys in the state dict if base_model: remove_classification_head_(state_dict) rename_keys = create_rename_keys(config, base_model) @@ -202,8 +181,9 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): read_in_q_k_v(state_dict, config, base_model) # load HuggingFace model - if vit_name[-5:] == "in21k": - model = ViTModel(config).eval() + if base_model: + model = ViTModel(config, add_pooling_layer=False).eval() + # print(model.state_dict().keys()) else: model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) @@ -219,8 +199,10 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) - assert timm_pooled_output.shape == outputs.pooler_output.shape - assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) + print(timm_pooled_output) + print(outputs.last_hidden_state) + assert timm_pooled_output.shape == outputs.last_hidden_state.shape + assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) else: timm_logits = timm_model(pixel_values) assert timm_logits.shape == outputs.logits.shape From 330eaf2e3d8ab8666f5c3eaa617c38f7fec93eca Mon Sep 17 00:00:00 2001 From: staghado Date: Wed, 18 Oct 2023 17:10:46 +0200 Subject: [PATCH 2/5] remove unecessary print statments --- src/transformers/models/vit/convert_vit_timm_to_pytorch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index c7df0356f2c984..81e522b39ef61a 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -183,7 +183,6 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): # load HuggingFace model if base_model: model = ViTModel(config, add_pooling_layer=False).eval() - # print(model.state_dict().keys()) else: model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) @@ -199,8 +198,6 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): if base_model: timm_pooled_output = timm_model.forward_features(pixel_values) - print(timm_pooled_output) - print(outputs.last_hidden_state) assert timm_pooled_output.shape == outputs.last_hidden_state.shape assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) else: From d5ce0e82085de2b8a327d5b11ca4d3fcc6e76b9e Mon Sep 17 00:00:00 2001 From: staghado Date: Wed, 25 Oct 2023 16:25:29 +0200 Subject: [PATCH 3/5] Detect non-supported ViTs in transformers & better handle id2label mapping --- .../models/vit/convert_vit_timm_to_pytorch.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index 81e522b39ef61a..ab82a9c5907adb 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -16,14 +16,13 @@ import argparse -import json from pathlib import Path import requests import timm import torch -from huggingface_hub import hf_hub_download from PIL import Image +from timm.data import ImageNetInfo, infer_imagenet_subset from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel from transformers.utils import logging @@ -65,8 +64,6 @@ def create_rename_keys(config, base_model=False): [ ("norm.weight", "layernorm.weight"), ("norm.bias", "layernorm.bias"), - # ("pre_logits.fc.weight", "pooler.dense.weight"), - # ("pre_logits.fc.bias", "pooler.dense.bias"), ] ) @@ -145,6 +142,35 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): timm_model = timm.create_model(vit_name, pretrained=True) timm_model.eval() + # detect unsupported ViT models in transformers + # fc_norm is present + if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity): + raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.") + + # use of global average pooling in combination (or without) class token + if getattr(timm_model, "global_pool", None) == "avg": + raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.") + + # CLIP style vit with norm_pre layer present + if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity): + raise ValueError( + f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer." + ) + + # SigLIP style vit with attn_pool layer present + if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map": + raise ValueError( + f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool." + ) + + # use of layer scale in ViT model blocks + if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance( + getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity + ): + raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.") + + # non-overlapping position and class token embedding (to be added) + # get patch size and image size from the patch embedding submodule config.patch_size = timm_model.patch_embed.patch_size[0] config.image_size = timm_model.patch_embed.img_size[0] @@ -158,15 +184,13 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): # check whether the model has a classification head or not if timm_model.num_classes != 0: config.num_labels = timm_model.num_classes - repo_id = "huggingface/label-files" - # .__ceil__() avoids having to import math - filename = f"imagenet-{(config.num_labels / 1000).__ceil__()}k-id2label.json" - id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) - id2label = {int(k): v for k, v in id2label.items()} - config.id2label = id2label - config.label2id = {v: k for k, v in id2label.items()} + # infer ImageNet subset from timm model + imagenet_subset = infer_imagenet_subset(timm_model) + dataset_info = ImageNetInfo(imagenet_subset) + config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())} + config.label2id = {v: k for k, v in config.id2label.items()} else: - print(f"{vit_name} is going to be converted as a feature extractor only. This is not guaranteed to work.") + print(f"{vit_name} is going to be converted as a feature extractor only.") base_model = True # load state_dict of original model From fcedbfde8345db1788e3afd9a61f171f98605e97 Mon Sep 17 00:00:00 2001 From: staghado Date: Sun, 5 Nov 2023 18:20:12 +0100 Subject: [PATCH 4/5] detect non supported hybrid resnet-vit models in conversion script --- src/transformers/models/vit/convert_vit_timm_to_pytorch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index ab82a9c5907adb..15c383473eda0e 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -169,6 +169,10 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): ): raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.") + # Hybrid ResNet-ViTs + if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): + raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.") + # non-overlapping position and class token embedding (to be added) # get patch size and image size from the patch embedding submodule From 64486191b066d8d5c21400e4b9f935127fd6b69f Mon Sep 17 00:00:00 2001 From: staghado Date: Tue, 7 Nov 2023 10:10:28 +0100 Subject: [PATCH 5/5] remove check for overlap between cls token and pos embed --- src/transformers/models/vit/convert_vit_timm_to_pytorch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index 15c383473eda0e..0ccd9b9f6685fe 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -173,8 +173,6 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.") - # non-overlapping position and class token embedding (to be added) - # get patch size and image size from the patch embedding submodule config.patch_size = timm_model.patch_embed.patch_size[0] config.image_size = timm_model.patch_embed.img_size[0]