Skip to content

Commit

Permalink
feat(api): add CodeFormer to automatic upscale
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 5, 2023
1 parent 9d0609f commit 0a9f108
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
17 changes: 9 additions & 8 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@

logger = getLogger(__name__)

pretrain_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
)

device = "cpu"


def correct_codeformer(
_job: JobContext,
job: JobContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
source_image: Image.Image = None,
**kwargs,
) -> Image.Image:
pipe = CodeFormer().to(device)
device = job.get_device()
# TODO: terrible names, fix
image = source or source_image

return pipe(source_image)
pipe = CodeFormer().to(device.torch_device())
return pipe(image)
10 changes: 9 additions & 1 deletion api/onnx_web/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .chain import (
ChainPipeline,
correct_codeformer,
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
Expand Down Expand Up @@ -40,9 +41,16 @@ def run_upscale_correction(
mini_tile = min(SizeChart.mini, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, stage, None))
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)

if upscale.faces:
stage = StageParams(tile_size=stage.tile_size, outscale=1)
chain.append((correct_gfpgan, stage, None))
if "codeformer" in upscale.correction_model:
chain.append((correct_codeformer, stage, None))
elif "gfpgan" in upscale.correction_model:
chain.append((correct_gfpgan, stage, None))
else:
logger.warn("unknown correction model: %s", upscale.correction_model)

return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)

0 comments on commit 0a9f108

Please sign in to comment.