Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] fix: lora loading when using with a device_mapped model. #9449

Merged
merged 29 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dc1aee2
fix: lora loading when using with a device_mapped model.
sayakpaul Sep 17, 2024
949a929
better attibutung
sayakpaul Sep 17, 2024
64b3ad1
empty
sayakpaul Sep 17, 2024
6d03c12
Merge branch 'main' into lora-device-map
sayakpaul Sep 22, 2024
d4bd94b
Merge branch 'main' into lora-device-map
sayakpaul Sep 24, 2024
5479198
Apply suggestions from code review
sayakpaul Sep 24, 2024
2846549
Merge branch 'main' into lora-device-map
sayakpaul Sep 27, 2024
1ed0eb0
Merge branch 'main' into lora-device-map
sayakpaul Sep 28, 2024
d2d59c3
Merge branch 'main' into lora-device-map
sayakpaul Oct 2, 2024
5f3cae2
Merge branch 'main' into lora-device-map
sayakpaul Oct 6, 2024
8f670e2
Merge branch 'main' into lora-device-map
sayakpaul Oct 8, 2024
e42ec19
Merge branch 'main' into lora-device-map
sayakpaul Oct 10, 2024
f63b04c
Merge branch 'main' into lora-device-map
sayakpaul Oct 15, 2024
eefda54
Merge branch 'main' into lora-device-map
sayakpaul Oct 19, 2024
ea727a3
minors
sayakpaul Oct 19, 2024
71989e3
better error messages.
sayakpaul Oct 19, 2024
f62afac
fix-copies
sayakpaul Oct 19, 2024
2334f78
add: tests, docs.
sayakpaul Oct 19, 2024
5ea1173
add hardware note.
sayakpaul Oct 19, 2024
f64751e
Merge branch 'main' into lora-device-map
sayakpaul Oct 19, 2024
c0dee87
quality
sayakpaul Oct 19, 2024
4b6124a
Merge branch 'main' into lora-device-map
sayakpaul Oct 22, 2024
fe2cca8
Update docs/source/en/training/distributed_inference.md
sayakpaul Oct 23, 2024
2db5d48
Merge branch 'main' into lora-device-map
sayakpaul Oct 23, 2024
61903c8
Merge branch 'main' into lora-device-map
sayakpaul Oct 31, 2024
03377b7
fixes
sayakpaul Oct 31, 2024
0bd40cb
skip properly.
sayakpaul Oct 31, 2024
a61b754
fixes
sayakpaul Oct 31, 2024
ccd8d2a
resolve conflicts.
sayakpaul Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,5 @@ with torch.no_grad():
```

By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.

This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow.
12 changes: 11 additions & 1 deletion src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
delete_adapter_layers,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
Expand Down Expand Up @@ -214,9 +215,18 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_accelerate_version,
is_peft_version,
is_torch_version,
logging,
Expand Down Expand Up @@ -398,9 +399,18 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After-effects of make fix-copies.

if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
7 changes: 7 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
Expand Down Expand Up @@ -947,3 +948,9 @@ def _get_ignore_patterns(
)

return ignore_patterns


def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
31 changes: 31 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
_update_init_kwargs_with_connected_pipeline,
load_sub_model,
maybe_raise_or_warn,
model_has_device_map,
variant_compatible_siblings,
warn_deprecated_model_variant,
)
Expand Down Expand Up @@ -406,6 +407,16 @@ def module_is_offloaded(module):

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`."
)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
Expand Down Expand Up @@ -1002,6 +1013,16 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
Expand Down Expand Up @@ -1104,6 +1125,16 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`."
)

if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
Expand Down
5 changes: 5 additions & 0 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,14 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

@unittest.skip("Test currently not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass

@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass


@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
Expand Down
24 changes: 24 additions & 0 deletions tests/pipelines/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ def test_inference_multiple_prompt_input(self):

assert image.shape == (4, 64, 64, 3)

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -697,6 +709,18 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,18 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
24 changes: 24 additions & 0 deletions tests/pipelines/controlnet/test_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,18 @@ def test_inference_batch_single_identical(self):
def test_save_load_optional_components(self):
return self._test_save_load_optional_components()

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -887,6 +899,18 @@ def test_negative_conditions(self):

self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
Loading
Loading