Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression with FLUX diffusers LoRA models #6997

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
pack,
unpack,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
Expand Down Expand Up @@ -209,7 +209,7 @@ def _run_diffusion(
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
cached_weights=cached_weights,
)
)
Expand All @@ -220,7 +220,7 @@ def _run_diffusion(
LoRAPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
Expand Down Expand Up @@ -101,7 +101,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_KOHYA_CLIP_PREFIX,
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
Expand Down Expand Up @@ -189,7 +190,9 @@ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0

return LoRAModelRaw(layers=layers)
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}

return LoRAModelRaw(layers=layers_with_prefix)


def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
Expand All @@ -23,11 +24,6 @@
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"


# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"


def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.

Expand Down Expand Up @@ -67,9 +63,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)

# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)
Expand Down
3 changes: 3 additions & 0 deletions invokeai/backend/lora/conversions/flux_lora_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
Expand Down Expand Up @@ -50,6 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())


def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_PREFIX,
FLUX_KOHYA_TRANFORMER_PREFIX,
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
Expand Down Expand Up @@ -95,8 +94,8 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
expected_layer_keys: set[str] = set()
for k in sd_keys:
# Replace prefixes.
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
k = k.replace("lora_unet_", FLUX_LORA_TRANSFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_LORA_CLIP_PREFIX)
# Remove suffixes.
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
Expand Down
Loading