Skip to content

Commit

Permalink
fix(api): report accurate sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 16, 2023
1 parent d406cd4 commit 4bf6875
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
6 changes: 3 additions & 3 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def img2img():
return jsonify({
'output': output,
'params': params.tojson(),
'size': size.tojson(),
'size': upscale.resize(size).tojson(),
})


Expand All @@ -345,7 +345,7 @@ def txt2img():
return jsonify({
'output': output,
'params': params.tojson(),
'size': size.tojson(),
'size': upscale.resize(size).tojson(),
})


Expand Down Expand Up @@ -399,7 +399,7 @@ def inpaint():
return jsonify({
'output': output,
'params': params.tojson(),
'size': size.tojson(),
'size': upscale.resize(size.with_border(expand)).tojson(),
})


Expand Down
11 changes: 8 additions & 3 deletions api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import torch

from .utils import (
ServerContext
ServerContext,
Size,
)

# TODO: these should all be params or config
Expand Down Expand Up @@ -49,7 +50,7 @@ def size(self):

class ONNXNet():
'''
Provides the RRDBNet interface but using ONNX.
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
'''

def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
Expand Down Expand Up @@ -102,6 +103,9 @@ def __init__(
self.platform = platform
self.half = half

def resize(self, size: Size) -> Size:
return Size(size.width * self.scale * self.outscale, size.height * self.scale * self.outscale)


def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_file = '%s.%s' % (params.upscale_model, params.platform)
Expand All @@ -125,9 +129,10 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]

# TODO: shouldn't need the PTH file
upsampler = RealESRGANer(
scale=params.scale,
model_path=model_path,
model_path=path.join(ctx.model_path, '%s.pth' % params.upscale_model),
dni_weight=dni_weight,
model=model,
tile=tile,
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def tojson(self) -> Dict[str, int]:
'width': self.width,
}

def with_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)


def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
Expand Down

0 comments on commit 4bf6875

Please sign in to comment.