Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] fix variant-identification. #9253

Merged
merged 37 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6b379a9
fix variant-idenitification.
sayakpaul Aug 23, 2024
f155ec7
fix variant
sayakpaul Aug 23, 2024
3f36e59
Merge branch 'main' into variant-tests
sayakpaul Aug 23, 2024
91253e8
fix sharded variant checkpoint loading.
sayakpaul Aug 27, 2024
dd5941e
Merge branch 'main' into variant-tests
sayakpaul Aug 27, 2024
564b8b4
Apply suggestions from code review
sayakpaul Aug 27, 2024
fdd0435
Merge branch 'main' into variant-tests
sayakpaul Sep 4, 2024
d5cad9e
Merge branch 'main' into variant-tests
sayakpaul Sep 10, 2024
c0b1ceb
fixes.
sayakpaul Sep 10, 2024
247dd93
more fixes.
sayakpaul Sep 10, 2024
b024a6d
remove print.
sayakpaul Sep 10, 2024
fdfdc5f
Merge branch 'main' into variant-tests
sayakpaul Sep 11, 2024
dcf1852
Merge branch 'main' into variant-tests
yiyixuxu Sep 12, 2024
3a71ad9
fixes
sayakpaul Sep 13, 2024
ab91852
fixes
sayakpaul Sep 13, 2024
aa631c5
comments
sayakpaul Sep 13, 2024
453bfa5
fixes
sayakpaul Sep 13, 2024
11e4b71
Merge branch 'main' into variant-tests
sayakpaul Sep 13, 2024
dbdf0f9
apply suggestions.
sayakpaul Sep 14, 2024
671038a
hub_utils.py
sayakpaul Sep 14, 2024
57382f2
Merge branch 'main' into variant-tests
sayakpaul Sep 14, 2024
ea5ecdb
fix test
sayakpaul Sep 14, 2024
a510a9b
Merge branch 'main' into variant-tests
sayakpaul Sep 17, 2024
f583dad
Merge branch 'main' into variant-tests
sayakpaul Sep 18, 2024
dc0255a
updates
sayakpaul Sep 19, 2024
f2ab3de
Merge branch 'main' into variant-tests
sayakpaul Sep 21, 2024
10baa9d
Merge branch 'main' into variant-tests
sayakpaul Sep 23, 2024
25ac01f
fixes
sayakpaul Sep 23, 2024
bac62ac
Merge branch 'main' into variant-tests
sayakpaul Sep 24, 2024
b6794ed
Merge branch 'main' into variant-tests
sayakpaul Sep 25, 2024
fcb4e39
Merge branch 'main' into variant-tests
sayakpaul Sep 26, 2024
4c0c5d2
fixes
sayakpaul Sep 26, 2024
0b1c2a6
Merge branch 'main' into variant-tests
sayakpaul Sep 27, 2024
8ad6b23
Apply suggestions from code review
sayakpaul Sep 28, 2024
1190f7d
updates.
sayakpaul Sep 28, 2024
59cfefb
removep patch file.
sayakpaul Sep 28, 2024
d72f5c1
Merge branch 'main' into variant-tests
sayakpaul Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,28 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
transformers_index_format = r"\d{5}-of-\d{5}"

if variant is not None:
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
# Examples:
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
# For transformers, `pytorch_model.fp16.bin` as well as `pytorch_model.fp16-00001-of-00004.bin`.
# For diffusers `diffusion_pytorch_model.fp16.bin` as well as `diffusion_pytorch_model-00001-of-00002.fp16.safetensors`
# These differences exist because `diffusers` delegates the process of loading sharded checkpoints
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
# to `accelerate`. However, `transformers` has custom code that takes care of it.
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
rf"({'|'.join(weight_prefixes)})"
rf"(?:"
rf"\.({variant}(?:-{transformers_index_format})?)"
rf"|"
rf"-({transformers_index_format})\.({variant})"
rf")"
rf"\.({'|'.join(weight_suffixs)})$"
)
# `text_encoder/pytorch_model.bin.index.fp16.json`
# Examples:
# For transformers, it will be `text_encoder/pytorch_model.bin.index.fp16.json`
# for diffusers, it will be `unet/diffusion_pytorch_model.safetensors.fp16.index.json`
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
rf"({'|'.join(weight_prefixes)})"
rf"\.({'|'.join(weight_suffixs)})"
rf"(?:\.{variant}\.index|\.index\.{variant})"
rf"\.json$"
)

# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
Expand All @@ -184,12 +199,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi

# all variant filenames will be used by default
usable_filenames = set(variant_filenames)
transformers_weight_prefixes = ("pytorch_model", "model")

def is_transformers_file(filename):
return filename.split("/")[-1].startswith(transformers_weight_prefixes)

def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
if is_transformers_file(filename):
variant_filename = filename.replace("index", f"index.{variant}")
else:
variant_filename = filename.replace("index", f"{variant}.index")
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
if is_transformers_file(filename):
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('-')[0]}-{'-'.join(filename.split('-')[1:])}"
variant_ext = variant_filename.split(".")[-1]
variant_filename = variant_filename.replace(variant_ext, f"{variant}.{variant_ext}")
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
Expand Down
39 changes: 26 additions & 13 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
DEPRECATED_REVISION_ARGS,
BaseOutput,
PushToHubMixin,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
Expand Down Expand Up @@ -722,6 +721,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
else:
cached_folder = pretrained_model_name_or_path
filenames = []
Copy link
Collaborator

@yiyixuxu yiyixuxu Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe we shoud just update the _identify_model_variants function usingvariant_compatible_siblings

and it is still not able to load variants with shared checkpoints from pipeline level

i.e. we should be able to load the fp16 variant in the transformer folder too but it is currently not

