Skip to content

Commit

Permalink
fix(api): improve summing of mismatched weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 21, 2023
1 parent f0109d3 commit 3e8f4b3
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@
logger = getLogger(__name__)


def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray:
logger.trace("summing weights with shapes: %s + %s", a.shape, b.shape)

# get the kernel size from the tensor with the higher rank
if len(a.shape) > len(b.shape):
kernel = a.shape[-2:]
hr = a
lr = b
else:
kernel = b.shape[-2:]
hr = b
lr = a

if kernel == (1, 1):
lr = np.expand_dims(lr, axis=(2, 3))

return hr + lr


def buffer_external_data_tensors(
model: ModelProto,
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
Expand Down Expand Up @@ -150,14 +169,8 @@ def blend_loras(

np_weights *= lora_weight
if base_key in blended:
blended_weights = blended[base_key]
logger.trace("summing LoHA weights: %s + %s", blended_weights.shape, np_weights.shape)

if blended_weights.shape != np_weights.shape and kernel == (1, 1):
logger.debug("expanding mismatched weights for 1x1 kernel")
blended[base_key] = np.expand_dims(blended_weights, axis=(2, 3))

blended[base_key] += np_weights
logger.trace("summing LoHA weights: %s + %s", blended[base_key].shape, np_weights.shape)
blended[base_key] += sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights
elif ".lora_down" in key and lora_prefix in key:
Expand Down Expand Up @@ -265,14 +278,8 @@ def blend_loras(

np_weights *= lora_weight
if base_key in blended:
blended_weights = blended[base_key]
logger.trace("summing weights: %s + %s", blended_weights.shape, np_weights.shape)

if blended_weights.shape != np_weights.shape and kernel == (1, 1):
logger.debug("expanding mismatched weights for 1x1 kernel")
blended[base_key] = np.expand_dims(blended_weights, axis=(2, 3))

blended[base_key] += np_weights
logger.trace("summing weights: %s + %s", blended[base_key].shape, np_weights.shape)
blended[base_key] = sum_weights(blended[base_key], np_weights)
else:
blended[base_key] = np_weights

Expand Down

0 comments on commit 3e8f4b3

Please sign in to comment.