Skip to content

Commit

Permalink
feat(api): use ONNX for Real ESRGAN v3 model
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 18, 2023
1 parent 3dde3b9 commit 2c9d96d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 22 deletions.
37 changes: 19 additions & 18 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,6 @@ def load_resrgan(
if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path)

if x4_v3_tag in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
elif params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
Expand All @@ -56,14 +46,25 @@ def load_resrgan(
sess_options=device.sess_options(),
)
elif params.format == "pth":
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
if x4_v3_tag in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=params.scale,
)
else:
raise Exception("unknown platform %s" % params.format)

Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .onnx_net import OnnxImage, OnnxNet
from .onnx_net import OnnxTensor, OnnxNet
8 changes: 5 additions & 3 deletions api/onnx_web/onnx/onnx_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..utils import ServerContext


class OnnxImage:
class OnnxTensor:
def __init__(self, source) -> None:
self.source = source
self.data = self
Expand Down Expand Up @@ -57,8 +57,10 @@ def __init__(
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0]
return OnnxImage(output)
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
return OnnxTensor(output)

def eval(self) -> None:
pass
Expand Down

0 comments on commit 2c9d96d

Please sign in to comment.