Skip to content

Commit

Permalink
Cleanup and optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Nov 18, 2023
1 parent d2d54af commit 026790f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions modules/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def fail(message):
assert len(alpha) == 26 or len(
alpha) == 20, "Alpha Block Weights are wrong length (26 or 20 for SDXL) falling back"
kwargs["alpha"] = alpha
except:
except AssertionError as e:
shared.log.info(e)
kwargs["alpha"] = kwargs.get("alpha_preset", kwargs["alpha"])
finally:
kwargs.pop("alpha_base", None)
Expand All @@ -303,7 +304,8 @@ def fail(message):
assert len(beta) == 26 or len(
beta) == 20, "Beta Block Weights are wrong length (26 or 20 for SDXL) falling back"
kwargs["beta"] = beta
except:
except AssertionError as e:
shared.log.info(e)
kwargs["beta"] = kwargs.get("beta_preset", kwargs["beta"])
finally:
kwargs.pop("beta_base", None)
Expand All @@ -313,15 +315,14 @@ def fail(message):
kwargs.pop("beta_preset", None)

if kwargs["device"] == "cuda":
kwargs["device"] == devices.device
kwargs["device"] = devices.device
sd_models.unload_model_weights()
elif kwargs["device"] == "shuffle":
kwargs["device"] = torch.device("cpu")
kwargs["work_device"] = devices.device
else:
kwargs["device"] = torch.device("cpu")


try:
theta_0 = merge_models(**kwargs)
except Exception as e:
Expand Down Expand Up @@ -390,6 +391,7 @@ def add_model_metadata(checkpoint_info):
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
if created_model:
created_model.calculate_shorthash()
devices.torch_gc(force=True)
shared.log.info(f"Model merge saved: {output_modelname}.")
shared.state.textinfo = "Checkpoint saved"
shared.state.end()
Expand Down

0 comments on commit 026790f

Please sign in to comment.