Skip to content

Commit

Permalink
fix(api): add extra models to convert script
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 21, 2023
1 parent b1e7ab0 commit e083411
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 32 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ already listed in the `convert.py` script, including:
- https://huggingface.co/runwayml/stable-diffusion-inpainting
- https://huggingface.co/stabilityai/stable-diffusion-2-1
- https://huggingface.co/stabilityai/stable-diffusion-2-inpainting
- https://huggingface.co/Aybeeceedee/knollingcase
- https://huggingface.co/prompthero/openjourney

You will need at least one of the base models for txt2img and img2img mode. If you want to use inpainting, you will
also need one of the inpainting models. The upscaling and face correction models are downloaded from Github by the
Expand All @@ -273,15 +275,15 @@ paste it into the prompt.
Run the provided conversion script from the `api/` directory:

```shell
> python -m onnx_web.convert --diffusers --gfpgan --resrgan
> python -m onnx_web.convert --diffusion --correction --upscaling
```

Models that have already been downloaded and converted will be skipped, so it should be safe to run this script after
every update. Some additional, more specialized models are available using the `--extras` flag.

The conversion script has a few other options, which can be printed using `python -m onnx_web.convert --help`. If you
are using CUDA on Nvidia hardware, using the `--half` option may make things faster.

Models that have already been downloaded and converted will be skipped, so it should be safe to run this script after
every update.

This will take a little while to convert each model. Stable diffusion v1.4 is about 6GB, v1.5 is at least 10GB or so.
You can skip certain models by including a `--skip name` argument if you want to save time or disk space. For example,
using `--skip stable-diffusion-onnx-v2-inpainting --skip stable-diffusion-onnx-v2-1` will not download the Stable
Expand Down
83 changes: 55 additions & 28 deletions api/onnx_web/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from shutil import copyfile, rmtree
from sys import exit
from torch.onnx import export
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import torch

sources: Dict[str, List[Tuple[str, str]]] = {
'diffusers': [
Models = Dict[str, List[Tuple[str, str, Union[int, None]]]]

# recommended models
base_models: Models = {
'diffusion': [
# v1.x
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
('stable-diffusion-onnx-v1-inpainting',
Expand All @@ -23,11 +26,11 @@
('stable-diffusion-onnx-v2-inpainting',
'stabilityai/stable-diffusion-2-inpainting'),
],
'gfpgan': [
'correction': [
('correction-gfpgan-v1-3',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
],
'real_esrgan': [
'upscaling': [
('upscaling-real-esrgan-x2-plus',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2),
('upscaling-real-esrgan-x4-plus',
Expand All @@ -37,6 +40,16 @@
],
}

# other neat models
extra_models: Models = {
'diffusion': [
('diffusion-knollingcase', 'Aybeeceedee/knollingcase'),
('diffusion-openjourney', 'prompthero/openjourney'),
],
'correction': [],
'upscaling': [],
}

model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -175,7 +188,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool):
'''
dtype = torch.float16 if half else torch.float32
dest_path = path.join(model_path, name)
print('converting Diffusers model: %s -> %s' % (name, dest_path))

# diffusers go into a directory rather than .onnx file
print('converting Diffusers model: %s -> %s/' % (name, dest_path))

if path.isdir(dest_path):
print('ONNX model already exists, skipping.')
Expand Down Expand Up @@ -354,6 +369,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool):
requires_safety_checker=safety_checker is not None,
)

print('exporting ONNX model')

onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)

Expand All @@ -364,15 +381,39 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool):
print("ONNX pipeline is loadable")


def load_models(args, models: Models):
if args.diffusion:
for source in models.get('diffusion'):
if source[0] in args.skip:
print('Skipping model: %s' % source[0])
else:
convert_diffuser(*source, args.opset, args.half)

if args.upscaling:
for source in models.get('upscaling'):
if source[0] in args.skip:
print('Skipping model: %s' % source[0])
else:
convert_real_esrgan(*source, args.opset)

if args.correction:
for source in models.get('correction'):
if source[0] in args.skip:
print('Skipping model: %s' % source[0])
else:
convert_gfpgan(*source, args.opset)


def main() -> int:
parser = ArgumentParser(
prog='onnx-web model converter',
description='convert checkpoint models to ONNX')

# model groups
parser.add_argument('--diffusers', action='store_true', default=False)
parser.add_argument('--gfpgan', action='store_true', default=False)
parser.add_argument('--resrgan', action='store_true', default=False)
parser.add_argument('--correction', action='store_true', default=False)
parser.add_argument('--diffusion', action='store_true', default=False)
parser.add_argument('--extras', action='store_true', default=False)
parser.add_argument('--upscaling', action='store_true', default=False)
parser.add_argument('--skip', nargs='*', type=str, default=[])

# export options
Expand All @@ -392,26 +433,12 @@ def main() -> int:
args = parser.parse_args()
print(args)

if args.diffusers:
for source in sources.get('diffusers'):
if source[0] in args.skip:
print('Skipping model: %s' % source[0])
else:
convert_diffuser(*source, args.opset, args.half)
print('Converting base models.')
load_models(args, base_models)

if args.resrgan:
for source in sources.get('real_esrgan'):
if source[0] in args.skip:
print('Skipping model: %s' % source[0])
else:
convert_real_esrgan(*source, args.opset)

if args.gfpgan:
for source in sources.get('gfpgan'):
if source[0] in args.skip:
print('Skipping model: %s' % source[0])
else:
convert_gfpgan(*source, args.opset)
if args.extras:
print('Converting extra models.')
load_models(args, extra_models)

return 0

Expand Down

0 comments on commit e083411

Please sign in to comment.