Skip to content

Commit

Permalink
fix(api): update codeformer patches for new lib
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 19, 2023
1 parent 7ed30ee commit 3e1db70
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
13 changes: 9 additions & 4 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

logger = getLogger(__name__)

CORRECTION_MODEL = "correction-codeformer.pth"
DETECTION_MODEL = "retinaface_resnet50"


class CorrectCodeformerStage(BaseStage):
def run(
Expand All @@ -28,7 +31,10 @@ def run(
upscale: UpscaleParams,
**kwargs,
) -> StageResult:
# must be within the load function for patch to take effect
# adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and
# https://pypi.org/project/codeformer-perceptor/

# import must be within the load function for patches to take effect
# TODO: rewrite and remove
from codeformer.basicsr.utils import img2tensor, tensor2img
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
Expand All @@ -45,17 +51,16 @@ def run(
connect_list=["32", "64", "128", "256"],
).to(device.torch_str())

ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
ckpt_path = path.join(server.cache_path, CORRECTION_MODEL)
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()

det_model = "retinaface_resnet50"
face_helper = FaceRestoreHelper(
upscale.face_outscale,
face_size=512,
crop_ratio=(1, 1),
det_model=det_model,
det_model=DETECTION_MODEL,
save_ext="png",
use_parse=True,
device=device.torch_str(),
Expand Down
7 changes: 3 additions & 4 deletions api/onnx_web/server/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,10 @@ def apply_patch_basicsr(server: ServerContext):
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
try:
import codeformer.basicsr.utils # download_util
import codeformer.facelib.utils.misc
import codeformer.basicsr.utils.download_util

codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(
codeformer.basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
codeformer.basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
Expand Down

0 comments on commit 3e1db70

Please sign in to comment.