diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 5ae5ea14a..38be5f78a 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -20,6 +20,25 @@ logger = getLogger(__name__) +def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray: + logger.trace("summing weights with shapes: %s + %s", a.shape, b.shape) + + # get the kernel size from the tensor with the higher rank + if len(a.shape) > len(b.shape): + kernel = a.shape[-2:] + hr = a + lr = b + else: + kernel = b.shape[-2:] + hr = b + lr = a + + if kernel == (1, 1): + lr = np.expand_dims(lr, axis=(2, 3)) + + return hr + lr + + def buffer_external_data_tensors( model: ModelProto, ) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]: @@ -150,14 +169,8 @@ def blend_loras( np_weights *= lora_weight if base_key in blended: - blended_weights = blended[base_key] - logger.trace("summing LoHA weights: %s + %s", blended_weights.shape, np_weights.shape) - - if blended_weights.shape != np_weights.shape and kernel == (1, 1): - logger.debug("expanding mismatched weights for 1x1 kernel") - blended[base_key] = np.expand_dims(blended_weights, axis=(2, 3)) - - blended[base_key] += np_weights + logger.trace("summing LoHA weights: %s + %s", blended[base_key].shape, np_weights.shape) + blended[base_key] += sum_weights(blended[base_key], np_weights) else: blended[base_key] = np_weights elif ".lora_down" in key and lora_prefix in key: @@ -265,14 +278,8 @@ def blend_loras( np_weights *= lora_weight if base_key in blended: - blended_weights = blended[base_key] - logger.trace("summing weights: %s + %s", blended_weights.shape, np_weights.shape) - - if blended_weights.shape != np_weights.shape and kernel == (1, 1): - logger.debug("expanding mismatched weights for 1x1 kernel") - blended[base_key] = np.expand_dims(blended_weights, axis=(2, 3)) - - blended[base_key] += np_weights + logger.trace("summing weights: %s + %s", blended[base_key].shape, np_weights.shape) + blended[base_key] = sum_weights(blended[base_key], np_weights) else: blended[base_key] = np_weights