diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 2268b0f7ef9..183f8bd7c13 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -117,6 +117,12 @@ def __init__(self, net: Network, weights: NetworkWeights): if hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + elif isinstance(self.sd_module, nn.MultiheadAttention): + # For now, only self-attn use Pytorch's MHA + # So assume all qkvo proj have same shape + self.shape = self.sd_module.out_proj.weight.shape + else: + self.shape = None self.ops = None self.extra_kwargs = {} @@ -146,7 +152,7 @@ def __init__(self, net: Network, weights: NetworkWeights): self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None - self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None + self.dora_scale = weights.w.get("dora_scale", None) self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1) def multiplier(self): diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 04bd19117a5..42b14dc239d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -429,9 +429,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: try: with torch.no_grad(): - updown_q, _ = module_q.calc_updown(self.in_proj_weight) - updown_k, _ = module_k.calc_updown(self.in_proj_weight) - updown_v, _ = module_v.calc_updown(self.in_proj_weight) + # Send "real" orig_weight into MHA's lora module + qw, kw, vw = self.in_proj_weight.chunk(3, 0) + updown_q, _ = module_q.calc_updown(qw) + updown_k, _ = module_k.calc_updown(kw) + updown_v, _ = module_v.calc_updown(vw) + del qw, kw, vw updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index ff8a73ad214..9a67baea90f 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -5,19 +5,49 @@ id="{tabname}_{extra_networks_tabname}_extra_search" class="extra-network-control--search-text" type="search" - placeholder="Filter files" + placeholder="Search" > + + Sort:
- +
+
+ +
+
+ +
+
+ +
+ +
- +
+ + +
- +
- +
{pane_content} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 4d891b2452e..be5f0f304cf 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -39,12 +39,12 @@ function setupExtraNetworksForTab(tabname) { // tabname_full = {tabname}_{extra_networks_tabname} var tabname_full = elem.id; var search = gradioApp().querySelector("#" + tabname_full + "_extra_search"); - var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort"); var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir"); var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh"); + var currentSort = ''; // If any of the buttons above don't exist, we want to skip this iteration of the loop. - if (!search || !sort_mode || !sort_dir || !refresh) { + if (!search || !sort_dir || !refresh) { return; // `return` is equivalent of `continue` but for forEach loops. } @@ -52,7 +52,7 @@ function setupExtraNetworksForTab(tabname) { var searchTerm = search.value.toLowerCase(); gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) { return t.textContent.toLowerCase(); }).join(" "); @@ -74,19 +74,20 @@ function setupExtraNetworksForTab(tabname) { var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card'); var parent = gradioApp().querySelector('#' + tabname_full + "_cards"); var reverse = sort_dir.dataset.sortdir == "Descending"; - var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; - sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); - var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + var activeSearchElem = gradioApp().querySelector('#' + tabname_full + "_controls .extra-network-control--sort.extra-network-control--enabled"); + var sortKey = activeSearchElem ? activeSearchElem.dataset.sortkey : "default"; + var sortKeyDataField = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + sort_dir.dataset.sortdir + "-" + cards.length; - if (sortKeyStore == sort_mode.dataset.sortkey && !force) { + if (sortKeyStore == currentSort && !force) { return; } - sort_mode.dataset.sortkey = sortKeyStore; + currentSort = sortKeyStore; var sortedCards = Array.from(cards); sortedCards.sort(function(cardA, cardB) { - var a = cardA.dataset[sortKey]; - var b = cardB.dataset[sortKey]; + var a = cardA.dataset[sortKeyDataField]; + var b = cardB.dataset[sortKeyDataField]; if (!isNaN(a) && !isNaN(b)) { return parseInt(a) - parseInt(b); } @@ -395,36 +396,17 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { } function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) { - /** - * Handles `onclick` events for the Sort Mode button. - * - * Modifies the data attributes of the Sort Mode button to cycle between - * various sorting modes. - * - * @param event The generated event. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. - */ - var curr_mode = event.currentTarget.dataset.sortmode; - var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir"); - var sort_dir = el_sort_dir.dataset.sortdir; - if (curr_mode == "path") { - event.currentTarget.dataset.sortmode = "name"; - event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by filename"); - } else if (curr_mode == "name") { - event.currentTarget.dataset.sortmode = "date_created"; - event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by date created"); - } else if (curr_mode == "date_created") { - event.currentTarget.dataset.sortmode = "date_modified"; - event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by date modified"); - } else { - event.currentTarget.dataset.sortmode = "path"; - event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; - event.currentTarget.setAttribute("title", "Sort by path"); - } + /** Handles `onclick` events for Sort Mode buttons. */ + + var self = event.currentTarget; + var parent = event.currentTarget.parentElement; + + parent.querySelectorAll('.extra-network-control--sort').forEach(function(x) { + x.classList.remove('extra-network-control--enabled'); + }); + + self.classList.add('extra-network-control--enabled'); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } diff --git a/modules/cmd_args.py b/modules/cmd_args.py index bf355303151..016a33d1057 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -124,3 +124,4 @@ parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui") parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system") parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system') +parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file") diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index db1866449ad..a1cbfb17d7b 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -462,7 +462,7 @@ def get_override_settings(params, *, skip_fields=None): def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): - if not prompt and not shared.cmd_opts.hide_ui_dir_config: + if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history: filename = os.path.join(data_path, "params.txt") try: with open(filename, "r", encoding="utf8") as file: diff --git a/modules/options.py b/modules/options.py index 35ccade25be..2a78a825ee2 100644 --- a/modules/options.py +++ b/modules/options.py @@ -240,6 +240,9 @@ def dumpjson(self): item_categories = {} for item in self.data_labels.values(): + if item.section[0] is None: + continue + category = categories.mapping.get(item.category_id) category = "Uncategorized" if category is None else category.label if category not in item_categories: diff --git a/modules/processing.py b/modules/processing.py index 93493f80e66..86194b0576b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -904,7 +904,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: # infotext could be modified by that callback # Example: a wildcard processed by process_batch sets an extra model # strength, which is saved as "Model Strength: 1.0" in the infotext - if n == 0: + if n == 0 and not cmd_opts.no_prompt_history: with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, []) file.write(processed.infotext(p, 0)) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0de17af3db2..94ff973fb84 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -13,8 +13,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 - width = getattr(batch, 'width', 1024) - height = getattr(batch, 'height', 1024) + width = getattr(batch, 'width', 1024) or 1024 + height = getattr(batch, 'height', 1024) or 1024 is_negative_prompt = getattr(batch, 'is_negative_prompt', False) aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 9a1cf913fe5..f4627ce8dbc 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -569,18 +569,16 @@ def create_html(self, tabname, *, empty=False): if "user_metadata" not in item: self.read_user_metadata(item) - data_sortdir = shared.opts.extra_networks_card_order - data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() - data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" - show_tree = shared.opts.extra_networks_tree_view_default_enabled page_params = { "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, - "data_sortmode": data_sortmode, - "data_sortkey": data_sortkey, - "data_sortdir": data_sortdir, + "data_sortdir": shared.opts.extra_networks_card_order, + "sort_path_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Path' else '', + "sort_name_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Name' else '', + "sort_date_created_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Created' else '', + "sort_date_modified_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Modified' else '', "tree_view_btn_extra_class": "extra-network-control--enabled" if show_tree else "", "items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None), "extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width, diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 2ca937fd117..fde093700b8 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -133,8 +133,10 @@ def write_user_metadata(self, name, metadata): filename = item.get("filename", None) basename, ext = os.path.splitext(filename) - with open(basename + '.json', "w", encoding="utf8") as file: + metadata_path = basename + '.json' + with open(metadata_path, "w", encoding="utf8") as file: json.dump(metadata, file, indent=4, ensure_ascii=False) + self.page.lister.update_file_entry(metadata_path) def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) @@ -185,7 +187,8 @@ def save_preview(self, index, gallery, name): geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - + self.page.lister.update_file_entry(item["local_preview"]) + item['preview'] = self.page.find_preview(item["local_preview"]) return self.get_card_html(name), '' def setup_ui(self, gallery): @@ -200,6 +203,3 @@ def setup_ui(self, gallery): inputs=[self.edit_name_input], outputs=[] ) - - - diff --git a/modules/util.py b/modules/util.py index 8d1aea44f5c..cb690e734f8 100644 --- a/modules/util.py +++ b/modules/util.py @@ -81,6 +81,17 @@ def __init__(self, dirname): self.files = {x[0].lower(): x for x in files} self.files_cased = {x[0]: x for x in files} + def update_entry(self, filename): + """Add a file to the cache""" + file_path = os.path.join(self.dirname, filename) + try: + stat = os.stat(file_path) + entry = (filename, stat.st_mtime, stat.st_ctime) + self.files[filename.lower()] = entry + self.files_cased[filename] = entry + except FileNotFoundError as e: + print(f'MassFileListerCachedDir.add_entry: "{file_path}" {e}') + class MassFileLister: """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.""" @@ -136,3 +147,9 @@ def mctime(self, path): def reset(self): """Clear the cache of all directories.""" self.cached_dirs.clear() + + def update_file_entry(self, path): + """Update the cache for a specific directory.""" + dirname, filename = os.path.split(path) + if cached_dir := self.cached_dirs.get(dirname): + cached_dir.update_entry(filename) diff --git a/style.css b/style.css index 49978a771a5..29eae41277e 100644 --- a/style.css +++ b/style.css @@ -1,6 +1,6 @@ /* temporary fix to load default gradio font in frontend instead of backend */ -@import url('webui-assets/css/sourcesanspro.css'); +@import url('/webui-assets/css/sourcesanspro.css'); /* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ @@ -1272,7 +1272,7 @@ body.resizing .resize-handle { .extra-network-control { position: relative; - display: grid; + display: flex; width: 100%; padding: 0 !important; margin-top: 0 !important; @@ -1289,6 +1289,12 @@ body.resizing .resize-handle { align-items: start; } +.extra-network-control small{ + color: var(--input-placeholder-color); + line-height: 2.2rem; + margin: 0 0.5rem 0 0.75rem; +} + .extra-network-tree .tree-list--tree {} /* Remove auto indentation from tree. Will be overridden later. */ @@ -1436,6 +1442,12 @@ body.resizing .resize-handle { line-height: 1rem; } + +.extra-network-control .extra-network-control--search .extra-network-control--search-text::placeholder { + color: var(--input-placeholder-color); +} + + /* clear button (x on right side) styling */ .extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button { -webkit-appearance: none; @@ -1468,19 +1480,19 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-control .extra-network-control--sort[data-sortmode="path"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="default"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-control .extra-network-control--sort[data-sortmode="name"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="name"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-control .extra-network-control--sort[data-sortmode="date_created"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="date_created"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-control .extra-network-control--sort[data-sortmode="date_modified"] .extra-network-control--sort-icon { +.extra-network-control .extra-network-control--sort[data-sortkey="date_modified"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } @@ -1530,13 +1542,18 @@ body.resizing .resize-handle { } .extra-network-control .extra-network-control--enabled { - background-color: rgba(0, 0, 0, 0.15); + background-color: rgba(0, 0, 0, 0.1); + border-radius: 0.25rem; } .dark .extra-network-control .extra-network-control--enabled { background-color: rgba(255, 255, 255, 0.15); } +.extra-network-control .extra-network-control--enabled .extra-network-control--icon{ + background-color: var(--button-secondary-text-color); +} + /* ==== REFRESH ICON ACTIONS ==== */ .extra-network-control .extra-network-control--refresh { padding: 0.25rem;