Skip to content

Commit

Permalink
Fix item 1,2,3,5,6
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Nov 19, 2023
1 parent c5371a5 commit 581ab58
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 18 deletions.
26 changes: 14 additions & 12 deletions modules/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,24 @@ def fail(message):
return [*[gr.update() for _ in range(4)], message]

kwargs["models"] = {
"model_a": sd_models.checkpoints_list[kwargs.get("primary_model_name", None)].filename,
"model_b": sd_models.checkpoints_list[kwargs.get("secondary_model_name", None)].filename,
"model_a": sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None)).filename,
"model_b": sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None)).filename,
}

if kwargs.get("primary_model_name", None) in [None, 'None']:
return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[kwargs.get("primary_model_name", None)]
primary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("primary_model_name", None))
if kwargs.get("secondary_model_name", None) in [None, 'None']:
return fail("Failed: Merging requires a secondary model.")
secondary_model_info = sd_models.checkpoints_list[kwargs.get("secondary_model_name", None)]
secondary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None))
if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in TRIPLE_METHODS:
return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.")
tertiary_model_info = sd_models.checkpoints_list[kwargs.get("tertiary_model_name", None)] if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None
tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None

del kwargs["primary_model_name"]
del kwargs["secondary_model_name"]
if kwargs.get("tertiary_model_name", None) is not None:
kwargs["models"] |= {"model_c": sd_models.checkpoints_list[kwargs.get("tertiary_model_name", None)].filename}
kwargs["models"] |= {"model_c": sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)).filename}
del kwargs["tertiary_model_name"]

try:
Expand Down Expand Up @@ -113,20 +113,26 @@ def fail(message):
kwargs.pop("beta_out_blocks", None)
kwargs.pop("beta_preset", None)

if kwargs["device"] == "cuda":
if kwargs["device"] == "gpu":
kwargs["device"] = devices.device
sd_models.unload_model_weights()
elif kwargs["device"] == "shuffle":
kwargs["device"] = torch.device("cpu")
kwargs["work_device"] = devices.device
else:
kwargs["device"] = torch.device("cpu")
if kwargs.pop("unload", False):
sd_models.unload_model_weights()

try:
theta_0 = merge_models(**kwargs)
except Exception as e:
return fail(f"{e}")

try:
theta_0 = theta_0.to_dict() #TensorDict -> Dict if necessary
except:
pass

bake_in_vae_filename = sd_vae.vae_dict.get(kwargs.get("bake_in_vae", None), None)
if bake_in_vae_filename is not None:
shared.log.info(f"Merge: baking in VAE: {bake_in_vae_filename}")
Expand Down Expand Up @@ -177,10 +183,6 @@ def add_model_metadata(checkpoint_info):

_, extension = os.path.splitext(output_modelname)

try:
theta_0 = theta_0.to_dict()
except:
pass

if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
Expand Down
14 changes: 11 additions & 3 deletions modules/merging/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def prune_sd_model(model: Dict) -> Dict:
for k in keys:
if (
not k.startswith("model.diffusion_model.")
and not k.startswith("first_stage_model.")
# and not k.startswith("first_stage_model.")
and not k.startswith("cond_stage_model.")
):
del model[k]
Expand Down Expand Up @@ -153,23 +153,31 @@ def un_prune_model(
devices.torch_gc(force=True)
log_vram("remove thetas")
original_a = TensorDict.from_dict(read_state_dict(models["model_a"], device))
for key in tqdm(original_a.keys(), desc="un-prune model a"):
unpruned = 0
for key in original_a.keys():
if KEY_POSITION_IDS in key:
continue
if "model" in key and key not in merged.keys():
merged.update({key: original_a[key]})
unpruned += 1
if precision == "fp16":
merged.update({key: merged[key].half()})
if unpruned != 0:
log.info(f"Merge: {unpruned} unmerged keys restored from Primary Model")
unpruned = 0
del original_a
devices.torch_gc(force=True)
original_b = TensorDict.from_dict(read_state_dict(models["model_b"], device))
for key in tqdm(original_b.keys(), desc="un-prune model b"):
for key in original_b.keys():
if KEY_POSITION_IDS in key:
continue
if "model" in key and key not in merged.keys():
merged.update({key: original_b[key]})
unpruned += 1
if precision == "fp16":
merged.update({key: merged[key].half()})
if unpruned != 0:
log.info(f"Merge: {unpruned} unmerged keys restored from Secondary Model")
del original_b

return fix_clip(merged)
Expand Down
16 changes: 13 additions & 3 deletions modules/ui_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ def sd_model_choices():
with FormRow():
precision = gr.Radio(choices=["fp16", "fp32"], value="fp16", label="Model precision")
with FormRow():
device = gr.Radio(choices=["cpu", "shuffle", "cuda"], value="cpu", label="Merge Device")
device = gr.Radio(choices=["cpu", "shuffle", "gpu"], value="cpu", label="Merge Device")
unload = gr.Checkbox(label="Unload Current Model from VRAM", value=False, visible=False)
with FormRow():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None",
interactive=True, label="Bake in VAE")
interactive=True, label="Replace VAE")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list,
lambda: {"choices": ["None"] + list(sd_vae.vae_dict)},
"modelmerger_refresh_bake_in_vae")
Expand Down Expand Up @@ -202,6 +203,7 @@ def modelmerger(dummy_component,
re_basin,
re_basin_iterations,
device,
unload,
bake_in_vae):
kwargs = {}
for x in inspect.getfullargspec(modelmerger)[0]:
Expand Down Expand Up @@ -240,6 +242,13 @@ def show_help(mode):
doc = getattr(merge_methods, mode).__doc__.replace("\n", "<br>")
return gr.update(value=doc, visible=True)

def show_unload(device):
if device == "gpu":
return gr.update(visble=True)
else:
return gr.update(visble=False)


def preset_visiblility(x):
if len(x) == 2:
return gr.Slider.update(value=0.5, visible=True)
Expand All @@ -262,7 +271,7 @@ def preset_choices(sdxl):
return [gr.update(choices=["None"] + list(SDXL_BLOCK_WEIGHTS_PRESETS.keys())) for _ in range(2)]
else:
return [gr.update(choices=["None"] + list(BLOCK_WEIGHTS_PRESETS.keys())) for _ in range(2)]

device.change(fn=show_unload, inputs=device, outputs=unload)
merge_mode.change(fn=show_help, inputs=merge_mode, outputs=merge_mode_docs)
sdxl.change(fn=preset_choices, inputs=sdxl, outputs=[alpha_preset, beta_preset])
alpha_preset.change(fn=preset_visiblility, inputs=alpha_preset, outputs=alpha_preset_lambda)
Expand Down Expand Up @@ -309,6 +318,7 @@ def preset_choices(sdxl):
re_basin,
re_basin_iterations,
device,
unload,
bake_in_vae,
],
outputs=[
Expand Down

0 comments on commit 581ab58

Please sign in to comment.