-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LORA #213
Comments
So I've been looking into this, and in theory, it is possible: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L123 The difficulty seems to be in building a valid ONNX protobuf of the new model. That code is not hooked up yet because I haven't been able to make it reliably emit valid models, but I'm not sure if that is a bug with the merge code or with the model checker. Adding too many initializers to the graph at once can cause a segfault in |
I've made some progress and done some research that may or may not be worth writing down, and see two issues so far:
Some of those node names:
Operator type sanity check script:import onnx.numpy_helper
import torch.nn as nn
import torch.onnx
# make a net with single Conv2d
conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=4, stride=1, padding=1, dilation=1, groups=1, bias=True)
dummy_input = torch.randn(10, 3, 224, 224)
torch.onnx.export(conv, dummy_input, f="/tmp/onnx-conv2d.pb")
# load it back
conv2d_onnx = onnx.load("/tmp/onnx-conv2d.pb")
print("conv2d init:", [(n.name, len(n.raw_data)) for n in conv2d_onnx.graph.initializer])
print("conv2d node:", [(n.name, n.input) for n in conv2d_onnx.graph.node])
print("conv2d output:", conv2d_onnx.graph.output)
# make a net with single Linear
conv = nn.Linear(224, 6720, bias=True)
torch.onnx.export(conv, dummy_input, f="/tmp/onnx-linear.pb")
# load it back
lin_onnx = onnx.load("/tmp/onnx-linear.pb")
print("linear init:", [(n.name, len(n.raw_data)) for n in lin_onnx.graph.initializer])
print("linear node:", [(n.name, n.input) for n in lin_onnx.graph.node])
print("linear output:", lin_onnx.graph.output)
# shapes
print("conv2d shapes:", [(n.name, onnx.numpy_helper.to_array(n).shape) for n in conv2d_onnx.graph.initializer])
print("linear shapes:", [(n.name, onnx.numpy_helper.to_array(n).shape) for n in lin_onnx.graph.initializer]) Operator script output:
tl;dr: |
Looking at this further, with a little bit of progress: It looks like ORT does offer a way to load external data from memory, so blending and even converting models without ever writing them to disk should be possible. For most of the large base models, saving them will still make sense, but that should save some SSD wear for LoRAs. I was able to blend some models by looking up the MatMul nodes and write that out as a valid ONNX model, at least valid enough to load and run inference, but it comes out as random colored spots. Rather than guess at what the nodes mean, I'm writing a script to diff the ONNX models and working backwards from there. diff script:from logging import getLogger, basicConfig, DEBUG
from onnx import load_model, ModelProto
from onnx.numpy_helper import to_array
from sys import argv, stdout
basicConfig(stream=stdout, level=DEBUG)
logger = getLogger(__name__)
def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
else:
for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
if ref_init.name != cmp_init.name:
logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
elif ref_init.data_location != cmp_init.data_location:
logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
elif ref_init.data_type != cmp_init.data_type:
logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
elif len(ref_init.raw_data) != len(cmp_init.raw_data):
logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
ref_data = to_array(ref_init)
cmp_data = to_array(cmp_init)
data_diff = ref_data - cmp_data
if data_diff.max() > 0:
logger.info("raw data differs: %s", data_diff)
else:
logger.info("initializers are identical in all checked fields: %s", ref_init.name)
if __name__ == "__main__":
ref_path = argv[1]
cmp_paths = argv[2:]
logger.info("loading reference model from %s", ref_path)
ref_model = load_model(ref_path)
for cmp_path in cmp_paths:
logger.info("loading comparison model from %s", cmp_path)
cmp_model = load_model(cmp_path)
diff_models(ref_model, cmp_model) My initial comparison of two text_encoders is that all of the changes to initializers are in those
|
I want to say one thing. Even with Diffusers, LoRA models sometimes give out strange artifacts |
I think I figured out the problem with the artifacts, or at least the cause: the script I have so far is converting the
The rest of the diff script:from logging import getLogger, basicConfig, DEBUG
from numpy import maximum
from onnx import load_model, ModelProto
from onnx.numpy_helper import to_array
from sys import argv, stdout
basicConfig(stream=stdout, level=DEBUG)
logger = getLogger(__name__)
def diff_models(ref_model: ModelProto, cmp_model: ModelProto):
if len(ref_model.graph.initializer) != len(cmp_model.graph.initializer):
logger.warning("different number of initializers: %s vs %s", len(ref_model.graph.initializer), len(cmp_model.graph.initializer))
else:
for (ref_init, cmp_init) in zip(ref_model.graph.initializer, cmp_model.graph.initializer):
if ref_init.name != cmp_init.name:
logger.info("different node names: %s vs %s", ref_init.name, cmp_init.name)
elif ref_init.data_location != cmp_init.data_location:
logger.info("different data locations: %s vs %s", ref_init.data_location, cmp_init.data_location)
elif ref_init.data_type != cmp_init.data_type:
logger.info("different data types: %s vs %s", ref_init.data_type, cmp_init.data_type)
elif len(ref_init.raw_data) != len(cmp_init.raw_data):
logger.info("different raw data size: %s vs %s", len(ref_init.raw_data), len(cmp_init.raw_data))
elif len(ref_init.raw_data) > 0 and len(cmp_init.raw_data) > 0:
ref_data = to_array(ref_init)
cmp_data = to_array(cmp_init)
data_diff = ref_data - cmp_data
if data_diff.max() > 0:
logger.info("raw data differs for %s: %s\n%s", ref_init.name, data_diff.max(), data_diff)
else:
logger.info("initializers are identical in all checked fields: %s", ref_init.name)
if len(ref_model.graph.node) != len(cmp_model.graph.node):
logger.warning("different number of nodes: %s vs %s", len(ref_model.graph.node), len(cmp_model.graph.node))
else:
for (ref_node, cmp_node) in zip(ref_model.graph.node, cmp_model.graph.node):
if ref_node.name != cmp_node.name:
logger.info("different node names: %s vs %s", ref_node.name, cmp_node.name)
elif ref_node.input != cmp_node.input:
logger.info("different inputs: %s vs %s", ref_node.input, cmp_node.input)
elif ref_node.output != cmp_node.output:
logger.info("different outputs: %s vs %s", ref_node.output, cmp_node.output)
elif ref_node.op_type != cmp_node.op_type:
logger.info("different op type: %s vs %s", ref_node.op_type, cmp_node.op_type)
else:
logger.info("nodes are identical in all checked fields: %s", ref_init.name)
if __name__ == "__main__":
ref_path = argv[1]
cmp_paths = argv[2:]
logger.info("loading reference model from %s", ref_path)
ref_model = load_model(ref_path)
for cmp_path in cmp_paths:
logger.info("loading comparison model from %s", cmp_path)
cmp_model = load_model(cmp_path)
diff_models(ref_model, cmp_model) |
I have something that seems to be working, #243, with a few caveats:
What it does support so far:
All of the good stuff is in https://github.com/ssube/onnx-web/blob/feat/213-lora/api/onnx_web/convert/diffusion/lora.py Important parts, for my own reference and anyone else who finds them useful:
|
I have this working pretty well for LoRAs produced by the sd-scripts repo and most Textual Inversions, but it doesn't support the cloneofsimo LoRAs yet (#206). Let me know how this works for you, @ForserX, and if you or @Amblyopius have any info about getting the other networks working (hypernetworks, etc) I would be very interested. I'm going to release what I have and the new ORT optimization stuff with 6/8GB support (#241) as v0.9. |
I hope to see it before the end of the month. While I'm dying at work... |
So, I lost a little "connection with the world"... |
Regarding hypernetwork: I have only seen the implementation of auto111. But to me, a person far from Python, his code is like a personal hell, in which I understand only "None" |
No need to merge. You can load the base model with base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras] The nodes need to have the right names, and running some of the more aggressive ORT optimization scripts will break that. The logic is pretty much normal, same as yours or sd-scripts, up until https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L170 |
Good job! I tested it myself - it works great! |
https: //github.com/ssube/onnx-web/issues/213 Co-Authored-By: Sean Sube <[email protected]>
# LoRA magic
There is a way to assemble LoRA + Diffusers on the fly. I will be glad if you figure out how to throw something like this in ONNX.
The text was updated successfully, but these errors were encountered: