Skip to content

Commit

Permalink
fix(api): load upscaling model from models dir
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 16, 2023
1 parent 45d65d1 commit 806503c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
4 changes: 2 additions & 2 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf
negative_prompt=negative_prompt,
num_inference_steps=steps,
).images[0]

image = upscale_resrgan(image)
image = upscale_resrgan(image, model_path)
image.save(output)

print('saved txt2img output: %s' % (output))
Expand All @@ -295,6 +294,7 @@ def run_img2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf
num_inference_steps=steps,
strength=strength,
).images[0]
image = upscale_resrgan(image, model_path)
image.save(output)

print('saved img2img output: %s' % (output))
Expand Down
22 changes: 14 additions & 8 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
fp32 = True
fp16 = False
model_name = 'RealESRGAN_x4plus'
netscale = 4
outscale = 4
Expand All @@ -20,13 +20,12 @@
tile_pad = 10


def upscale_resrgan(source_image: Image, faces=True) -> Image:
model_path = path.join('weights', model_name + '.pth')
def make_resrgan(model_path):
model_path = path.join(model_path, model_name + '.pth')
if not path.isfile(model_path):
ROOT_DIR = path.dirname(path.abspath(__file__))
for url in resrgan_url:
model_path = load_file_from_url(
url=url, model_dir=path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
url=url, model_dir=path.join(model_path, model_name), progress=True, file_name=None)

model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
Expand All @@ -46,13 +45,19 @@ def upscale_resrgan(source_image: Image, faces=True) -> Image:
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
half=fp32)
half=fp16)

return upsampler


def upscale_resrgan(source_image: Image, model_path: str, faces=True) -> Image:
image = np.array(source_image)
upsampler = make_resrgan(model_path)

output, _ = upsampler.enhance(image, outscale=outscale)

if faces:
output = upscale_gfpgan(output, upsampler)
output = upscale_gfpgan(output, upsampler)

return Image.fromarray(output, 'RGB')

Expand All @@ -65,6 +70,7 @@ def upscale_gfpgan(image, upsampler) -> Image:
channel_multiplier=2,
bg_upsampler=upsampler)

_, _, output = face_enhancer.enhance(image, has_aligned=False, only_center_face=False, paste_back=True)
_, _, output = face_enhancer.enhance(
image, has_aligned=False, only_center_face=False, paste_back=True)

return output

0 comments on commit 806503c

Please sign in to comment.