Skip to content

Commit

Permalink
feat: add API parameter for upscale checkbox
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 31, 2023
1 parent a5ecb59 commit 71fbc87
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 25 deletions.
50 changes: 26 additions & 24 deletions api/onnx_web/chain/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,33 @@ def stage_upscale_correction(
**kwargs,
"upscale": upscale,
}

upscale_stage: Optional[PipelineStage] = None
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else:
logger.warning("unknown upscaling model: %s", upscale.upscale_model)
if upscale.upscale:
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else:
logger.warning("unknown upscaling model: %s", upscale.upscale_model)

correct_stage: Optional[PipelineStage] = None
if upscale.faces:
Expand Down
5 changes: 5 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def __init__(
upscale_model: str,
correction_model: Optional[str] = None,
denoise: float = 0.5,
upscale=True,
faces=True,
face_outscale: int = 1,
face_strength: float = 0.5,
Expand All @@ -406,6 +407,7 @@ def __init__(
self.upscale_model = upscale_model
self.correction_model = correction_model
self.denoise = denoise
self.upscale = upscale
self.faces = faces
self.face_outscale = face_outscale
self.face_strength = face_strength
Expand All @@ -421,6 +423,7 @@ def rescale(self, scale: int):
self.upscale_model,
correction_model=self.correction_model,
denoise=self.denoise,
upscale=self.upscale,
faces=self.faces,
face_outscale=self.face_outscale,
face_strength=self.face_strength,
Expand All @@ -447,6 +450,7 @@ def tojson(self):
"upscale_model": self.upscale_model,
"correction_model": self.correction_model,
"denoise": self.denoise,
"upscale": self.upscale,
"faces": self.faces,
"face_outscale": self.face_outscale,
"face_strength": self.face_strength,
Expand All @@ -463,6 +467,7 @@ def with_args(self, **kwargs):
kwargs.get("upscale_model", self.upscale_model),
kwargs.get("correction_model", self.correction_model),
kwargs.get("denoise", self.denoise),
kwargs.get("upscale", self.upscale),
kwargs.get("faces", self.faces),
kwargs.get("face_outscale", self.face_outscale),
kwargs.get("face_strength", self.face_strength),
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/server/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def build_upscale(
if data is None:
data = request.args

upscale = get_boolean(data, "upscale", False)
denoise = get_and_clamp_float(
data,
"denoise",
Expand All @@ -262,7 +263,8 @@ def build_upscale(
)
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
correction = get_from_list(data, "correction", get_correction_models())
faces = get_not_empty(data, "faces", "false") == "true"

faces = get_boolean(data, "faces", False)
face_outscale = get_and_clamp_int(
data,
"faceOutscale",
Expand All @@ -283,6 +285,7 @@ def build_upscale(
upscaling,
correction_model=correction,
denoise=denoise,
upscale=upscale,
faces=faces,
face_outscale=face_outscale,
face_strength=face_strength,
Expand Down
1 change: 1 addition & 0 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
* Append the upscale parameters to an existing URL.
*/
export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
url.searchParams.append('upscale', String(upscale.enabled));
url.searchParams.append('upscaleOrder', upscale.upscaleOrder);

if (upscale.enabled) {
Expand Down

0 comments on commit 71fbc87

Please sign in to comment.