import torch
from diffusers import  AutoPipelineForText2Image
repo = "fal/AuraFlow" # sharded checkpoint with variant
pipe = AutoPipelineForText2Image.from_pretrained(
        repo,
        variant="fp16",
        torch_dtype=torch.float16,
        )
print(pipe.dtype)

you get


A mixture of fp16 and non-fp16 filenames will be loaded.
Loaded fp16 filenames:
[vae/diffusion_pytorch_model.fp16.safetensors, text_encoder/model.fp16.safetensors]
Loaded non-fp16 filenames:
[transformer/diffusion_pytorch_model-00002-of-00003.safetensors, transformer/diffusion_pytorch_model-00003-of-00003.safetensors, transformer/diffusion_pytorch_model-00001-of-00003.safetensors
If this behavior is not expected, please check your folder structure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @DN6 @a-r-r-o-w here too

for _, _, files in os.walk(cached_folder):
for file in files:
filenames.append(os.path.basename(file))

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
if len(variant_filenames) == 0 and variant is not None:
error_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f" Available ones are: {model_filenames}."
)
raise ValueError(error_message)

config_dict = cls.load_config(cached_folder)

Expand Down Expand Up @@ -1239,6 +1250,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_info_call_error = e # save error to reraise it if model is not cached locally

if not local_files_only:
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
if len(variant_filenames) == 0 and variant is not None:
error_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f" Available ones are: {model_filenames}."
)
raise ValueError(error_message)

config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
Expand All @@ -1255,9 +1275,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

Comment on lines -1270 to -1272
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved up to raise error earlier in code.

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

Expand All @@ -1279,15 +1296,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not remove this error in download

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version.

Instead, we are erroring out now from from_pretrained():

model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. I think this should be resolved now.

WDYT about catching these errors without having to download the actual files and leveraging model_info() (in case we're querying the Hub) or regular string matching (in case it's local)? Currently, we're still calling download() in case we don't have the model files cached. I think many errors can be caught and warnings can be thrown without having to do that.

This could live in a future PR.

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)
Expand Down Expand Up @@ -1356,6 +1364,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
is_sharded = any("index.json" in f and f != "model_index.json" for f in filenames)

if (
use_safetensors
Expand All @@ -1380,9 +1389,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
# `not is_sharded` because sharded checkpoints with a variant
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
# ("fp16") for example may have lesser shards actually. Consider
# https://huggingface.co/fal/AuraFlow/tree/main/transformer, for example.
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
and not is_sharded
):
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
Expand Down
25 changes: 24 additions & 1 deletion tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,29 @@ def test_download_variant_partly(self):
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)

def test_download_variants_with_sharded_checkpoints(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if someone has a better idea to test it out.

# Here we test for downloading of "variant" files belonging to the `unet` and
# the `text_encoder`. Their checkpoints can be sharded.
for use_safetensors in [True, False]:
for variant in ["fp16", None]:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds",
safety_checker=None,
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# Check for `model_ext` and `variant`.
model_ext = ".safetensors" if use_safetensors else ".bin"
unexpected_ext = ".bin" if use_safetensors else ".safetensors"
assert not any(f.endswith(unexpected_ext) for f in files)
assert all(variant in f for f in files if f.endswith(model_ext) and variant is not None)

def test_download_safetensors_only_variant_exists_for_model(self):
variant = None
use_safetensors = True
Expand Down Expand Up @@ -655,7 +678,7 @@ def test_local_save_load_index(self):
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been serialized with variant and safe_serialization otherwise the test seems wrong to me.

pipe_2 = StableDiffusionPipeline.from_pretrained(
tmpdirname, safe_serialization=use_safe, variant=variant
)
Expand Down
68 changes: 68 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,74 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs):
# accounts for models that modify the number of inference steps based on strength
assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)

def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module)
]
variant = "fp16"

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)

with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)

for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))

def test_loading_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
variant = "fp16"

def is_nan(tensor):
if tensor.ndimension() == 0:
has_nan = torch.isnan(tensor).item()
else:
has_nan = torch.isnan(tensor).any()
return has_nan

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant)

model_components_pipe = {
component_name: component
for component_name, component in pipe.components.items()
if isinstance(component, nn.Module)
}
model_components_pipe_loaded = {
component_name: component
for component_name, component in pipe_loaded.components.items()
if isinstance(component, nn.Module)
}
for component_name in model_components_pipe:
pipe_component = model_components_pipe[component_name]
pipe_loaded_component = model_components_pipe_loaded[component_name]
for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()):
# nan check for luminanext (mps).
if not (is_nan(p1) and is_nan(p2)):
self.assertTrue(torch.equal(p1, p2))

def test_loading_with_incorrect_variants_raises_error(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
variant = "fp16"

with tempfile.TemporaryDirectory() as tmpdir:
# Don't save with variants.
pipe.save_pretrained(tmpdir, safe_serialization=False)

with self.assertRaises(ValueError) as error:
_ = self.pipeline_class.from_pretrained(tmpdir, variant=variant)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have failed with the fixes from this PR rightfully complaining:

ValueError: The deprecation tuple ('no variant default', '0.24.0', "You are trying to load the model files of the `variant=fp16`, but no such modeling files are available.The default model files: {'model.safetensors', 'diffusion_pytorch_model.safetensors'} will be loaded instead. Make sure to not load from `variant=fp16`if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variantmodeling files is deprecated.") should be removed since diffusers' version 0.31.0.dev0 is >= 0.24.0

We didn't have it because we never tested it. But we should be all good now.


assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)

def test_StableDiffusionMixin_component(self):
"""Any pipeline that have LDMFuncMixin should have vae and unet components."""
if not issubclass(self.pipeline_class, StableDiffusionMixin):
Expand Down
Loading