Skip to content

Commit

Permalink
in-mid-out LoRA Weights
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Dec 3, 2023
1 parent 099188a commit 323e2c1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
7 changes: 5 additions & 2 deletions extensions-builtin/Lora/extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ def activate(self, p, params_list):
names.append(params.positional[0])
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
te_multiplier = float(params.named.get("te", te_multiplier))
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
unet_multiplier = float(params.named.get("unet", unet_multiplier))
unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3
unet_multiplier = [float(params.named.get("unet", unet_multiplier))] * 3
unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0]))
unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1]))
unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2]))
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
te_multipliers.append(te_multiplier)
Expand Down
10 changes: 8 additions & 2 deletions extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.te_multiplier = 1.0
self.unet_multiplier = 1.0
self.unet_multiplier = [1.0] * 3
self.dyn_dim = None
self.modules = {}
self.mtime = None
Expand Down Expand Up @@ -112,8 +112,14 @@ def __init__(self, net: Network, weights: NetworkWeights):
def multiplier(self):
if 'transformer' in self.sd_key[:20]:
return self.network.te_multiplier
if "input_blocks" in self.sd_key:
return self.network.unet_multiplier[0]
if "middle_block" in self.sd_key:
return self.network.unet_multiplier[1]
if "output_blocks" in self.sd_key:
return self.network.unet_multiplier[2]
else:
return self.network.unet_multiplier
return self.network.unet_multiplier[0]

def calc_scale(self):
if self.scale is not None:
Expand Down

0 comments on commit 323e2c1

Please sign in to comment.