From 4bf68759d72b089082ddd490b110f91f0c4e5de9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 15:11:40 -0600 Subject: [PATCH] fix(api): report accurate sizes --- api/onnx_web/serve.py | 6 +++--- api/onnx_web/upscale.py | 11 ++++++++--- api/onnx_web/utils.py | 3 +++ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 68b91f21a..c8162771a 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -324,7 +324,7 @@ def img2img(): return jsonify({ 'output': output, 'params': params.tojson(), - 'size': size.tojson(), + 'size': upscale.resize(size).tojson(), }) @@ -345,7 +345,7 @@ def txt2img(): return jsonify({ 'output': output, 'params': params.tojson(), - 'size': size.tojson(), + 'size': upscale.resize(size).tojson(), }) @@ -399,7 +399,7 @@ def inpaint(): return jsonify({ 'output': output, 'params': params.tojson(), - 'size': size.tojson(), + 'size': upscale.resize(size.with_border(expand)).tojson(), }) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 02c212a8c..06abd16d5 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -10,7 +10,8 @@ import torch from .utils import ( - ServerContext + ServerContext, + Size, ) # TODO: these should all be params or config @@ -49,7 +50,7 @@ def size(self): class ONNXNet(): ''' - Provides the RRDBNet interface but using ONNX. + Provides the RRDBNet interface using an ONNX session for DirectML acceleration. ''' def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None: @@ -102,6 +103,9 @@ def __init__( self.platform = platform self.half = half + def resize(self, size: Size) -> Size: + return Size(size.width * self.scale * self.outscale, size.height * self.scale * self.outscale) + def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): model_file = '%s.%s' % (params.upscale_model, params.platform) @@ -125,9 +129,10 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): model_path = [model_path, wdn_model_path] dni_weight = [params.denoise, 1 - params.denoise] + # TODO: shouldn't need the PTH file upsampler = RealESRGANer( scale=params.scale, - model_path=model_path, + model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model), dni_weight=dni_weight, model=model, tile=tile, diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 86c428be3..05be707f4 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -76,6 +76,9 @@ def tojson(self) -> Dict[str, int]: 'width': self.width, } + def with_border(self, border: Border): + return Size(border.left + self.width + border.right, border.top + self.height + border.right) + def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: return min(max(float(args.get(key, default_value)), min_value), max_value)