Skip to content

Commit

Permalink
feat(api): add basic upscaling
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 16, 2023
1 parent 64fac4d commit 77cb84c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 0 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Based on guides by:
- [Note about setup paths](#note-about-setup-paths)
- [Create a virtual environment](#create-a-virtual-environment)
- [Install pip packages](#install-pip-packages)
- [For upscaling and face correction](#for-upscaling-and-face-correction)
- [For AMD on Windows: Install ONNX DirectML](#for-amd-on-windows-install-onnx-directml)
- [For CPU on Linux: Install PyTorch CPU](#for-cpu-on-linux-install-pytorch-cpu)
- [For CPU on Windows: Install PyTorch CPU](#for-cpu-on-windows-install-pytorch-cpu)
Expand Down Expand Up @@ -190,6 +191,12 @@ sure you are not using `numpy>=1.24`.
[This SO question](https://stackoverflow.com/questions/74844262/how-to-solve-error-numpy-has-no-attribute-float-in-python)
has more details.

#### For upscaling and face correction

```shell
> pip install basicsr facexlib gfpgan realesrgan
```

#### For AMD on Windows: Install ONNX DirectML

If you are running on Windows, install the DirectML ONNX runtime as well:
Expand Down
7 changes: 7 additions & 0 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
noise_source_uniform,
)

from .upscale import (
upscale_gfpgan,
upscale_resrgan,
)

import json
import numpy as np
import time
Expand Down Expand Up @@ -268,6 +273,8 @@ def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf
negative_prompt=negative_prompt,
num_inference_steps=steps,
).images[0]

image = upscale_resrgan(image)
image.save(output)

print('saved txt2img output: %s' % (output))
Expand Down
64 changes: 64 additions & 0 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer
from os import path
from PIL import Image
from realesrgan import RealESRGANer

denoise_strength = 0.5
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
fp32 = True
model_name = 'RealESRGAN_x4plus'
netscale = 4
outscale = 4
pre_pad = 0
tile = 0
tile_pad = 10


def upscale_resrgan(source_image: Image) -> Image:
model_path = path.join('weights', model_name + '.pth')
if not path.isfile(model_path):
ROOT_DIR = os.path.dirname(path.abspath(__file__))
for url in resrgan_url:
model_path = load_file_from_url(
url=url, model_dir=path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)

model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)

dni_weight = None
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]

upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=fp32)

output, _ = upsampler.enhance(source_image, outscale=outscale)

return upscale_gfpgan(output, upsampler)


def upscale_gfpgan(source_image: Image, upsampler) -> Image:
face_enhancer = GFPGANer(
model_path=gfpgan_url,
upscale=outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)

_, _, output = face_enhancer.enhance(source_image, has_aligned=False, only_center_face=False, paste_back=True)

return output

0 comments on commit 77cb84c

Please sign in to comment.