-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] fix: lora loading when using with a device_mapped model. #9449
Conversation
Does diffusers have multi GPU tests? If yes, would it make sense to add a test there and check that after LoRA loading, no parameter was transferred to meta device? |
That is a TODO ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a TODO ;)
I see. In that case, I have just some nits, otherwise I'd defer to Marc as I'm not an expert on device maps.
@BenjaminBossan yes, we do: https://github.com/search?q=repo%3Ahuggingface%2Fdiffusers%20require_torch_multi_gpu&type=code But not for the use case, being described here. Will add them as a part of this PR. |
Co-authored-by: Benjamin Bossan <[email protected]>
@SunMarc a gentle ping when you find a moment. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Just a few suggestions !
Co-authored-by: Marc Sun <[email protected]>
@yiyixuxu can you give this an initial look and once we agree, I will work on adding testing, docs, etc. |
@yiyixuxu a gentle ping for a first review as it touches |
@@ -398,9 +399,18 @@ def _optionally_disable_offloading(cls, _pipeline): | |||
is_model_cpu_offload = False | |||
is_sequential_cpu_offload = False | |||
|
|||
def model_has_device_map(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After-effects of make fix-copies
.
@@ -387,6 +387,11 @@ def to(self, *args, **kwargs): | |||
|
|||
device = device or device_arg | |||
|
|||
def model_has_device_map(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 it would make sense to make this a separate utility instead of having redefine three times. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, you can add as a util function inside pipeline_utils.
@slow | ||
@nightly | ||
def test_calling_to_raises_error_device_mapped_components(self): | ||
if "Combined" in self.pipeline_class.__name__: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because for connected pipelines, we don't support device mapping in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this, LGTM.
Failing tests are unrelated. |
* fix: lora loading when using with a device_mapped model. * better attibutung * empty Co-authored-by: Benjamin Bossan <[email protected]> * Apply suggestions from code review Co-authored-by: Marc Sun <[email protected]> * minors * better error messages. * fix-copies * add: tests, docs. * add hardware note. * quality * Update docs/source/en/training/distributed_inference.md Co-authored-by: Steven Liu <[email protected]> * fixes * skip properly. * fixes --------- Co-authored-by: Benjamin Bossan <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Steven Liu <[email protected]>
What does this PR do?
Fixes LoRA loading behaviour when used with a model that is sharded into multiple devices.
Minimal code
Some internal discussions:
Cc: @philschmid for awareness as you were interested in this feature.
TODOs
Once I get a sanity review from Marc and Benjamin, will request a review from Yiyi.