Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] fix: lora loading when using with a device_mapped model. #9449

Merged
merged 29 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dc1aee2
fix: lora loading when using with a device_mapped model.
sayakpaul Sep 17, 2024
949a929
better attibutung
sayakpaul Sep 17, 2024
64b3ad1
empty
sayakpaul Sep 17, 2024
6d03c12
Merge branch 'main' into lora-device-map
sayakpaul Sep 22, 2024
d4bd94b
Merge branch 'main' into lora-device-map
sayakpaul Sep 24, 2024
5479198
Apply suggestions from code review
sayakpaul Sep 24, 2024
2846549
Merge branch 'main' into lora-device-map
sayakpaul Sep 27, 2024
1ed0eb0
Merge branch 'main' into lora-device-map
sayakpaul Sep 28, 2024
d2d59c3
Merge branch 'main' into lora-device-map
sayakpaul Oct 2, 2024
5f3cae2
Merge branch 'main' into lora-device-map
sayakpaul Oct 6, 2024
8f670e2
Merge branch 'main' into lora-device-map
sayakpaul Oct 8, 2024
e42ec19
Merge branch 'main' into lora-device-map
sayakpaul Oct 10, 2024
f63b04c
Merge branch 'main' into lora-device-map
sayakpaul Oct 15, 2024
eefda54
Merge branch 'main' into lora-device-map
sayakpaul Oct 19, 2024
ea727a3
minors
sayakpaul Oct 19, 2024
71989e3
better error messages.
sayakpaul Oct 19, 2024
f62afac
fix-copies
sayakpaul Oct 19, 2024
2334f78
add: tests, docs.
sayakpaul Oct 19, 2024
5ea1173
add hardware note.
sayakpaul Oct 19, 2024
f64751e
Merge branch 'main' into lora-device-map
sayakpaul Oct 19, 2024
c0dee87
quality
sayakpaul Oct 19, 2024
4b6124a
Merge branch 'main' into lora-device-map
sayakpaul Oct 22, 2024
fe2cca8
Update docs/source/en/training/distributed_inference.md
sayakpaul Oct 23, 2024
2db5d48
Merge branch 'main' into lora-device-map
sayakpaul Oct 23, 2024
61903c8
Merge branch 'main' into lora-device-map
sayakpaul Oct 31, 2024
03377b7
fixes
sayakpaul Oct 31, 2024
0bd40cb
skip properly.
sayakpaul Oct 31, 2024
a61b754
fixes
sayakpaul Oct 31, 2024
ccd8d2a
resolve conflicts.
sayakpaul Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,5 @@ with torch.no_grad():
```

By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.

This workflow is also compatible when working with LoRAs via `load_lora_weights()`. However, note that only LoRAs not involving any text encoder components are supported in this workflow at the moment.
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 11 additions & 1 deletion src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
delete_adapter_layers,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_transformers_available,
logging,
Expand Down Expand Up @@ -214,9 +215,18 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_accelerate_version,
is_peft_version,
is_torch_version,
logging,
Expand Down Expand Up @@ -398,9 +399,18 @@ def _optionally_disable_offloading(cls, _pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

def model_has_device_map(model):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After-effects of make fix-copies.

if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
Expand Down
47 changes: 47 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,11 @@ def to(self, *args, **kwargs):

device = device or device_arg

def model_has_device_map(model):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 it would make sense to make this a separate utility instead of having redefine three times. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, you can add as a util function inside pipeline_utils.

if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
Expand All @@ -406,6 +411,16 @@ def module_is_offloaded(module):

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`."
)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
Expand Down Expand Up @@ -1002,6 +1017,22 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
Expand Down Expand Up @@ -1104,6 +1135,22 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""

def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None

# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`."
)

if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
Expand Down
5 changes: 5 additions & 0 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,14 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

@unittest.skip("Test currently not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass

@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass


@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
Expand Down
86 changes: 86 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -249,3 +251,87 @@ def test_flux_inference(self):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

assert max_diff < 1e-4

@require_torch_multi_gpu
@torch.no_grad()
def test_flux_component_sharding(self):
"""
internal note: test was run on `audace`.
"""

ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
prompt = "a photo of a cat with tiger-like look"

pipeline = FluxPipeline.from_pretrained(
ckpt_id,
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=dtype,
)
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)

del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline

gc.collect()
torch.cuda.empty_cache()

transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=dtype,
)

height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=10,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
generator=torch.manual_seed(0),
).images
latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy()
expected_slice = np.array([-0.377, -0.3008, -0.5117, -0.252, 0.0615, -0.3477, -0.1309, -0.1914, 0.1533])

assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4

del pipeline.transformer
del pipeline

gc.collect()
torch.cuda.empty_cache()

vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="np")
image_slice = image[0, :3, :3, -1].flatten()
expected_slice = np.array([0.127, 0.1113, 0.1055, 0.1172, 0.1172, 0.1074, 0.1191, 0.1191, 0.1152])

assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
4 changes: 4 additions & 0 deletions tests/pipelines/musicldm/test_musicldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass


@nightly
@require_torch_gpu
Expand Down
102 changes: 102 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.models.adapter import MultiAdapter
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
nightly,
require_torch,
require_torch_multi_gpu,
skip_mps,
slow,
torch_device,
)

Expand All @@ -59,6 +64,10 @@
from ..others.test_utils import TOKEN, USER, is_staging_test


if is_accelerate_available():
from accelerate.utils import compute_module_sizes


def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
Expand Down Expand Up @@ -1907,6 +1916,99 @@ def test_StableDiffusionMixin_component(self):
)
)

@require_torch_multi_gpu
@slow
@nightly
def test_calling_to_raises_error_device_mapped_components(self):
if "Combined" in self.pipeline_class.__name__:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because for connected pipelines, we don't support device mapping in the first place.

return

# TODO (sayakpaul): skip these for now. revisit later.
components = self.get_dummy_components()
if any(isinstance(component, (MultiControlNetModel, MultiAdapter)) for component in components):
return

pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)

with self.assertRaises(ValueError) as err_context:
loaded_pipe.to(torch_device)

self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with explicitly setting the device using `to()`" in str(err_context.exception)
)

@require_torch_multi_gpu
@slow
@nightly
def test_calling_mco_raises_error_device_mapped_components(self):
if "Combined" in self.pipeline_class.__name__:
return

# TODO (sayakpaul): skip these for now. revisit later.
components = self.get_dummy_components()
if any(isinstance(component, (MultiControlNetModel, MultiAdapter)) for component in components):
return

pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)

with self.assertRaises(ValueError) as err_context:
loaded_pipe.enable_model_cpu_offload()

self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with `enable_model_cpu_offload()`" in str(err_context.exception)
)

@require_torch_multi_gpu
@slow
@nightly
def test_calling_sco_raises_error_device_mapped_components(self):
if "Combined" in self.pipeline_class.__name__:
return

# TODO (sayakpaul): skip these for now. revisit later.
components = self.get_dummy_components()
if any(isinstance(component, (MultiControlNetModel, MultiAdapter)) for component in components):
return

pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)

with self.assertRaises(ValueError) as err_context:
loaded_pipe.enable_sequential_cpu_offload()

self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with `enable_sequential_cpu_offload()`" in str(err_context.exception)
)


@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
Expand Down
Loading