Skip to content

Commit

Permalink
Merge pull request #15943 from eatmoreapple/update-lora-load
Browse files Browse the repository at this point in the history
feat: lora partial update precede full update
  • Loading branch information
AUTOMATIC1111 authored Jun 8, 2024
2 parents 46bcfbe + 10f8d0f commit 6de733c
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,16 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No

loaded_networks.clear()

unavailable_networks = []
for name in names:
if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
unavailable_networks.append(name)
elif available_network_aliases.get(name) is None:
unavailable_networks.append(name)

if unavailable_networks:
update_available_networks_by_names(unavailable_networks)

networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
Expand Down Expand Up @@ -566,22 +576,16 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)


def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})

os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

def process_network_files(names: list[str] | None = None):
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in candidates:
if os.path.isdir(filename):
continue

name = os.path.splitext(os.path.basename(filename))[0]
# if names is provided, only load networks with names in the list
if names and name not in names:
continue
try:
entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
Expand All @@ -597,6 +601,22 @@ def list_available_networks():
available_network_aliases[entry.alias] = entry


def update_available_networks_by_names(names: list[str]):
process_network_files(names)


def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})

os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

process_network_files()


re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")


Expand Down

0 comments on commit 6de733c

Please sign in to comment.