Skip to content

Commit

Permalink
apply lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 16, 2023
1 parent 506cf9f commit 8e8e230
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 24 deletions.
7 changes: 6 additions & 1 deletion api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file

from ...server.context import ServerContext
from ..utils import ConversionContext

logger = getLogger(__name__)
Expand Down Expand Up @@ -55,6 +56,7 @@ def fix_node_name(key: str):


def blend_loras(
context: ServerContext,
base_name: str,
lora_names: List[str],
dest_type: Literal["text_encoder", "unet"],
Expand Down Expand Up @@ -236,6 +238,7 @@ def blend_loras(


if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
Expand All @@ -251,7 +254,9 @@ def blend_loras(
args.lora_weights,
)

blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
blend_model = blend_loras(
context, args.base, args.lora_models, args.type, args.lora_weights
)
if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
Expand Down
36 changes: 24 additions & 12 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model
from torch.onnx import export
from transformers import CLIPTokenizer

from ...server.context import ServerContext
Expand All @@ -25,11 +24,15 @@ def blend_textual_inversions(
inversion_weights: Optional[List[float]] = None,
base_tokens: Optional[List[str]] = None,
) -> Tuple[ModelProto, CLIPTokenizer]:
dtype = np.float # TODO: fixed type, which one?
# prev: text_encoder.get_input_embeddings().weight.dtype
dtype = np.float
embeds = {}

for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
for name, format, weight, base_token in zip(
inversion_names,
inversion_formats,
inversion_weights,
base_tokens or inversion_names,
):
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
if format == "concept":
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
Expand Down Expand Up @@ -64,7 +67,7 @@ def blend_textual_inversions(

for i in range(num_tokens):
token = f"{base_token or name}-{i}"
layer = trained_embeds[i,:].cpu().numpy().astype(dtype)
layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
layer *= weight
if token in embeds:
embeds[token] += layer
Expand All @@ -74,7 +77,9 @@ def blend_textual_inversions(
raise ValueError(f"unknown Textual Inversion format: {format}")

# add the tokens to the tokenizer
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
logger.info(
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
Expand All @@ -85,7 +90,11 @@ def blend_textual_inversions(

# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
embedding_weights = numpy_helper.to_array(embedding_node)

weights_dim = embedding_weights.shape[1]
Expand All @@ -94,15 +103,18 @@ def blend_textual_inversions(

for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.debug(
"embedding %s weights for token %s", weights.shape, token
)
logger.debug("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights

# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(np.float32), embedding_node.name
)
logger.debug("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
Expand Down
42 changes: 34 additions & 8 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def load_pipeline(
inversion_names, inversion_weights = zip(*inversions)
logger.debug("blending Textual Inversions from %s", inversion_names)

inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names]
inversion_models = [
path.join(server.model_path, "inversion", f"{name}.ckpt")
for name in inversion_names
]
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
tokenizer = CLIPTokenizer.from_pretrained(
model,
Expand Down Expand Up @@ -249,16 +252,33 @@ def load_pipeline(
# test LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
lora_models = [
path.join(server.model_path, "lora", f"{name}.safetensors")
for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models
)

# blend and load text encoder
text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx")
blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder = text_encoder or path.join(
model, "text_encoder", "model.onnx"
)
text_encoder = blend_loras(
server,
text_encoder,
lora_models,
"text_encoder",
lora_weights=lora_weights,
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
Expand All @@ -268,7 +288,13 @@ def load_pipeline(
)

# blend and load unet
blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
blended_unet = blend_loras(
server,
path.join(model, "unet", "model.onnx"),
lora_models,
"unet",
lora_weights=lora_weights,
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions()
Expand Down
11 changes: 8 additions & 3 deletions api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logging import getLogger
from math import ceil
from re import compile, Pattern
from re import Pattern, compile
from typing import List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -132,7 +132,9 @@ def expand_prompt(
return prompt_embeds


def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
def get_tokens_from_prompt(
prompt: str, pattern: Pattern[str]
) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
Expand All @@ -145,7 +147,10 @@ def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, Lis
name, weight = next_match.groups()
tokens.append((name, float(weight)))
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = pattern.search(remaining_prompt)

return (remaining_prompt, tokens)
Expand Down

0 comments on commit 8e8e230

Please sign in to comment.