Skip to content

Commit

Permalink
Merge branch 'main' into Add-Matryoshka-Diffusion-Models
Browse files Browse the repository at this point in the history
  • Loading branch information
tolgacangoz committed Sep 28, 2024
2 parents f2f2f9c + bd4df28 commit 9d85d0d
Show file tree
Hide file tree
Showing 10 changed files with 603 additions and 219 deletions.
265 changes: 181 additions & 84 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_torch_version,
logging,
Expand Down Expand Up @@ -228,3 +229,67 @@ def _fetch_index_file(
index_file = None

return index_file


def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file.split(".")
split_index = -3 if ".cache" in index_file else -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file = ".".join(splits)
if os.path.exists(index_file):
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
index_file = Path(index_file)
else:
index_file = None
else:
if variant is not None:
index_file_in_repo = Path(
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file_in_repo.split(".")
split_index = -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file_in_repo = ".".join(splits)
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
except (EntryNotFoundError, EnvironmentError):
index_file = None

return index_file
44 changes: 24 additions & 20 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .model_loading_utils import (
_determine_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
Expand Down Expand Up @@ -309,11 +310,9 @@ def save_pretrained(

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
else:
raise ValueError(f"Invalid {weights_name} provided.")
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)

os.makedirs(save_directory, exist_ok=True)

Expand Down Expand Up @@ -624,21 +623,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file_kwargs = {
"is_local": is_local,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"subfolder": subfolder or "",
"use_safetensors": use_safetensors,
"cache_dir": cache_dir,
"variant": variant,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
if index_file is not None and index_file.is_file():
is_sharded = True

Expand Down
49 changes: 37 additions & 12 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@
DEPRECATED_REVISION_ARGS,
BaseOutput,
PushToHubMixin,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
logging,
numpy_to_pil,
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module


Expand Down Expand Up @@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
cached_folder = pretrained_model_name_or_path

# The variant filenames can have the legacy sharding checkpoint format that we check and throw
# a warning if detected.
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)

config_dict = cls.load_config(cached_folder)

# pop out "_ignore_files" as it is only needed for download
Expand All @@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
# with variant being `"fp16"`.
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
if len(model_variants) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
Expand Down Expand Up @@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_info_call_error = e # save error to reraise it if model is not cached locally

if not local_files_only:
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
Expand All @@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

Expand All @@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)

if len(variant_filenames) == 0 and variant is not None:
deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
Expand Down
16 changes: 14 additions & 2 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
split_index = -2 if weights_name.endswith(".index.json") else -1
splits = splits[:-split_index] + [variant] + splits[-split_index:]
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)

return weights_name
Expand Down Expand Up @@ -502,6 +501,19 @@ def _get_checkpoint_shard_files(
return cached_folder, sharded_metadata


def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")
if not filenames:
filenames = []
for _, _, files in os.walk(folder):
for file in files:
filenames.append(os.path.basename(file))
transformers_index_format = r"\d{5}-of-\d{5}"
variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
return any(variant_file_re.match(f) is not None for f in filenames)


class PushToHubMixin:
"""
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
Expand Down
Loading

0 comments on commit 9d85d0d

Please sign in to comment.