Skip to content

Commit

Permalink
fix(api): only run SDXL LoRA node matching on XL models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 5, 2023
1 parent 5d0d904 commit ea9023c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 5 additions & 4 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def blend_loras(
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
xl: Optional[bool] = False,
):
# always load to CPU for blending
device = torch.device("cpu")
Expand Down Expand Up @@ -394,14 +395,14 @@ def blend_loras(
blended[base_key] = np_weights

# rewrite node names for XL
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
if xl:
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)

logger.trace(
"updating %s of %s initializers, %s missed",
"updating %s of %s initializers",
len(blended.keys()),
len(base_model.graph.initializer),
len(nodes),
)

fixed_initializer_names = [
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def load_pipeline(
list(zip(lora_models, lora_weights)),
"text_encoder",
1 if params.is_xl() else None,
params.is_xl(),
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
Expand Down Expand Up @@ -284,6 +285,7 @@ def load_pipeline(
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
params.is_xl()
)
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
text_encoder2
Expand Down Expand Up @@ -311,6 +313,7 @@ def load_pipeline(
unet,
list(zip(lora_models, lora_weights)),
"unet",
xl=params.is_xl(),
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
Expand Down

0 comments on commit ea9023c

Please sign in to comment.