Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] fix variant-identification. #9253

Merged
merged 37 commits into from
Sep 28, 2024
Merged

[Core] fix variant-identification. #9253

merged 37 commits into from
Sep 28, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 23, 2024

What does this PR do?

See: https://huggingface.slack.com/archives/C065E480NN9/p1724387504059169

Some in-line comments.

@sayakpaul sayakpaul changed the title [Core] fix variant-idenitification. [Core] fix variant-identification. Aug 23, 2024
pipe.save_pretrained(tmpdir)

with self.assertRaises(ValueError) as error:
_ = self.pipeline_class.from_pretrained(tmpdir, variant=variant)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would have failed with the fixes from this PR rightfully complaining:

ValueError: The deprecation tuple ('no variant default', '0.24.0', "You are trying to load the model files of the `variant=fp16`, but no such modeling files are available.The default model files: {'model.safetensors', 'diffusion_pytorch_model.safetensors'} will be loaded instead. Make sure to not load from `variant=fp16`if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variantmodeling files is deprecated.") should be removed since diffusers' version 0.31.0.dev0 is >= 0.24.0

We didn't have it because we never tested it. But we should be all good now.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -655,7 +655,7 @@ def test_local_save_load_index(self):
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been serialized with variant and safe_serialization otherwise the test seems wrong to me.

@@ -722,6 +721,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
else:
cached_folder = pretrained_model_name_or_path
filenames = []
Copy link
Collaborator

@yiyixuxu yiyixuxu Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe we shoud just update the _identify_model_variants function usingvariant_compatible_siblings

and it is still not able to load variants with shared checkpoints from pipeline level

i.e. we should be able to load the fp16 variant in the transformer folder too but it is currently not

import torch
from diffusers import  AutoPipelineForText2Image
repo = "fal/AuraFlow" # sharded checkpoint with variant
pipe = AutoPipelineForText2Image.from_pretrained(
        repo,
        variant="fp16",
        torch_dtype=torch.float16,
        )
print(pipe.dtype)

you get


A mixture of fp16 and non-fp16 filenames will be loaded.
Loaded fp16 filenames:
[vae/diffusion_pytorch_model.fp16.safetensors, text_encoder/model.fp16.safetensors]
Loaded non-fp16 filenames:
[transformer/diffusion_pytorch_model-00002-of-00003.safetensors, transformer/diffusion_pytorch_model-00003-of-00003.safetensors, transformer/diffusion_pytorch_model-00001-of-00003.safetensors
If this behavior is not expected, please check your folder structure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @DN6 @a-r-r-o-w here too

Comment on lines -1258 to -1272
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved up to raise error earlier in code.

@@ -551,6 +551,29 @@ def test_download_variant_partly(self):
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)

