Skip to content

Commit

Permalink
Merge pull request #5586 from wywywywy/ldsr-improvements
Browse files Browse the repository at this point in the history
LDSR improvements - cache / optimization / opt_channelslast
  • Loading branch information
AUTOMATIC1111 authored Dec 10, 2022
2 parents 0a81dd5 + 1581d5a commit 685f963
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
49 changes: 34 additions & 15 deletions extensions-builtin/LDSR/ldsr_model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,41 @@

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack

warnings.filterwarnings("ignore", category=UserWarning)

cached_ldsr_model: torch.nn.Module = None


# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model.cuda()
if half_attention:
model = model.half()

model.eval()
global cached_ldsr_model

if shared.opts.ldsr_cached and cached_ldsr_model is not None:
print(f"Loading model from cache")
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model = model.to(shared.device)
if half_attention:
model = model.half()
if shared.cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)

sd_hijack.model_hijack.hijack(model) # apply optimization
model.eval()

if shared.opts.ldsr_cached:
cached_ldsr_model = model

return {"model": model}

def __init__(self, model_path, yaml_path):
Expand Down Expand Up @@ -94,7 +110,8 @@ def super_resolution(self, image, steps=100, target_scale=2, half_attention=Fals
down_sample_method = 'Lanczos'

gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available:
torch.cuda.empty_cache()

im_og = image
width_og, height_og = im_og.size
Expand Down Expand Up @@ -131,7 +148,9 @@ def super_resolution(self, image, steps=100, target_scale=2, half_attention=Fals

del model
gc.collect()
torch.cuda.empty_cache()
if torch.cuda.is_available:
torch.cuda.empty_cache()

return a


Expand All @@ -146,7 +165,7 @@ def get_cond(selected_path):
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.

c = c.to(torch.device("cuda"))
c = c.to(shared.device)
example["LR_image"] = c
example["image"] = c_up

Expand Down
1 change: 1 addition & 0 deletions extensions-builtin/LDSR/scripts/ldsr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def on_ui_settings():
import gradio as gr

shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))


script_callbacks.on_ui_settings(on_ui_settings)

0 comments on commit 685f963

Please sign in to comment.