Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for textual inversion #199

Merged
merged 17 commits into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ package-upload:
twine upload dist/*

lint-check:
black --check --preview onnx_web/
black --check onnx_web/
isort --check-only --skip __init__.py --filter-files onnx_web
flake8 onnx_web

Expand Down
26 changes: 26 additions & 0 deletions api/extras.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@
"name": "diffusion-unstable-ink-dream-v6",
"source": "civitai://5796",
"format": "safetensors"
},
{
"name": "stable-diffusion-onnx-v1-5",
"source": "runwayml/stable-diffusion-v1-5",
"inversions": [
{
"name": "line-art",
"source": "sd-concepts-library/line-art"
},
{
"name": "cubex",
"source": "sd-concepts-library/cubex"
},
{
"name": "birb",
"source": "sd-concepts-library/birb-style"
},
{
"name": "minecraft",
"source": "sd-concepts-library/minecraft-concept-art"
},
{
"name": "ugly-sonic",
"source": "sd-concepts-library/ugly-sonic"
}
]
}
],
"correction": [],
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def blend_img2img(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:
logger.debug("using LPW pipeline for img2img")
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)

if params.lpw:
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def source_txt2img(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)

if params.lpw:
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
)
if params.lpw:
logger.debug("using LPW pipeline for inpaint")
Expand Down
18 changes: 15 additions & 3 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from yaml import safe_load

from .correction_gfpgan import convert_correction_gfpgan
from .diffusion_original import convert_diffusion_original
from .diffusion_stable import convert_diffusion_stable
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
ConversionContext,
Expand Down Expand Up @@ -223,11 +224,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
source,
)
else:
convert_diffusion_stable(
convert_diffusion_diffusers(
ctx,
model,
source,
)

for inversion in model.get("inversions", []):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
ctx, inversion_name, model["source"], inversion_source
)

except Exception as e:
logger.error("error converting diffusion model %s: %s", name, e)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
from onnx import load, save_model
from torch.onnx import export

from onnx_web.diffusion.load import optimize_pipeline

from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
from ...diffusion.load import optimize_pipeline
from ...diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from .utils import ConversionContext
from ..utils import ConversionContext

logger = getLogger(__name__)

Expand Down Expand Up @@ -63,7 +62,7 @@ def onnx_export(


@torch.no_grad()
def convert_diffusion_stable(
def convert_diffusion_diffusers(
ctx: ConversionContext,
model: Dict,
source: str,
Expand Down
212 changes: 212 additions & 0 deletions api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from logging import getLogger
from os import path
from sys import argv
from typing import List, Tuple

import onnx.checker
import torch
from numpy import ndarray
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
from safetensors import safe_open

from ..utils import ConversionContext

logger = getLogger(__name__)


###
# everything in this file is still super experimental and may not produce valid ONNX models
###


def load_lora(filename: str):
model = load(filename)

for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass

return model


def blend_loras(
base: ModelProto, weights: List[ModelProto], alphas: List[float]
) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)

results = []

for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()

for weight, alpha in zip(weights, alphas):
weight_node = next(
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
None,
)

if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning(
"missing weights: %s in %s", base_node.name, weight.doc_string
)

results.append((base_node, base_weights / total))

return results


def convert_diffusion_lora(context: ConversionContext, component: str):
lora_weights = [
f"diffusion-lora-jack/{component}/model.onnx",
f"diffusion-lora-taters/{component}/model.onnx",
]

base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)

result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))

del weights
del alphas

tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))

del result

graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)

del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14

onnx_path = path.join(context.cache_path, f"lora-{component}.onnx")
tensor_path = path.join(context.cache_path, f"lora-{component}.tensors")
save_model(
model,
onnx_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=tensor_path,
)
logger.info(
"saved model to %s and tensors to %s",
onnx_path,
tensor_path,
)


def fix_key(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
return key.replace(".", "_")


def merge_lora():
base_name = argv[1]
lora_name = argv[2]

base_model = load(base_name)
lora_model = safe_open(lora_name, framework="pt")

lora_nodes = []
for base_node in base_model.graph.initializer:
base_key = fix_key(base_node.name)

for key in lora_model.keys():
if "lora_down" in key:
lora_key = key[: key.index("lora_down")].replace("lora_unet_", "")
if lora_key.startswith(base_key):
print("down for key:", base_key, lora_key)

up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"

down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)

dim = down_weight.size()[0]
alpha = lora_model.get(alpha_key).numpy() or dim

np_vals = numpy_helper.to_array(base_node)
print(np_vals.shape, up_weight.shape, down_weight.shape)

squoze = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)

np_vals = np_vals + (alpha * squoze.numpy())

try:
if len(up_weight.size()) == 2:
squoze = up_weight @ down_weight
print(squoze.shape)
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
else:
squoze = (
(
up_weight.squeeze(3).squeeze(2)
@ down_weight.squeeze(3).squeeze(2)
)
.unsqueeze(2)
.unsqueeze(3)
)
print(squoze.shape)
np_vals = np_vals + (alpha * squoze.numpy())

# retensor = numpy_helper.from_array(np_vals, base_node.name)
retensor = helper.make_tensor(
base_node.name,
base_node.data_type,
base_node.dim,
np_vals,
raw=True,
)
print(retensor)

# TypeError: does not support assignment
lora_nodes.append(retensor)

break
except Exception as e:
print(e)

if retensor is None:
print("no lora found for key", base_key)
lora_nodes.append(base_node)

print(len(lora_nodes), len(base_model.graph.initializer))
del base_model.graph.initializer[:]
base_model.graph.initializer.extend(lora_nodes)

onnx.checker.check_model(base_model)


if __name__ == "__main__":
context = ConversionContext.from_environ()
convert_diffusion_lora(context, "unet")
convert_diffusion_lora(context, "text_encoder")
Loading