diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index b6b441469..f014016c3 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -208,6 +208,23 @@ def convert_models(ctx: ConversionContext, args, models: Models): except Exception: logger.exception("error fetching source %s", name) + if args.networks and "networks" in models: + for network in models.get("networks"): + name = network["name"] + + if name in args.skip: + logger.info("skipping network: %s", name) + else: + network_format = source_format(network) + network_type = network["type"] + source = network["source"] + + try: + dest = fetch_model(ctx, name, source, dest=path.join(ctx.model_path, network_type), format=network_format) + logger.info("finished downloading network: %s -> %s", source, dest) + except Exception: + logger.exception("error fetching network %s", name) + if args.diffusion and "diffusion" in models: for model in models.get("diffusion"): model = tuple_to_diffusion(model) diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 91f40141d..64b678303 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -13,7 +13,7 @@ from os import mkdir, path from pathlib import Path from shutil import rmtree -from typing import Dict +from typing import Dict, Tuple import torch from diffusers import ( diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index 8bee08d7f..fbc2bb109 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -10,7 +10,20 @@ $defs: - type: number - type: string - textual_inversion: + lora_network: + type: object + required: [name, source] + properties: + name: + type: string + source: + type: string + label: + type: string + weight: + type: number + + textual_inversion_network: type: object required: [name, source] properties: @@ -25,6 +38,8 @@ $defs: type: string token: type: string + weight: + type: number base_model: type: object @@ -58,7 +73,11 @@ $defs: inversions: type: array items: - $ref: "#/$defs/textual_inversion" + $ref: "#/$defs/textual_inversion_network" + loras: + type: array + items: + $ref: "#/$defs/lora_network" vae: type: string @@ -82,6 +101,21 @@ $defs: source: type: string + source_network: + type: object + required: [name, source, type] + properties: + format: + type: string + enum: [ckpt, safetensors] + name: + type: string + source: + type: string + type: + type: string + enum: [inversion, lora] + translation: type: object additionalProperties: False @@ -106,12 +140,9 @@ properties: oneOf: - $ref: "#/$defs/legacy_tuple" - $ref: "#/$defs/correction_model" - upscaling: + networks: type: array - items: - oneOf: - - $ref: "#/$defs/legacy_tuple" - - $ref: "#/$defs/upscaling_model" + items: "#/$defs/source_network" sources: type: array items: @@ -123,4 +154,10 @@ properties: additionalProperties: False patternProperties: "^\\w\\w$": - $ref: "#/$defs/translation" \ No newline at end of file + $ref: "#/$defs/translation" + upscaling: + type: array + items: + oneOf: + - $ref: "#/$defs/legacy_tuple" + - $ref: "#/$defs/upscaling_model" \ No newline at end of file