From 10f8d0f84216e3642e960ea7118a5acc8a79546f Mon Sep 17 00:00:00 2001 From: eatmoreapple Date: Tue, 4 Jun 2024 15:02:13 +0800 Subject: [PATCH] feat: lora partial update precede full update. --- extensions-builtin/Lora/networks.py | 40 +++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 42b14dc239d..18809364b61 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -260,6 +260,16 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No loaded_networks.clear() + unavailable_networks = [] + for name in names: + if name.lower() in forbidden_network_aliases and available_networks.get(name) is None: + unavailable_networks.append(name) + elif available_network_aliases.get(name) is None: + unavailable_networks.append(name) + + if unavailable_networks: + update_available_networks_by_names(unavailable_networks) + networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() @@ -566,22 +576,16 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) -def list_available_networks(): - available_networks.clear() - available_network_aliases.clear() - forbidden_network_aliases.clear() - available_network_hash_lookup.clear() - forbidden_network_aliases.update({"none": 1, "Addams": 1}) - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - +def process_network_files(names: list[str] | None = None): candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue - name = os.path.splitext(os.path.basename(filename))[0] + # if names is provided, only load networks with names in the list + if names and name not in names: + continue try: entry = network.NetworkOnDisk(name, filename) except OSError: # should catch FileNotFoundError and PermissionError etc. @@ -597,6 +601,22 @@ def list_available_networks(): available_network_aliases[entry.alias] = entry +def update_available_networks_by_names(names: list[str]): + process_network_files(names) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + process_network_files() + + re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")