Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug when using --gfpgan-models-path #13718

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions modules/gfpgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(paths.models_path, model_dir)
model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None
Expand All @@ -17,24 +18,32 @@
def gfpgann():
global loaded_gfpgan_model
global model_path
global model_file_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model

if gfpgan_constructor is None:
return None

models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])

if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
gfp_models = []
for item in models:
if 'GFPGAN' in os.path.basename(item):
gfp_models.append(item)
latest_file = max(gfp_models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None

if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan
model_file_path = model_file
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model

Expand Down Expand Up @@ -77,19 +86,25 @@ def setup_model(dirname):
global user_path
global have_gfpgan
global gfpgan_constructor
global model_file_path

facexlib_path = model_path

if dirname is not None:
facexlib_path = dirname

load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url

def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))

def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))

def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))

gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
Expand Down