-
Notifications
You must be signed in to change notification settings - Fork 26
/
upscale_resrgan.py
120 lines (100 loc) · 3.73 KB
/
upscale_resrgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from logging import getLogger
from os import path
from typing import Optional
from PIL import Image
from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3"
class UpscaleRealESRGANStage(BaseStage):
def load(
self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
# TODO: rewrite and remove
from realesrgan import RealESRGANer
class RealESRGANWrapper(RealESRGANer):
def __init__(
self,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None,
):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
self.model = model
self.device = device
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
if not path.isfile(model_path):
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
# TODO: swap for regular RRDBNet after rewriting wrapper
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
dni_weight = None
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
upsampler = RealESRGANWrapper(
scale=params.scale,
model_path=None,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=("torch-fp16" in server.optimizations),
device=device.torch_str(),
)
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
run_gc([device])
return upsampler
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
upsampler = self.load(
server, upscale, worker.get_device(), tile=stage.tile_size
)
outputs = []
for source in sources.as_numpy():
output, _ = upsampler.enhance(source, outscale=upscale.outscale)
logger.info("final output image size: %s", output.shape)
outputs.append(output)
return StageResult(arrays=outputs)