Skip to content

Commit

Permalink
add experimental LoRA blender
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 22, 2023
1 parent 43761b5 commit 7e966f7
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from numpy import ndarray
from onnx import TensorProto, helper, load, numpy_helper, ModelProto, save_model
from typing import Dict, List, Tuple
from logging import getLogger


logger = getLogger(__name__)


def load_lora(filename: str):
model = load(filename)

for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass

return model


def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)

results = []

for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()

for weight, alpha in zip(weights, alphas):
weight_node = next(iter([f for f in weight.graph.initializer if f.name == base_node.name]), None)

if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning("missing weights: %s in %s", base_node.name, weight.doc_string)

results.append((base_node, base_weights / total))

return results


def convert_loras(part: str):
lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx",
]

base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)

result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))

del weights
del alphas

tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))

del result

graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)

del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14

save_model(
model,
f"/tmp/lora-{part}.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"/tmp/lora-{part}.tensors",
)
logger.info("saved model to %s and tensors to %s", f"/tmp/lora-{part}.onnx", f"/tmp/lora-{part}.tensors")


if __name__ == "__main__":
convert_loras("unet")
convert_loras("text_encoder")

0 comments on commit 7e966f7

Please sign in to comment.