def test_download_variants_with_sharded_checkpoints(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if someone has a better idea to test it out.

@sayakpaul
Copy link
Member Author

@yiyixuxu I have done a couple of changes. LMK what you think.

@sayakpaul
Copy link
Member Author

sayakpaul commented Sep 10, 2024

Tests run: pytest tests/ -k "sharded"

Will run LoRA and other important tests too.

@sayakpaul
Copy link
Member Author

@yiyixuxu I think this is ready for another review.

src/diffusers/pipelines/pipeline_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

@yiyixuxu

There's another issue we need to settle on before #9253 (comment).

Since we decided to not touch

def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:

as we never supported the legacy variant sharding checkpoint format on the pipeline level. So, what is happening as a consequence of that is now we have:

safetensors_variant_filenames={'text_encoder/model.fp16-00003-of-00004.safetensors', 'safety_checker/model.fp16.safetensors', 'text_encoder/model.fp16-00004-of-00004.safetensors', 'text_encoder/model.fp16-00002-of-00004.safetensors', 'text_encoder/model.fp16-00001-of-00004.safetensors', 'vae/diffusion_pytorch_model.fp16.safetensors'}

safetensors_model_filenames={'text_encoder/model.fp16-00003-of-00004.safetensors', 'safety_checker/model.fp16.safetensors', 'text_encoder/model.fp16-00002-of-00004.safetensors', 'text_encoder/model.fp16-00004-of-00004.safetensors', 'unet/diffusion_pytorch_model-00002-of-00002.safetensors', 'text_encoder/model.fp16-00001-of-00004.safetensors', 'unet/diffusion_pytorch_model-00001-of-00002.safetensors', 'vae/diffusion_pytorch_model.fp16.safetensors'}

When "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds" is called on DiffusionPipeline.from_pretrained() with variant="fp16". So this is why it'd fail the safetensors_model_filenames != safetensors_variant_filenames check. What we could is this (with slight modifications made to _check_legacy_sharding_variant_format())

if (
    len(safetensors_variant_filenames) > 0
    and safetensors_model_filenames != safetensors_variant_filenames
-    and not is_sharded
+    and not _check_legacy_sharding_variant_format(filenames)
):

But note that this will still not allow us to parse just the legacy variant shard checkpoints from the pipeline-level. In order for us to only parse the variant (legacy) files we need to adjust variant_compatible_siblings(), IMO.

It's not a problem for the model-level because of how it's handled here:

def _get_checkpoint_shard_files(

LMK if I am missing something here.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 24, 2024

@sayakpaul
since we have never actually supported loading that checkpoint from pipeine-level and already deprecating it (for loading from model level), we do not need to start supporting it now. a check and warn to ask them to save the checkpoint to correct format is sufficient. The check fail as expected

But note that this will still not allow us to parse just the legacy variant shard checkpoints from the pipeline-level. In

@sayakpaul
Copy link
Member Author

@yiyixuxu sorry about the back and forth but I think it's necessary we get this right.

Our pipeline-level warning is here (as decided earlier):

if variant is not None and _check_legacy_sharding_variant_format(cached_folder, variant):

It will warn as long as we're hitting:

cached_folder = pretrained_model_name_or_path

But not for

cached_folder = cls.download(

Why?
Because we're not downloading the (legacy) variant sharded checkpoint files as variant_compatible_siblings() is not returning them:

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

We could maybe check if the filenames here

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
correspond to the legacy format and throw a warning from there? Would that be okay?

I am bringing this up because we'd want to make sure unified behaviour for both local and remote loading.

@yiyixuxu
Copy link
Collaborator

this warning here will cover both case, it is not specific to cached_folder = pretrained_model_name_or_path

        if not os.path.isdir(pretrained_model_name_or_path):
            ...
            cached_folder = cls.download(
                pretrained_model_name_or_path,
            )
        else:
            cached_folder = pretrained_model_name_or_path

        # The variant filenames can have the legacy sharding checkpoint format that we check and throw
        # a warning if detected.
        if variant is not None and _check_legacy_sharding_variant_format(cached_folder, variant):
            warn_msg = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
            logger.warning(warn_msg)

@sayakpaul
Copy link
Member Author

I don't think so.

cls.download() will NOT download the (legacy) variant sharded checkpoint files because of what I mentioned in #9253 (comment). To confirm that I printed the files that are getting downloaded and here's the log:

model_filenames={'vae/diffusion_pytorch_model.fp16.safetensors', 'text_encoder/pytorch_model.fp16-00003-of-00004.bin', 'unet/diffusion_pytorch_model-00001-of-00002.safetensors', 'text_encoder/pytorch_model.fp16-00001-of-00004.bin', 'text_encoder/pytorch_model.fp16-00002-of-00004.bin', 'text_encoder/pytorch_model.fp16-00004-of-00004.bin', 'unet/diffusion_pytorch_model.safetensors.index.json', 'text_encoder/model.fp16-00004-of-00004.safetensors', 'vae/diffusion_pytorch_model.fp16.bin', 'text_encoder/model.fp16-00003-of-00004.safetensors', 'text_encoder/model.safetensors.index.fp16.json', 'unet/diffusion_pytorch_model-00002-of-00002.safetensors', 'text_encoder/pytorch_model.bin.index.fp16.json', 'text_encoder/model.fp16-00001-of-00004.safetensors', 'vae/diffusion_flax_model.msgpack', 'text_encoder/model.fp16-00002-of-00004.safetensors', 'safety_checker/model.fp16.safetensors', 'safety_checker/pytorch_model.fp16.bin'}
variant_filenames={'vae/diffusion_pytorch_model.fp16.safetensors', 'text_encoder/pytorch_model.fp16-00003-of-00004.bin', 'text_encoder/pytorch_model.fp16-00001-of-00004.bin', 'text_encoder/pytorch_model.fp16-00002-of-00004.bin', 'text_encoder/pytorch_model.fp16-00004-of-00004.bin', 'text_encoder/model.fp16-00004-of-00004.safetensors', 'vae/diffusion_pytorch_model.fp16.bin', 'text_encoder/model.fp16-00003-of-00004.safetensors', 'text_encoder/model.safetensors.index.fp16.json', 'text_encoder/pytorch_model.bin.index.fp16.json', 'text_encoder/model.fp16-00001-of-00004.safetensors', 'text_encoder/model.fp16-00002-of-00004.safetensors', 'safety_checker/model.fp16.safetensors', 'safety_checker/pytorch_model.fp16.bin'}

Notice that in model_filenames we're not picking the (legacy) variants associated to the UNet because

def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
is unable to match them with regex. We use model_filenames to craft our allow_patterns:
allow_patterns = list(model_filenames)

We can further confirm this by printing the contents of each subfolder from here:

cached_folder = snapshot_download(

unet = ['diffusion_pytorch_model-00001-of-00002.safetensors', 'config.json', 'diffusion_pytorch_model-00002-of-00002.safetensors', 'diffusion_pytorch_model.safetensors.index.json']
text_encoder = ['model.fp16-00003-of-00004.safetensors', 'model.safetensors.index.fp16.json', 'config.json', 'model.fp16-00002-of-00004.safetensors', 'model.fp16-00001-of-00004.safetensors', 'model.fp16-00004-of-00004.safetensors']
vae = ['config.json', 'diffusion_pytorch_model.fp16.safetensors']

Hopefully, my concern is clear now.

@yiyixuxu
Copy link
Collaborator

i see!! let's add a warning inside download too then
make sure _check_legacy_sharding_variant_format accepts files or folder

@sayakpaul
Copy link
Member Author

@yiyixuxu done.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

tests/models/test_modeling_common.py Outdated Show resolved Hide resolved
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolde", "unet"),
]
)
def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nice test! let's also add to the @parameterized to test non-variant(if not already tested), and device_map

Copy link
Member Author

@sayakpaul sayakpaul Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Added parameterized to have subfolder and variant testing in all the sharding tests here:
    https://github.com/huggingface/diffusers/blob/main/tests/models/unets/test_models_unet_2d_condition.py

  2. Modified this test to have non-variant checkpoints as well.

Ran everything with "pytest tests/models/ -k "sharded" and it was green.

Commit: 1190f7d

@sayakpaul
Copy link
Member Author

Shipping this since I have an approval.

@sayakpaul sayakpaul merged commit 1154243 into main Sep 28, 2024
18 checks passed
@sayakpaul sayakpaul deleted the variant-tests branch September 28, 2024 04:27
leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
* fix variant-idenitification.

* fix variant

* fix sharded variant checkpoint loading.

* Apply suggestions from code review

* fixes.

* more fixes.

* remove print.

* fixes

* fixes

* comments

* fixes

* apply suggestions.

* hub_utils.py

* fix test

* updates

* fixes

* fixes

* Apply suggestions from code review

Co-authored-by: YiYi Xu <[email protected]>

* updates.

* removep patch file.

---------

Co-authored-by: YiYi Xu <[email protected]>
@DN6 DN6 mentioned this pull request Nov 5, 2024
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants