diff --git a/api/Makefile b/api/Makefile index 523cc1281..236aacabd 100644 --- a/api/Makefile +++ b/api/Makefile @@ -29,7 +29,7 @@ package-upload: twine upload dist/* lint-check: - black --check --preview onnx_web/ + black --check onnx_web/ isort --check-only --skip __init__.py --filter-files onnx_web flake8 onnx_web diff --git a/api/extras.json b/api/extras.json index 3fa65e5b8..b942fe0a7 100644 --- a/api/extras.json +++ b/api/extras.json @@ -17,6 +17,32 @@ "name": "diffusion-unstable-ink-dream-v6", "source": "civitai://5796", "format": "safetensors" + }, + { + "name": "stable-diffusion-onnx-v1-5", + "source": "runwayml/stable-diffusion-v1-5", + "inversions": [ + { + "name": "line-art", + "source": "sd-concepts-library/line-art" + }, + { + "name": "cubex", + "source": "sd-concepts-library/cubex" + }, + { + "name": "birb", + "source": "sd-concepts-library/birb-style" + }, + { + "name": "minecraft", + "source": "sd-concepts-library/minecraft-concept-art" + }, + { + "name": "ugly-sonic", + "source": "sd-concepts-library/ugly-sonic" + } + ] } ], "correction": [], diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0baa3153d..0ef9ef960 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -36,6 +36,7 @@ def blend_img2img( params.scheduler, job.get_device(), params.lpw, + params.inversion, ) if params.lpw: logger.debug("using LPW pipeline for img2img") diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 7c3b32a73..16422cce3 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -75,6 +75,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): params.scheduler, job.get_device(), params.lpw, + params.inversion, ) if params.lpw: diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 1fb4ee6ad..a5cbb07fe 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -42,6 +42,7 @@ def source_txt2img( params.scheduler, job.get_device(), params.lpw, + params.inversion, ) if params.lpw: diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 0f552bd55..7919e3251 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -79,6 +79,7 @@ def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): params.scheduler, job.get_device(), params.lpw, + params.inversion, ) if params.lpw: logger.debug("using LPW pipeline for inpaint") diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 774a9e721..2a670511d 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -10,8 +10,9 @@ from yaml import safe_load from .correction_gfpgan import convert_correction_gfpgan -from .diffusion_original import convert_diffusion_original -from .diffusion_stable import convert_diffusion_stable +from .diffusion.diffusers import convert_diffusion_diffusers +from .diffusion.original import convert_diffusion_original +from .diffusion.textual_inversion import convert_diffusion_textual_inversion from .upscale_resrgan import convert_upscale_resrgan from .utils import ( ConversionContext, @@ -223,11 +224,22 @@ def convert_models(ctx: ConversionContext, args, models: Models): source, ) else: - convert_diffusion_stable( + convert_diffusion_diffusers( ctx, model, source, ) + + for inversion in model.get("inversions", []): + inversion_name = inversion["name"] + inversion_source = inversion["source"] + inversion_source = fetch_model( + ctx, f"{name}-inversion-{inversion_name}", inversion_source + ) + convert_diffusion_textual_inversion( + ctx, inversion_name, model["source"], inversion_source + ) + except Exception as e: logger.error("error converting diffusion model %s: %s", name, e) diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion/diffusers.py similarity index 98% rename from api/onnx_web/convert/diffusion_stable.py rename to api/onnx_web/convert/diffusion/diffusers.py index 998e2fb54..0a805dcbc 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -25,12 +25,11 @@ from onnx import load, save_model from torch.onnx import export -from onnx_web.diffusion.load import optimize_pipeline - -from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( +from ...diffusion.load import optimize_pipeline +from ...diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) -from .utils import ConversionContext +from ..utils import ConversionContext logger = getLogger(__name__) @@ -63,7 +62,7 @@ def onnx_export( @torch.no_grad() -def convert_diffusion_stable( +def convert_diffusion_diffusers( ctx: ConversionContext, model: Dict, source: str, diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py new file mode 100644 index 000000000..57e85d79a --- /dev/null +++ b/api/onnx_web/convert/diffusion/lora.py @@ -0,0 +1,212 @@ +from logging import getLogger +from os import path +from sys import argv +from typing import List, Tuple + +import onnx.checker +import torch +from numpy import ndarray +from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model +from safetensors import safe_open + +from ..utils import ConversionContext + +logger = getLogger(__name__) + + +### +# everything in this file is still super experimental and may not produce valid ONNX models +### + + +def load_lora(filename: str): + model = load(filename) + + for weight in model.graph.initializer: + # print(weight.name, numpy_helper.to_array(weight).shape) + pass + + return model + + +def blend_loras( + base: ModelProto, weights: List[ModelProto], alphas: List[float] +) -> List[Tuple[TensorProto, ndarray]]: + total = 1 + sum(alphas) + + results = [] + + for base_node in base.graph.initializer: + logger.info("blending initializer node %s", base_node.name) + base_weights = numpy_helper.to_array(base_node).copy() + + for weight, alpha in zip(weights, alphas): + weight_node = next( + iter([f for f in weight.graph.initializer if f.name == base_node.name]), + None, + ) + + if weight_node is not None: + base_weights += numpy_helper.to_array(weight_node) * alpha + else: + logger.warning( + "missing weights: %s in %s", base_node.name, weight.doc_string + ) + + results.append((base_node, base_weights / total)) + + return results + + +def convert_diffusion_lora(context: ConversionContext, component: str): + lora_weights = [ + f"diffusion-lora-jack/{component}/model.onnx", + f"diffusion-lora-taters/{component}/model.onnx", + ] + + base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx") + weights = [load_lora(f) for f in lora_weights] + alphas = [1 / len(weights)] * len(weights) + logger.info("blending LoRAs with alphas: %s, %s", weights, alphas) + + result = blend_loras(base, weights, alphas) + logger.info("blended result keys: %s", len(result)) + + del weights + del alphas + + tensors = [] + for node, tensor in result: + logger.info("remaking tensor for %s", node.name) + tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor)) + + del result + + graph = helper.make_graph( + base.graph.node, + base.graph.name, + base.graph.input, + base.graph.output, + tensors, + base.graph.doc_string, + base.graph.value_info, + base.graph.sparse_initializer, + ) + model = helper.make_model(graph) + + del model.opset_import[:] + opset = model.opset_import.add() + opset.version = 14 + + onnx_path = path.join(context.cache_path, f"lora-{component}.onnx") + tensor_path = path.join(context.cache_path, f"lora-{component}.tensors") + save_model( + model, + onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=tensor_path, + ) + logger.info( + "saved model to %s and tensors to %s", + onnx_path, + tensor_path, + ) + + +def fix_key(key: str): + # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight + # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 + return key.replace(".", "_") + + +def merge_lora(): + base_name = argv[1] + lora_name = argv[2] + + base_model = load(base_name) + lora_model = safe_open(lora_name, framework="pt") + + lora_nodes = [] + for base_node in base_model.graph.initializer: + base_key = fix_key(base_node.name) + + for key in lora_model.keys(): + if "lora_down" in key: + lora_key = key[: key.index("lora_down")].replace("lora_unet_", "") + if lora_key.startswith(base_key): + print("down for key:", base_key, lora_key) + + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + down_weight = lora_model.get_tensor(key).to(dtype=torch.float32) + up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32) + + dim = down_weight.size()[0] + alpha = lora_model.get(alpha_key).numpy() or dim + + np_vals = numpy_helper.to_array(base_node) + print(np_vals.shape, up_weight.shape, down_weight.shape) + + squoze = ( + ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + ) + print(squoze.shape) + + np_vals = np_vals + (alpha * squoze.numpy()) + + try: + if len(up_weight.size()) == 2: + squoze = up_weight @ down_weight + print(squoze.shape) + np_vals = np_vals + (squoze.numpy() * (alpha / dim)) + else: + squoze = ( + ( + up_weight.squeeze(3).squeeze(2) + @ down_weight.squeeze(3).squeeze(2) + ) + .unsqueeze(2) + .unsqueeze(3) + ) + print(squoze.shape) + np_vals = np_vals + (alpha * squoze.numpy()) + + # retensor = numpy_helper.from_array(np_vals, base_node.name) + retensor = helper.make_tensor( + base_node.name, + base_node.data_type, + base_node.dim, + np_vals, + raw=True, + ) + print(retensor) + + # TypeError: does not support assignment + lora_nodes.append(retensor) + + break + except Exception as e: + print(e) + + if retensor is None: + print("no lora found for key", base_key) + lora_nodes.append(base_node) + + print(len(lora_nodes), len(base_model.graph.initializer)) + del base_model.graph.initializer[:] + base_model.graph.initializer.extend(lora_nodes) + + onnx.checker.check_model(base_model) + + +if __name__ == "__main__": + context = ConversionContext.from_environ() + convert_diffusion_lora(context, "unet") + convert_diffusion_lora(context, "text_encoder") diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion/original.py similarity index 74% rename from api/onnx_web/convert/diffusion_original.py rename to api/onnx_web/convert/diffusion/original.py index da00c45bb..c331ee741 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -53,13 +53,13 @@ CLIPVisionConfig, ) -from .diffusion_stable import convert_diffusion_stable -from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from .diffusers import convert_diffusion_diffusers logger = getLogger(__name__) -class TrainingConfig(): +class TrainingConfig: """ From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py """ @@ -184,7 +184,9 @@ def save(self, backup=False): backup_dir = os.path.join(models_path, "backups") if not os.path.exists(backup_dir): os.makedirs(backup_dir) - config_file = os.path.join(models_path, "backups", f"db_config_{self.revision}.json") + config_file = os.path.join( + models_path, "backups", f"db_config_{self.revision}.json" + ) with open(config_file, "w") as outfile: json.dump(self.__dict__, outfile, indent=4) @@ -238,7 +240,9 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -253,7 +257,9 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -295,7 +301,9 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -303,7 +311,12 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, ): """ This does the final conversion step: take locally converted weights and apply a global renaming @@ -312,7 +325,9 @@ def assign_to_checkpoint( Assigns the weights to the new checkpoint. """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + assert isinstance( + paths, list + ), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: @@ -324,7 +339,9 @@ def assign_to_checkpoint( num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + old_tensor = old_tensor.reshape( + (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] + ) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) @@ -335,7 +352,10 @@ def assign_to_checkpoint( new_path = path["new"] # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: + if ( + attention_paths_to_split is not None + and new_path in attention_paths_to_split + ): continue # Global renaming happens here @@ -373,19 +393,29 @@ def create_unet_diffusers_config(original_config, image_size: int): unet_params = original_config.model.params.unet_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [ + unet_params.model_channels * mult for mult in unet_params.channel_mult + ] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = ( + "CrossAttnDownBlock2D" + if resolution in unet_params.attention_resolutions + else "DownBlock2D" + ) down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = ( + "CrossAttnUpBlock2D" + if resolution in unet_params.attention_resolutions + else "UpBlock2D" + ) up_block_types.append(block_type) resolution //= 2 @@ -393,7 +423,9 @@ def create_unet_diffusers_config(original_config, image_size: int): head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params.use_linear_in_transformer + if "use_linear_in_transformer" in unet_params + else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 @@ -482,7 +514,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) else: print( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" @@ -493,33 +527,53 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - new_checkpoint = {"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"], - "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"], - "time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"], - "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"], - "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"], - "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"], - "conv_norm_out.weight": unet_state_dict["out.0.weight"], - "conv_norm_out.bias": unet_state_dict["out.0.bias"], - "conv_out.weight": unet_state_dict["out.2.weight"], - "conv_out.bias": unet_state_dict["out.2.bias"]} + new_checkpoint = { + "time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"], + "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"], + "time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"], + "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"], + "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"], + "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"], + "conv_norm_out.weight": unet_state_dict["out.0.weight"], + "conv_norm_out.bias": unet_state_dict["out.0.bias"], + "conv_out.weight": unet_state_dict["out.2.weight"], + "conv_out.bias": unet_state_dict["out.2.bias"], + } # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + num_input_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "input_blocks" in layer + } + ) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + num_middle_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "middle_block" in layer + } + ) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + num_output_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "output_blocks" in layer + } + ) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) @@ -530,29 +584,45 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.weight" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.bias" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) if len(attentions): paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) resnet_0 = middle_blocks[0] @@ -568,7 +638,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) for i in range(num_output_blocks): @@ -586,25 +660,36 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key + ] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] + index = list(output_block_list.values()).index( + ["conv.bias", "conv.weight"] + ) + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.weight" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.bias" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -617,13 +702,27 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) new_checkpoint[new_path] = unet_state_dict[old_path] @@ -645,49 +744,75 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): else: vae_state_dict[key] = checkpoint.get(key) - new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], - "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], - "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"], - "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"], - "encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"], - "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"], - "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"], - "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"], - "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"], - "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"], - "decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"], - "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"], - "quant_conv.weight": vae_state_dict["quant_conv.weight"], - "quant_conv.bias": vae_state_dict["quant_conv.bias"], - "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"], - "post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"]} + new_checkpoint = { + "encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], + "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], + "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"], + "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"], + "encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"], + "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"], + "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"], + "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"], + "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"], + "decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"], + "quant_conv.weight": vae_state_dict["quant_conv.weight"], + "quant_conv.bias": vae_state_dict["quant_conv.bias"], + "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"], + "post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"], + } # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 @@ -696,31 +821,51 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 @@ -729,12 +874,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) return new_checkpoint @@ -769,14 +926,16 @@ def _copy_layers(hf_layers, pt_layers): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i - pt_layer = pt_layers[i: i + 2] + pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + hf_model.model.embed_positions.weight.data = ( + checkpoint.transformer.pos_emb.emb.weight + ) # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) @@ -799,9 +958,13 @@ def convert_ldm_clip_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): if key.find("text_model") == -1: - text_model_dict["text_model." + key[len("cond_stage_model.transformer."):]] = checkpoint[key] + text_model_dict[ + "text_model." + key[len("cond_stage_model.transformer.") :] + ] = checkpoint[key] else: - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + text_model_dict[ + key[len("cond_stage_model.transformer.") :] + ] = checkpoint[key] text_model.load_state_dict(text_model_dict) @@ -809,12 +972,16 @@ def convert_ldm_clip_checkpoint(checkpoint): textenc_conversion_lst = [ - ('cond_stage_model.model.positional_embedding', - "text_model.embeddings.position_embedding.weight"), - ('cond_stage_model.model.token_embedding.weight', - "text_model.embeddings.token_embedding.weight"), - ('cond_stage_model.model.ln_final.weight', 'text_model.final_layer_norm.weight'), - ('cond_stage_model.model.ln_final.bias', 'text_model.final_layer_norm.bias') + ( + "cond_stage_model.model.positional_embedding", + "text_model.embeddings.position_embedding.weight", + ), + ( + "cond_stage_model.model.token_embedding.weight", + "text_model.embeddings.token_embedding.weight", + ), + ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), + ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} @@ -827,8 +994,14 @@ def convert_ldm_clip_checkpoint(checkpoint): (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), + ( + "token_embedding.weight", + "transformer.text_model.embeddings.token_embedding.weight", + ), + ( + "positional_embedding", + "transformer.text_model.embeddings.position_embedding.weight", + ), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) @@ -844,7 +1017,9 @@ def convert_paint_by_example_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[ + key + ] # load clip vision model.model.load_state_dict(text_model_dict) @@ -902,19 +1077,25 @@ def convert_paint_by_example_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + text_model = CLIPTextModel.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="text_encoder" + ) keys = list(checkpoint.keys()) text_model_dict = {} - if 'cond_stage_model.model.text_projection' in checkpoint: - d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0]) + if "cond_stage_model.model.text_projection" in checkpoint: + d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) else: logger.debug("no projection shape found, setting to 1024") d_model = 1024 - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + text_model_dict[ + "text_model.embeddings.position_ids" + ] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + if ( + "resblocks.23" in key + ): # Diffusers drops the final layer and only uses the penultimate layer continue if key in textenc_conversion_map: text_model_dict[textenc_conversion_map[key]] = checkpoint[key] @@ -922,18 +1103,34 @@ def convert_open_clip_checkpoint(checkpoint): new_key = key[len("cond_stage_model.model.transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][ + :d_model, : + ] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][ + d_model : d_model * 2, : + ] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][ + d_model * 2 :, : + ] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][ + d_model : d_model * 2 + ] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][ + d_model * 2 : + ] else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) text_model_dict[new_key] = checkpoint[key] @@ -992,7 +1189,15 @@ def download_model(db_config: TrainingConfig, token): siblings = repo_info.siblings - diffusion_dirs = ["text_encoder", "unet", "vae", "tokenizer", "scheduler", "feature_extractor", "safety_checker"] + diffusion_dirs = [ + "text_encoder", + "unet", + "vae", + "tokenizer", + "scheduler", + "feature_extractor", + "safety_checker", + ] config_file = None model_index = None model_files = [] @@ -1031,9 +1236,9 @@ def download_model(db_config: TrainingConfig, token): (x for x in model_files if "nonema" in x), next( (x for x in model_files if ".safetensors" in x), - model_files[0] if model_files else None - ) - ) + model_files[0] if model_files else None, + ), + ), ) files_to_fetch = None @@ -1061,7 +1266,7 @@ def download_model(db_config: TrainingConfig, token): filename=repo_file, repo_type="model", revision=repo_info.sha, - token=token + token=token, ) replace_symlinks(out, db_config.model_dir) dest = None @@ -1074,7 +1279,9 @@ def download_model(db_config: TrainingConfig, token): for diffusion_dir in diffusion_dirs: if diffusion_dir in out: out_model = db_config.pretrained_model_name_or_path - dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir) + dest = os.path.join( + db_config.pretrained_model_name_or_path, diffusion_dir + ) if not dest: if ".ckpt" in out or ".safetensors" in out: dest = os.path.join(db_config.model_dir, "src") @@ -1095,9 +1302,11 @@ def get_config_path( model_version: str = "v1", train_type: str = "default", config_base_name: str = "training", - prediction_type: str = "epsilon" + prediction_type: str = "epsilon", ): - train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" + train_type = ( + f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" + ) parts = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -1106,21 +1315,20 @@ def get_config_path( "..", "models", "configs", - f"{model_version}-{config_base_name}-{train_type}.yaml" + f"{model_version}-{config_base_name}-{train_type}.yaml", ) return os.path.abspath(parts) -def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None): +def get_config_file( + train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None +): if config_file is not None: return config_file config_base_name = "training" - model_versions = { - "v1": "v1", - "v2": "v2" - } + model_versions = {"v1": "v1", "v2": "v2"} train_types = { "default": "default", "unfrozen": "unfrozen", @@ -1134,7 +1342,9 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", c else: model_train_type = train_types["default"] - return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) + return get_config_path( + model_version_name, model_train_type, config_base_name, prediction_type + ) def extract_checkpoint( @@ -1182,8 +1392,9 @@ def extract_checkpoint( msg = None # Create empty config - db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, - src=checkpoint_file) + db_config = TrainingConfig( + ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file + ) original_config_file = None @@ -1221,9 +1432,13 @@ def extract_checkpoint( else: prediction_type = "epsilon" - original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file) + original_config_file = get_config_file( + train_unfrozen, v2, prediction_type, config_file=config_file + ) - logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") + logger.info( + f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}" + ) db_config.resolution = image_size db_config.lifetime_revision = revision db_config.epoch = epoch @@ -1233,12 +1448,18 @@ def extract_checkpoint( # Use existing YAML if present if checkpoint_file is not None: - config_check = checkpoint_file.replace(".ckpt", ".yaml") if ".ckpt" in checkpoint_file else checkpoint_file.replace(".safetensors", ".yaml") + config_check = ( + checkpoint_file.replace(".ckpt", ".yaml") + if ".ckpt" in checkpoint_file + else checkpoint_file.replace(".safetensors", ".yaml") + ) if os.path.exists(config_check): original_config_file = config_check if original_config_file is None or not os.path.exists(original_config_file): - logger.warning("unable to select a config file: %s" % (original_config_file)) + logger.warning( + "unable to select a config file: %s" % (original_config_file) + ) return logger.debug("trying to load: %s", original_config_file) @@ -1281,7 +1502,9 @@ def extract_checkpoint( # Convert the UNet2DConditionModel model. logger.info("converting UNet") - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config = create_unet_diffusers_config( + original_config, image_size=image_size + ) unet_config["upcast_attention"] = upcast_attention unet = UNet2DConditionModel(**unet_config) @@ -1297,22 +1520,30 @@ def extract_checkpoint( vae_config = create_vae_diffusers_config(original_config, image_size=image_size) if vae_file is None: - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + checkpoint, vae_config + ) else: vae_file = os.path.join(ctx.model_path, vae_file) logger.debug("loading custom VAE: %s", vae_file) vae_checkpoint = load_tensor(vae_file, map_location=map_location) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + vae_checkpoint, vae_config, first_stage=False + ) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) # Convert the text model. logger.info("converting text encoder") - text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + text_model_type = original_config.model.params.cond_stage_config.target.split( + "." + )[-1] if text_model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="tokenizer" + ) pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_model, @@ -1326,7 +1557,9 @@ def extract_checkpoint( elif text_model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ) pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, @@ -1338,8 +1571,12 @@ def extract_checkpoint( elif text_model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ) pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_model, @@ -1347,16 +1584,24 @@ def extract_checkpoint( unet=unet, scheduler=scheduler, safety_checker=safety_checker, - feature_extractor=feature_extractor + feature_extractor=feature_extractor, ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, - scheduler=scheduler) + pipe = LDMTextToImagePipeline( + vqvae=vae, + bert=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) except Exception: - logger.error("exception setting up output: %s", traceback.format_exception(*sys.exc_info())) + logger.error( + "exception setting up output: %s", + traceback.format_exception(*sys.exc_info()), + ) pipe = None if pipe is None or db_config is None: @@ -1371,12 +1616,18 @@ def extract_checkpoint( scheduler = db_config.scheduler required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] if original_config_file is not None and os.path.exists(original_config_file): - logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir) + logger.debug( + "copying original config: %s -> %s", + original_config_file, + db_config.model_dir, + ) shutil.copy(original_config_file, db_config.model_dir) basename = os.path.basename(original_config_file) new_ex_path = os.path.join(db_config.model_dir, basename) new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") - logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name) + logger.debug( + "copying model config to new name: %s -> %s", new_ex_path, new_name + ) if os.path.exists(new_name): os.remove(new_name) os.rename(new_ex_path, new_name) @@ -1407,7 +1658,9 @@ def convert_diffusion_original( source = source or model["source"] dest = os.path.join(ctx.model_path, name) - logger.info("converting original Diffusers checkpoint %s: %s -> %s", name, source, dest) + logger.info( + "converting original Diffusers checkpoint %s: %s -> %s", name, source, dest + ) if os.path.exists(dest): logger.info("ONNX pipeline already exists, skipping") @@ -1420,13 +1673,23 @@ def convert_diffusion_original( if os.path.exists(torch_path): logger.info("torch pipeline already exists, reusing: %s", torch_path) else: - logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path) - extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae")) + logger.info( + "converting original Diffusers check to Torch model: %s -> %s", + source, + torch_path, + ) + extract_checkpoint( + ctx, + torch_name, + source, + config_file=model.get("config"), + vae_file=model.get("vae"), + ) logger.info("converted original Diffusers checkpoint to Torch model") # VAE has already been converted and will confuse HF repo lookup if "vae" in model: del model["vae"] - convert_diffusion_stable(ctx, model, working_name) + convert_diffusion_diffusers(ctx, model, working_name) logger.info("ONNX pipeline saved to %s", name) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py new file mode 100644 index 000000000..cf8deb68c --- /dev/null +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -0,0 +1,88 @@ +from logging import getLogger +from os import makedirs, path + +import torch +from huggingface_hub.file_download import hf_hub_download +from torch.onnx import export +from transformers import CLIPTextModel, CLIPTokenizer + +from ..utils import ConversionContext + +logger = getLogger(__name__) + + +def convert_diffusion_textual_inversion( + context: ConversionContext, name: str, base_model: str, inversion: str +): + dest_path = path.join(context.model_path, f"inversion-{name}") + logger.info( + "converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path + ) + + if path.exists(dest_path): + logger.info("ONNX model already exists, skipping.") + return + + makedirs(path.join(dest_path, "text_encoder"), exist_ok=True) + + embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") + token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") + + with open(token_file, "r") as f: + token = f.read() + + tokenizer = CLIPTokenizer.from_pretrained( + base_model, + subfolder="tokenizer", + ) + text_encoder = CLIPTextModel.from_pretrained( + base_model, + subfolder="text_encoder", + ) + + loaded_embeds = torch.load(embeds_file, map_location=context.map_location) + + # separate token and the embeds + trained_token = list(loaded_embeds.keys())[0] + embeds = loaded_embeds[trained_token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + embeds.to(dtype) + + # add the token in tokenizer + num_added_tokens = tokenizer.add_tokens(token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." + ) + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + + # conversion stuff + text_input = tokenizer( + "A sample prompt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + export( + text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + (text_input.input_ids.to(dtype=torch.int32)), + f=path.join(dest_path, "text_encoder", "model.onnx"), + input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + do_constant_folding=True, + opset_version=context.opset, + ) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 339787597..47c35fcdd 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -1,4 +1,5 @@ from logging import getLogger +from os import path from typing import Any, Optional, Tuple import numpy as np @@ -16,6 +17,7 @@ KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, + OnnxRuntimeModel, PNDMScheduler, StableDiffusionPipeline, ) @@ -138,8 +140,9 @@ def load_pipeline( scheduler_type: Any, device: DeviceParams, lpw: bool, + inversion: Optional[str], ): - pipe_key = (pipeline, model, device.device, device.provider, lpw) + pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion) scheduler_key = (scheduler_type, model) cache_pipe = server.cache.get("diffusion", pipe_key) @@ -176,12 +179,23 @@ def load_pipeline( custom_pipeline = None logger.debug("loading new diffusion pipeline from %s", model) - scheduler = scheduler_type.from_pretrained( - model, - provider=device.ort_provider(), - sess_options=device.sess_options(), - subfolder="scheduler", - ) + components = { + "scheduler": scheduler_type.from_pretrained( + model, + provider=device.ort_provider(), + sess_options=device.sess_options(), + subfolder="scheduler", + ) + } + + if inversion is not None: + logger.debug("loading text encoder from %s", inversion) + components["text_encoder"] = OnnxRuntimeModel.from_pretrained( + path.join(inversion, "text_encoder"), + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) + pipe = pipeline.from_pretrained( model, custom_pipeline=custom_pipeline, @@ -189,7 +203,7 @@ def load_pipeline( sess_options=device.sess_options(), revision="onnx", safety_checker=None, - scheduler=scheduler, + **components, ) if not server.show_progress: @@ -201,6 +215,6 @@ def load_pipeline( pipe = pipe.to(device.torch_str()) server.cache.set("diffusion", pipe_key, pipe) - server.cache.set("scheduler", scheduler_key, scheduler) + server.cache.set("scheduler", scheduler_key, components["scheduler"]) return pipe diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index c96af7e1c..ec988d660 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -36,6 +36,7 @@ def run_txt2img_pipeline( params.scheduler, job.get_device(), params.lpw, + params.inversion, ) progress = job.get_progress_callback() @@ -109,6 +110,7 @@ def run_img2img_pipeline( params.scheduler, job.get_device(), params.lpw, + params.inversion, ) progress = job.get_progress_callback() if params.lpw: diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index d793d262b..f60b3941f 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -19,6 +19,8 @@ def hash_value(sha, param: Param): if param is None: return + elif isinstance(param, bool): + sha.update(bytearray(pack("!B", param))) elif isinstance(param, float): sha.update(bytearray(pack("!f", param))) elif isinstance(param, int): @@ -73,8 +75,12 @@ def make_output_name( hash_value(sha, params.prompt) hash_value(sha, params.negative_prompt) hash_value(sha, params.cfg) - hash_value(sha, params.steps) hash_value(sha, params.seed) + hash_value(sha, params.steps) + hash_value(sha, params.lpw) + hash_value(sha, params.eta) + hash_value(sha, params.batch) + hash_value(sha, params.inversion) hash_value(sha, size.width) hash_value(sha, size.height) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index fada5d23d..db23414aa 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -157,6 +157,7 @@ def __init__( lpw: bool = False, eta: float = 0.0, batch: int = 1, + inversion: Optional[str] = None, ) -> None: self.model = model self.scheduler = scheduler @@ -168,6 +169,7 @@ def __init__( self.lpw = lpw or False self.eta = eta self.batch = batch + self.inversion = inversion def tojson(self) -> Dict[str, Optional[Param]]: return { @@ -181,6 +183,7 @@ def tojson(self) -> Dict[str, Optional[Param]]: "lpw": self.lpw, "eta": self.eta, "batch": self.batch, + "inversion": self.inversion, } def with_args(self, **kwargs): @@ -195,6 +198,7 @@ def with_args(self, **kwargs): kwargs.get("lpw", self.lpw), kwargs.get("eta", self.eta), kwargs.get("batch", self.batch), + kwargs.get("inversion", self.inversion), ) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index ee60ef843..fb95d8ef3 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -122,8 +122,9 @@ available_platforms: List[DeviceParams] = [] # loaded from model_path -diffusion_models: List[str] = [] correction_models: List[str] = [] +diffusion_models: List[str] = [] +inversion_models: List[str] = [] upscaling_models: List[str] = [] @@ -159,6 +160,11 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler") ) + inversion = request.args.get("inversion", None) + inversion_path = None + if inversion is not None and inversion.strip() != "": + inversion_path = get_model_path(inversion) + # image params prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) negative_prompt = request.args.get("negativePrompt", None) @@ -239,6 +245,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: lpw=lpw, negative_prompt=negative_prompt, batch=batch, + inversion=inversion_path, ) size = Size(width, height) return (device, params, size) @@ -301,8 +308,9 @@ def get_model_name(model: str) -> str: def load_models(context: ServerContext) -> None: - global diffusion_models global correction_models + global diffusion_models + global inversion_models global upscaling_models diffusion_models = [ @@ -323,6 +331,12 @@ def load_models(context: ServerContext) -> None: correction_models = list(set(correction_models)) correction_models.sort() + inversion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*")) + ] + inversion_models = list(set(inversion_models)) + inversion_models.sort() + upscaling_models = [ get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) ] @@ -496,8 +510,9 @@ def list_mask_filters(): def list_models(): return jsonify( { - "diffusion": diffusion_models, "correction": correction_models, + "diffusion": diffusion_models, + "inversion": inversion_models, "upscaling": upscaling_models, } ) diff --git a/api/params.json b/api/params.json index e45aefdc3..8303b47d3 100644 --- a/api/params.json +++ b/api/params.json @@ -60,6 +60,10 @@ "max": 1024, "step": 8 }, + "inversion": { + "default": "", + "keys": [] + }, "left": { "default": 0, "min": 0, diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index b0a0f9124..2cf471232 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -10,6 +10,15 @@ $defs: - type: number - type: string + textual_inversion: + type: object + required: [name, source] + properties: + name: + type: string + source: + type: string + base_model: type: object required: [name, source] @@ -37,6 +46,10 @@ $defs: properties: config: type: string + inversions: + type: array + items: + $ref: "#/$defs/textual_inversion" vae: type: string diff --git a/api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin.png b/api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin-0.png similarity index 100% rename from api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin.png rename to api/scripts/test-refs/img2img-sd-v1-5-256-pumpkin-0.png diff --git a/api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin.png b/api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin-0.png similarity index 100% rename from api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin.png rename to api/scripts/test-refs/img2img-sd-v1-5-512-pumpkin-0.png diff --git a/api/scripts/test-refs/inpaint-v1-512-black.png b/api/scripts/test-refs/inpaint-v1-512-black-0.png similarity index 100% rename from api/scripts/test-refs/inpaint-v1-512-black.png rename to api/scripts/test-refs/inpaint-v1-512-black-0.png diff --git a/api/scripts/test-refs/inpaint-v1-512-white.png b/api/scripts/test-refs/inpaint-v1-512-white-0.png similarity index 100% rename from api/scripts/test-refs/inpaint-v1-512-white.png rename to api/scripts/test-refs/inpaint-v1-512-white-0.png diff --git a/api/scripts/test-refs/outpaint-even-256.png b/api/scripts/test-refs/outpaint-even-256-0.png similarity index 100% rename from api/scripts/test-refs/outpaint-even-256.png rename to api/scripts/test-refs/outpaint-even-256-0.png diff --git a/api/scripts/test-refs/outpaint-horizontal-512.png b/api/scripts/test-refs/outpaint-horizontal-512-0.png similarity index 100% rename from api/scripts/test-refs/outpaint-horizontal-512.png rename to api/scripts/test-refs/outpaint-horizontal-512-0.png diff --git a/api/scripts/test-refs/outpaint-vertical-512.png b/api/scripts/test-refs/outpaint-vertical-512-0.png similarity index 100% rename from api/scripts/test-refs/outpaint-vertical-512.png rename to api/scripts/test-refs/outpaint-vertical-512-0.png diff --git a/api/scripts/test-refs/txt2img-knollingcase-512-muffin.png b/api/scripts/test-refs/txt2img-knollingcase-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-knollingcase-512-muffin.png rename to api/scripts/test-refs/txt2img-knollingcase-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-openjourney-512-muffin.png b/api/scripts/test-refs/txt2img-openjourney-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-openjourney-512-muffin.png rename to api/scripts/test-refs/txt2img-openjourney-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png b/api/scripts/test-refs/txt2img-sd-v1-5-256-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-256-muffin.png rename to api/scripts/test-refs/txt2img-sd-v1-5-256-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun.png rename to api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png b/api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v2-1-512-muffin.png rename to api/scripts/test-refs/txt2img-sd-v2-1-512-muffin-0.png diff --git a/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png b/api/scripts/test-refs/txt2img-sd-v2-1-768-muffin-0.png similarity index 100% rename from api/scripts/test-refs/txt2img-sd-v2-1-768-muffin.png rename to api/scripts/test-refs/txt2img-sd-v2-1-768-muffin-0.png diff --git a/api/scripts/test-refs/upscale-resrgan-x2-1024-muffin.png b/api/scripts/test-refs/upscale-resrgan-x2-1024-muffin-0.png similarity index 100% rename from api/scripts/test-refs/upscale-resrgan-x2-1024-muffin.png rename to api/scripts/test-refs/upscale-resrgan-x2-1024-muffin-0.png diff --git a/api/scripts/test-refs/upscale-resrgan-x4-2048-muffin.png b/api/scripts/test-refs/upscale-resrgan-x4-2048-muffin-0.png similarity index 100% rename from api/scripts/test-refs/upscale-resrgan-x4-2048-muffin.png rename to api/scripts/test-refs/upscale-resrgan-x4-2048-muffin-0.png diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index ff31397c9..65572083f 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -25,6 +25,9 @@ logger = getLogger(__name__) +FAST_TEST = 20 +SLOW_TEST = 50 + def test_root() -> str: if len(sys.argv) > 1: @@ -42,7 +45,7 @@ def __init__( self, name: str, query: str, - max_attempts: int = 20, + max_attempts: int = FAST_TEST, mse_threshold: float = 0.001, source: Union[Image.Image, List[Image.Image]] = None, mask: Image.Image = None, @@ -95,23 +98,23 @@ def __init__( TestCase( "img2img-sd-v1-5-512-pumpkin", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "img2img-sd-v1-5-256-pumpkin", "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", - source="txt2img-sd-v1-5-256-muffin", + source="txt2img-sd-v1-5-256-muffin-0", ), TestCase( "inpaint-v1-512-white", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-white", ), TestCase( "inpaint-v1-512-black", "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", ), TestCase( @@ -120,8 +123,9 @@ def __init__( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "&top=256&bottom=256&left=256&right=256" ), - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", + max_attempts=SLOW_TEST, mse_threshold=0.025, ), TestCase( @@ -130,8 +134,9 @@ def __init__( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "&top=512&bottom=512&left=0&right=0" ), - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", + max_attempts=SLOW_TEST, mse_threshold=0.010, ), TestCase( @@ -140,27 +145,28 @@ def __init__( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" "&top=0&bottom=0&left=512&right=512" ), - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", + max_attempts=SLOW_TEST, mse_threshold=0.010, ), TestCase( "upscale-resrgan-x2-1024-muffin", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "upscale-resrgan-x4-2048-muffin", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4", - source="txt2img-sd-v1-5-512-muffin", + source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "blend-512-muffin-black", "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", mask="mask-black", source=[ - "txt2img-sd-v1-5-512-muffin", - "txt2img-sd-v2-1-512-muffin", + "txt2img-sd-v1-5-512-muffin-0", + "txt2img-sd-v2-1-512-muffin-0", ], ), TestCase( @@ -168,14 +174,14 @@ def __init__( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2", mask="mask-white", source=[ - "txt2img-sd-v2-1-512-muffin", - "txt2img-sd-v1-5-512-muffin", + "txt2img-sd-v2-1-512-muffin-0", + "txt2img-sd-v1-5-512-muffin-0", ], ), ] -def generate_image(root: str, test: TestCase) -> Optional[str]: +def generate_images(root: str, test: TestCase) -> Optional[str]: files = {} if test.source is not None: if isinstance(test.source, list): @@ -211,9 +217,9 @@ def generate_image(root: str, test: TestCase) -> Optional[str]: resp = requests.post(f"{root}/api/{test.query}", files=files) if resp.status_code == 200: json = resp.json() - return json.get("output") + return json.get("outputs") else: - logger.warning("request failed: %s", resp.status_code) + logger.warning("request failed: %s: %s", resp.status_code, resp.text) return None @@ -227,14 +233,17 @@ def check_ready(root: str, key: str) -> bool: return False -def download_image(root: str, key: str) -> Image.Image: - resp = requests.get(f"{root}/output/{key}") - if resp.status_code == 200: - logger.debug("downloading image: %s", key) - return Image.open(BytesIO(resp.content)) - else: - logger.warning("request failed: %s", resp.status_code) - return None +def download_images(root: str, keys: List[str]) -> List[Image.Image]: + images = [] + for key in keys: + resp = requests.get(f"{root}/output/{key}") + if resp.status_code == 200: + logger.debug("downloading image: %s", key) + images.append(Image.open(BytesIO(resp.content))) + else: + logger.warning("request failed: %s", resp.status_code) + + return images def find_mse(result: Image.Image, ref: Image.Image) -> float: @@ -259,20 +268,19 @@ def find_mse(result: Image.Image, ref: Image.Image) -> float: def run_test( root: str, test: TestCase, - ref: Image.Image, ) -> bool: """ Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ - key = generate_image(root, test) - if key is None: + keys = generate_images(root, test) + if keys is None: raise ValueError("could not generate") attempts = 0 while attempts < test.max_attempts: - if check_ready(root, key): - logger.debug("image is ready: %s", key) + if check_ready(root, keys[0]): + logger.debug("image is ready: %s", keys) break else: logger.debug("waiting for image to be ready") @@ -282,16 +290,25 @@ def run_test( if attempts == test.max_attempts: raise ValueError("image was not ready in time") - result = download_image(root, key) - result.save(test_path(path.join("test-results", f"{test.name}.png"))) - mse = find_mse(result, ref) + results = download_images(root, keys) - if mse < test.mse_threshold: - logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) - return True - else: - logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold) - return False + passed = True + for i in range(len(results)): + result = results[i] + result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) + + ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png")) + ref = Image.open(ref_name) if path.exists(ref_name) else None + + mse = find_mse(result, ref) + + if mse < test.mse_threshold: + logger.info("MSE within threshold: %.4f < %.4f", mse, test.mse_threshold) + else: + logger.warning("MSE above threshold: %.4f > %.4f", mse, test.mse_threshold) + passed = False + + return passed def main(): @@ -303,9 +320,7 @@ def main(): for test in TEST_DATA: try: logger.info("starting test: %s", test.name) - ref_name = test_path(path.join("test-refs", f"{test.name}.png")) - ref = Image.open(ref_name) if path.exists(ref_name) else None - if run_test(root, test, ref): + if run_test(root, test): logger.info("test passed: %s", test.name) passed.append(test.name) else: diff --git a/common/scripts/copy-pr-to-gitlab.sh b/common/scripts/copy-pr-to-gitlab.sh new file mode 100644 index 000000000..269711916 --- /dev/null +++ b/common/scripts/copy-pr-to-gitlab.sh @@ -0,0 +1,6 @@ +PR_USER=$1 +PR_BRANCH=$2 + +git remote add $1 git@github.com:$1/onnx-web.git +git fetch $1 +git push gitlab refs/remotes/$1/$2:refs/heads/$1-$2 diff --git a/gui/src/client.ts b/gui/src/client.ts index ab523de79..d16e70c92 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -32,6 +32,8 @@ export interface ModelParams { * Use the long prompt weighting pipeline. */ lpw: boolean; + + inversion: string; } /** @@ -183,6 +185,7 @@ export interface ReadyResponse { export interface ModelsResponse { diffusion: Array; correction: Array; + inversion: Array; upscaling: Array; } @@ -325,6 +328,7 @@ export function appendModelToURL(url: URL, params: ModelParams) { url.searchParams.append('upscaling', params.upscaling); url.searchParams.append('correction', params.correction); url.searchParams.append('lpw', String(params.lpw)); + url.searchParams.append('inversion', params.inversion); } /** diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index 7e0159197..41d157cfc 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -7,7 +7,7 @@ import { useStore } from 'zustand'; import { STALE_TIME } from '../../config.js'; import { ClientContext, StateContext } from '../../state.js'; -import { MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js'; +import { INVERSION_LABELS, MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js'; import { QueryList } from '../input/QueryList.js'; export function ModelControl() { @@ -54,6 +54,22 @@ export function ModelControl() { }); }} /> + result.inversion, + }} + showEmpty={true} + value={params.inversion} + onChange={(inversion) => { + setModel({ + inversion, + }); + }} + /> { value: string; query: QueryListComplete | QueryListFilter; + showEmpty?: boolean; onChange?: (value: string) => void; } @@ -28,17 +29,25 @@ export function hasFilter(query: QueryListComplete | QueryListFilter): que return Reflect.has(query, 'selector'); } -export function filterQuery(query: QueryListComplete | QueryListFilter): Array { +export function filterQuery(query: QueryListComplete | QueryListFilter, showEmpty: boolean): Array { if (hasFilter(query)) { const data = mustExist(query.result.data); - return (query as QueryListFilter).selector(data); + const selected = (query as QueryListFilter).selector(data); + if (showEmpty) { + return ['', ...selected]; + } + return selected; } else { - return mustExist(query.result.data); + const data = Array.from(mustExist(query.result.data)); + if (showEmpty) { + return ['', ...data]; + } + return data; } } export function QueryList(props: QueryListProps) { - const { labels, query, value } = props; + const { labels, query, showEmpty = false, value } = props; const { result } = query; function firstValidValue(): string { @@ -52,7 +61,7 @@ export function QueryList(props: QueryListProps) { // update state when previous selection was invalid: https://github.com/ssube/onnx-web/issues/120 useEffect(() => { if (result.status === 'success' && doesExist(result.data) && doesExist(props.onChange)) { - const data = filterQuery(query); + const data = filterQuery(query, showEmpty); if (data.includes(value) === false) { props.onChange(data[0]); } @@ -77,7 +86,7 @@ export function QueryList(props: QueryListProps) { // else: success const labelID = `query-list-${props.id}-labels`; - const data = filterQuery(query); + const data = filterQuery(query, showEmpty); return {props.name} diff --git a/gui/src/state.ts b/gui/src/state.ts index baabf493f..cd8254d45 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -487,6 +487,7 @@ export function createStateSlices(server: ServerParams) { platform: server.platform.default, upscaling: server.upscaling.default, correction: server.correction.default, + inversion: server.inversion.default, lpw: false, }, setModel(params) { diff --git a/gui/src/strings.ts b/gui/src/strings.ts index 86c4171b2..6e6dec3eb 100644 --- a/gui/src/strings.ts +++ b/gui/src/strings.ts @@ -32,6 +32,14 @@ export const MODEL_LABELS: Record = { 'diffusion-unstable-ink-dream-v6': 'Unstable Ink Dream v6', }; +export const INVERSION_LABELS: Record = { + '': 'None', + 'inversion-cubex': 'Cubex', + 'inversion-birb': 'Birb Style', + 'inversion-line-art': 'Line Art', + 'inversion-minecraft': 'Minecraft Concept', +}; + export const PLATFORM_LABELS: Record = { amd: 'AMD GPU', // eslint-disable-next-line id-blacklist