Skip to content

Commit

Permalink
fix(api): include model scale
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 17, 2023
1 parent 556d5b8 commit dba6113
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions api/onnx_web/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,30 @@
'diffusers': [
# v1.x
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
('stable-diffusion-onnx-v1-inpainting', 'runwayml/stable-diffusion-inpainting'),
('stable-diffusion-onnx-v1-inpainting',
'runwayml/stable-diffusion-inpainting'),
# v2.x
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
('stable-diffusion-onnx-v2-inpainting', 'stabilityai/stable-diffusion-2-inpainting'),
('stable-diffusion-onnx-v2-inpainting',
'stabilityai/stable-diffusion-2-inpainting'),
],
'gfpgan': [
('GFPGANv1.3', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
('correction-gfpgan-v1-3',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
],
'real_esrgan': [
('RealESRGAN_x4plus', 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'),
('upscaling-real-esrgan-x4-plus',
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
],
}

model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))


training_device = 'cuda' if torch.cuda.is_available() else 'cpu'


@torch.no_grad()
def convert_real_esrgan(name: str, url: str, opset: int):
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name)
dest_onnx = path.join(model_path, name + '.onnx')
print('converting Real ESRGAN model: %s -> %s' % (name, dest_path))
Expand All @@ -53,7 +55,7 @@ def convert_real_esrgan(name: str, url: str, opset: int):

print('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
num_block=23, num_grow_ch=32, scale=scale)
model.load_state_dict(torch.load(dest_path)['params_ema'])
model.to(training_device).train(False)
model.eval()
Expand Down

0 comments on commit dba6113

Please sign in to comment.