Skip to content

Commit

Permalink
Revert "[LoRA] fix: lora loading when using with a device_mapped mode… (
Browse files Browse the repository at this point in the history
#9823)

Revert "[LoRA] fix: lora loading when using with a device_mapped model. (#9449)"

This reverts commit 41e4779.
  • Loading branch information
yiyixuxu authored and a-r-r-o-w committed Nov 1, 2024
1 parent 3ef6487 commit a91e8ed
Show file tree
Hide file tree
Showing 22 changed files with 8 additions and 546 deletions.
2 changes: 0 additions & 2 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,3 @@ with torch.no_grad():
```

By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.

This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow.
12 changes: 1 addition & 11 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
delete_adapter_layers,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
Expand Down Expand Up @@ -215,18 +214,9 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
12 changes: 1 addition & 11 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_accelerate_version,
is_peft_version,
is_torch_version,
logging,
Expand Down Expand Up @@ -399,18 +398,9 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
7 changes: 0 additions & 7 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
Expand Down Expand Up @@ -948,9 +947,3 @@ def _get_ignore_patterns(
)

return ignore_patterns


def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
31 changes: 0 additions & 31 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
_update_init_kwargs_with_connected_pipeline,
load_sub_model,
maybe_raise_or_warn,
model_has_device_map,
variant_compatible_siblings,
warn_deprecated_model_variant,
)
Expand Down Expand Up @@ -407,16 +406,6 @@ def module_is_offloaded(module):

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`."
)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
Expand Down Expand Up @@ -1013,16 +1002,6 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
Expand Down Expand Up @@ -1125,16 +1104,6 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`."
)

if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
Expand Down
5 changes: 0 additions & 5 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,14 +506,9 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

@unittest.skip("Test currently not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass

@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass


@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
Expand Down
24 changes: 0 additions & 24 deletions tests/pipelines/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,18 +514,6 @@ def test_inference_multiple_prompt_input(self):

assert image.shape == (4, 64, 64, 3)

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -709,18 +697,6 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
12 changes: 0 additions & 12 deletions tests/pipelines/controlnet/test_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,6 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
12 changes: 0 additions & 12 deletions tests/pipelines/controlnet/test_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,18 +441,6 @@ def test_save_pretrained_raise_not_implemented_exception(self):
except NotImplementedError:
pass

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
24 changes: 0 additions & 24 deletions tests/pipelines/controlnet/test_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,18 +683,6 @@ def test_inference_batch_single_identical(self):
def test_save_load_optional_components(self):
return self._test_save_load_optional_components()

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
Expand Down Expand Up @@ -899,18 +887,6 @@ def test_negative_conditions(self):

self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)

@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass

@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass


@slow
@require_torch_gpu
Expand Down
Loading

0 comments on commit a91e8ed

Please sign in to comment.