Skip to content

Commit

Permalink
Better convert. (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Nov 17, 2023
1 parent 7faab77 commit 1799438
Showing 1 changed file with 115 additions and 126 deletions.
241 changes: 115 additions & 126 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import os
import shutil
from collections import defaultdict
from inspect import signature
from tempfile import TemporaryDirectory
from typing import Dict, List, Optional, Set, Tuple

import torch

from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import load_file, save_file
from transformers import AutoConfig
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete


COMMIT_DESCRIPTION = """
Expand All @@ -34,20 +32,78 @@

ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]

def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)

shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
if len(shared) == 1:
# Force contiguous
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
else:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)

class AlreadyExists(Exception):
pass
keep_name = sorted(list(complete_names))[0]

# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]

if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove

def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
try:
import transformers
import json

config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
)
with open(config_filename, "r") as f:
config = json.load(f)
architecture = config["architectures"][0]

class_ = getattr(transformers, architecture)

def shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for ptr, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])

except Exception as e:
discard_names = []
return discard_names

class AlreadyExists(Exception):
pass


def check_file_size(sf_filename: str, pt_filename: str):
Expand All @@ -70,8 +126,8 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
with open(filename, "r") as f:
data = json.load(f)

Expand All @@ -82,7 +138,7 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio

sf_filename = rename(pt_filename)
sf_filename = os.path.join(folder, sf_filename)
convert_file(pt_filename, sf_filename)
convert_file(pt_filename, sf_filename, discard_names=discard_names)
local_filenames.append(sf_filename)

index = os.path.join(folder, "model.safetensors.index.json")
Expand All @@ -101,12 +157,12 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio
return operations, errors


def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)

sf_name = "model.safetensors"
sf_filename = os.path.join(folder, sf_name)
convert_file(pt_filename, sf_filename)
convert_file(pt_filename, sf_filename, discard_names)
operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
errors: List[Tuple[str, "Exception"]] = []
return operations, errors
Expand All @@ -115,21 +171,25 @@ def convert_single(model_id: str, folder: str, token: Optional[str]) -> Conversi
def convert_file(
pt_filename: str,
sf_filename: str,
discard_names: List[str],
):
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)

# For tensors to be contiguous
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)

metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}

dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
save_file(loaded, sf_filename, metadata=metadata)
check_file_size(sf_filename, pt_filename)
reloaded = load_file(sf_filename)
for k in loaded:
Expand All @@ -155,79 +215,10 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
return "\n".join(errors)


def check_final_model(model_id: str, folder: str, token: Optional[str]):
config = hf_hub_download(repo_id=model_id, filename="config.json", token=token, cache_dir=folder)
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)

import transformers

class_ = getattr(transformers, config.architectures[0])
(pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
(sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)

if pt_infos != sf_infos:
error_string = create_diff(pt_infos, sf_infos)
raise ValueError(f"Different infos when reloading the model: {error_string}")

pt_params = pt_model.state_dict()
sf_params = sf_model.state_dict()

pt_shared = shared_pointers(pt_params)
sf_shared = shared_pointers(sf_params)
if pt_shared != sf_shared:
raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}")

sig = signature(pt_model.forward)
input_ids = torch.arange(10).unsqueeze(0)
pixel_values = torch.randn(1, 3, 224, 224)
input_values = torch.arange(1000).float().unsqueeze(0)
# Hardcoded for whisper basically
input_features = torch.zeros((1, 80, 3000))
kwargs = {}
if "input_ids" in sig.parameters:
kwargs["input_ids"] = input_ids
if "input_features" in sig.parameters:
kwargs["input_features"] = input_features
if "decoder_input_ids" in sig.parameters:
kwargs["decoder_input_ids"] = input_ids
if "pixel_values" in sig.parameters:
kwargs["pixel_values"] = pixel_values
if "input_values" in sig.parameters:
kwargs["input_values"] = input_values
if "bbox" in sig.parameters:
kwargs["bbox"] = torch.zeros((1, 10, 4)).long()
if "image" in sig.parameters:
kwargs["image"] = pixel_values

if torch.cuda.is_available():
pt_model = pt_model.cuda()
sf_model = sf_model.cuda()
kwargs = {k: v.cuda() for k, v in kwargs.items()}

try:
pt_logits = pt_model(**kwargs)[0]
except Exception as e:
try:
# Musicgen special exception.
decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long)
if torch.cuda.is_available():
decoder_input_ids = decoder_input_ids.cuda()

kwargs["decoder_input_ids"] = decoder_input_ids
pt_logits = pt_model(**kwargs)[0]
except Exception:
raise e
sf_logits = sf_model(**kwargs)[0]

torch.testing.assert_close(sf_logits, pt_logits)
print(f"Model {model_id} is ok !")


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
try:
main_commit = api.list_repo_commits(model_id)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id)
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id, revision=revision)
except Exception:
return None
for discussion in discussions:
Expand All @@ -239,15 +230,15 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
return None


def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
operations = []
errors = []

extensions = set([".bin", ".ckpt"])
for filename in filenames:
prefix, ext = os.path.splitext(filename)
if ext in extensions:
pt_filename = hf_hub_download(model_id, filename=filename, token=token, cache_dir=folder)
pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder)
dirname, raw_filename = os.path.split(filename)
if raw_filename == "pytorch_model.bin":
# XXX: This is a special case to handle `transformers` and the
Expand All @@ -257,25 +248,25 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti
sf_in_repo = f"{prefix}.safetensors"
sf_filename = os.path.join(folder, sf_in_repo)
try:
convert_file(pt_filename, sf_filename)
convert_file(pt_filename, sf_filename, discard_names=[])
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
except Exception as e:
errors.append((pt_filename, e))
return operations, errors


def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
pr_title = "Adding `safetensors` variant of this model"
info = api.model_info(model_id)
info = api.model_info(model_id, revision=revision)
filenames = set(s.rfilename for s in info.siblings)

with TemporaryDirectory() as d:
with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
operations = None
pr = previous_pr(api, model_id, pr_title)
pr = previous_pr(api, model_id, pr_title, revision=revision)

library_name = getattr(info, "library_name", None)
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
Expand All @@ -285,19 +276,21 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
new_pr = pr
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
elif library_name == "transformers":

discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
if "pytorch_model.bin" in filenames:
operations, errors = convert_single(model_id, folder, token=api.token)
operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
elif "pytorch_model.bin.index.json" in filenames:
operations, errors = convert_multi(model_id, folder, token=api.token)
operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
else:
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
check_final_model(model_id, folder, token=api.token)
else:
operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token)

if operations:
new_pr = api.create_commit(
repo_id=model_id,
revision=revision,
operations=operations,
commit_message=pr_title,
commit_description=COMMIT_DESCRIPTION,
Expand All @@ -324,6 +317,11 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
parser.add_argument(
"--revision",
type=str,
help="The revision to convert",
)
parser.add_argument(
"--force",
action="store_true",
Expand All @@ -346,26 +344,17 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
" Continue [Y/n] ?"
)
if txt.lower() in {"", "y"}:
try:
commit_info, errors = convert(api, model_id, force=args.force)
string = f"""
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
string = f"""
### Success 🔥
Yay! This model was successfully converted and a PR was open using your token, here:
[{commit_info.pr_url}]({commit_info.pr_url})
"""
if errors:
string += "\nErrors during conversion:\n"
string += "\n".join(
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
)
print(string)
except Exception as e:
print(
f"""
### Error 😢😢😢
{e}
"""
"""
if errors:
string += "\nErrors during conversion:\n"
string += "\n".join(
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
)
print(string)
else:
print(f"Answer was `{txt}` aborting.")

0 comments on commit 1799438

Please sign in to comment.