Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sharding when no device_map is passed #8531

Merged
merged 9 commits into from
Jun 18, 2024

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Jun 13, 2024

What does this PR do?

This PR fixes the loading for sharded checkpoint when no device_map is passed. Currently, the following doesn't work:

from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline 
import torch

unet = UNet2DConditionModel.from_pretrained(
    "sayakpaul/sdxl-unet-sharded", torch_dtype = torch.float16
)

You can have more details here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc
Copy link
Member Author

SunMarc commented Jun 13, 2024

There is still a path where sharding is not handled. It happens when low_cpu_mem_usage=False. I see that by default, low_cpu_mem_usage is set to True, it is the case for most models ? cc @sayakpaul

@SunMarc SunMarc requested review from sayakpaul and yiyixuxu and removed request for yiyixuxu June 13, 2024 12:51
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!
very nice tests too:)

is it possible to explain device_map=None in the doc string for device_map too?

@SunMarc
Copy link
Member Author

SunMarc commented Jun 14, 2024

is it possible to explain device_map=None in the doc string for device_map too?

Done !

@@ -872,6 +872,39 @@ def test_model_parallelism(self):

@require_torch_gpu
def test_sharded_checkpoints(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is already here:

def test_sharded_checkpoints(self):

Is it different?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he renamed this test to test_sharded_checkpoints_device_map because in that test it loads with device_map='auto' flag; this is a new test testing default value for device_map

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I renamed the tests since it makes more sense this way

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much, Marc. I think there's some confusion in the tests as they are existing in the main already. Am I missing out on something?

@sayakpaul
Copy link
Member

Alright then! Let’s merge this.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@yiyixuxu yiyixuxu merged commit 96399c3 into huggingface:main Jun 18, 2024
14 of 15 checks passed
yiyixuxu pushed a commit that referenced this pull request Jun 20, 2024
* Fix sharding when no device_map is passed

* style

* add tests

* align

* add docstring

* format

---------

Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants