Skip to content

Commit

Permalink
fix(api): add theoretical support for 3x3 conv in LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 16, 2023
1 parent 8e8e230 commit 315e5a3
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def blend_loras(
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1
logger.debug(
"blending weights for Conv node: %s, %s, %s",
"blending weights for Conv 1x1 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
Expand All @@ -121,8 +121,17 @@ def blend_loras(
.unsqueeze(3)
)
np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
# blend for nn.Conv2d 3x3
logger.debug(
"blending weights for Conv 3x3 node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
np_weights = weights.numpy() * (alpha / dim)
else:
# TODO: add support for Conv2d 3x3
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,
Expand Down

0 comments on commit 315e5a3

Please sign in to comment.