Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…on-webui into dev
  • Loading branch information
MisterSeajay committed Mar 13, 2024
2 parents 42043fc + 3e0146f commit 9adc0d4
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 76 deletions.
8 changes: 7 additions & 1 deletion extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def __init__(self, net: Network, weights: NetworkWeights):

if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA
# So assume all qkvo proj have same shape
self.shape = self.sd_module.out_proj.weight.shape
else:
self.shape = None

self.ops = None
self.extra_kwargs = {}
Expand Down Expand Up @@ -146,7 +152,7 @@ def __init__(self, net: Network, weights: NetworkWeights):
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None

self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None
self.dora_scale = weights.w.get("dora_scale", None)
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)

def multiplier(self):
Expand Down
9 changes: 6 additions & 3 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
try:
with torch.no_grad():
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
# Send "real" orig_weight into MHA's lora module
qw, kw, vw = self.in_proj_weight.chunk(3, 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)

Expand Down
51 changes: 42 additions & 9 deletions html/extra-networks-pane.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,76 @@
id="{tabname}_{extra_networks_tabname}_extra_search"
class="extra-network-control--search-text"
type="search"
placeholder="Filter files"
placeholder="Search"
>
</div>

<small>Sort: </small>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort"
class="extra-network-control--sort"
data-sortmode="{data_sortmode}"
data-sortkey="{data_sortkey}"
id="{tabname}_{extra_networks_tabname}_extra_sort_path"
class="extra-network-control--sort{sort_path_active}"
data-sortkey="default"
title="Sort by path"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--sort-icon"></i>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_name"
class="extra-network-control--sort{sort_name_active}"
data-sortkey="name"
title="Sort by name"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_date_created"
class="extra-network-control--sort{sort_date_created_active}"
data-sortkey="date_created"
title="Sort by date created"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_date_modified"
class="extra-network-control--sort{sort_date_modified_active}"
data-sortkey="date_modified"
title="Sort by date modified"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>

<small> </small>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
class="extra-network-control--sort-dir"
data-sortdir="{data_sortdir}"
title="Sort ascending"
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--sort-dir-icon"></i>
<i class="extra-network-control--icon extra-network-control--sort-dir-icon"></i>
</div>


<small> </small>
<div
id="{tabname}_{extra_networks_tabname}_extra_tree_view"
class="extra-network-control--tree-view {tree_view_btn_extra_class}"
title="Enable Tree View"
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--tree-view-icon"></i>
<i class="extra-network-control--icon extra-network-control--tree-view-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_refresh"
class="extra-network-control--refresh"
title="Refresh page"
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--refresh-icon"></i>
<i class="extra-network-control--icon extra-network-control--refresh-icon"></i>
</div>
</div>
{pane_content}
Expand Down
62 changes: 22 additions & 40 deletions javascript/extraNetworks.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ function setupExtraNetworksForTab(tabname) {
// tabname_full = {tabname}_{extra_networks_tabname}
var tabname_full = elem.id;
var search = gradioApp().querySelector("#" + tabname_full + "_extra_search");
var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort");
var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir");
var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh");
var currentSort = '';

// If any of the buttons above don't exist, we want to skip this iteration of the loop.
if (!search || !sort_mode || !sort_dir || !refresh) {
if (!search || !sort_dir || !refresh) {
return; // `return` is equivalent of `continue` but for forEach loops.
}

var applyFilter = function(force) {
var searchTerm = search.value.toLowerCase();
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
var searchOnly = elem.querySelector('.search_only');
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) {
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) {
return t.textContent.toLowerCase();
}).join(" ");

Expand All @@ -74,19 +74,20 @@ function setupExtraNetworksForTab(tabname) {
var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card');
var parent = gradioApp().querySelector('#' + tabname_full + "_cards");
var reverse = sort_dir.dataset.sortdir == "Descending";
var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
var activeSearchElem = gradioApp().querySelector('#' + tabname_full + "_controls .extra-network-control--sort.extra-network-control--enabled");
var sortKey = activeSearchElem ? activeSearchElem.dataset.sortkey : "default";
var sortKeyDataField = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
var sortKeyStore = sortKey + "-" + sort_dir.dataset.sortdir + "-" + cards.length;

if (sortKeyStore == sort_mode.dataset.sortkey && !force) {
if (sortKeyStore == currentSort && !force) {
return;
}
sort_mode.dataset.sortkey = sortKeyStore;
currentSort = sortKeyStore;

var sortedCards = Array.from(cards);
sortedCards.sort(function(cardA, cardB) {
var a = cardA.dataset[sortKey];
var b = cardB.dataset[sortKey];
var a = cardA.dataset[sortKeyDataField];
var b = cardB.dataset[sortKeyDataField];
if (!isNaN(a) && !isNaN(b)) {
return parseInt(a) - parseInt(b);
}
Expand Down Expand Up @@ -395,36 +396,17 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) {
}

function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {
/**
* Handles `onclick` events for the Sort Mode button.
*
* Modifies the data attributes of the Sort Mode button to cycle between
* various sorting modes.
*
* @param event The generated event.
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
*/
var curr_mode = event.currentTarget.dataset.sortmode;
var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir");
var sort_dir = el_sort_dir.dataset.sortdir;
if (curr_mode == "path") {
event.currentTarget.dataset.sortmode = "name";
event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by filename");
} else if (curr_mode == "name") {
event.currentTarget.dataset.sortmode = "date_created";
event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by date created");
} else if (curr_mode == "date_created") {
event.currentTarget.dataset.sortmode = "date_modified";
event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by date modified");
} else {
event.currentTarget.dataset.sortmode = "path";
event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by path");
}
/** Handles `onclick` events for Sort Mode buttons. */

var self = event.currentTarget;
var parent = event.currentTarget.parentElement;

parent.querySelectorAll('.extra-network-control--sort').forEach(function(x) {
x.classList.remove('extra-network-control--enabled');
});

self.classList.add('extra-network-control--enabled');

applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
}

Expand Down
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,4 @@
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")
2 changes: 1 addition & 1 deletion modules/infotext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def get_override_settings(params, *, skip_fields=None):

def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
filename = os.path.join(data_path, "params.txt")
try:
with open(filename, "r", encoding="utf8") as file:
Expand Down
3 changes: 3 additions & 0 deletions modules/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def dumpjson(self):

item_categories = {}
for item in self.data_labels.values():
if item.section[0] is None:
continue

category = categories.mapping.get(item.category_id)
category = "Uncategorized" if category is None else category.label
if category not in item_categories:
Expand Down
2 changes: 1 addition & 1 deletion modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
# strength, which is saved as "Model Strength: 1.0" in the infotext
if n == 0:
if n == 0 and not cmd_opts.no_prompt_history:
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [])
file.write(processed.infotext(p, 0))
Expand Down
4 changes: 2 additions & 2 deletions modules/sd_models_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0

width = getattr(batch, 'width', 1024)
height = getattr(batch, 'height', 1024)
width = getattr(batch, 'width', 1024) or 1024
height = getattr(batch, 'height', 1024) or 1024
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score

Expand Down
12 changes: 5 additions & 7 deletions modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,18 +569,16 @@ def create_html(self, tabname, *, empty=False):
if "user_metadata" not in item:
self.read_user_metadata(item)

data_sortdir = shared.opts.extra_networks_card_order
data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip()
data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"

show_tree = shared.opts.extra_networks_tree_view_default_enabled

page_params = {
"tabname": tabname,
"extra_networks_tabname": self.extra_networks_tabname,
"data_sortmode": data_sortmode,
"data_sortkey": data_sortkey,
"data_sortdir": data_sortdir,
"data_sortdir": shared.opts.extra_networks_card_order,
"sort_path_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Path' else '',
"sort_name_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Name' else '',
"sort_date_created_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Created' else '',
"sort_date_modified_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Modified' else '',
"tree_view_btn_extra_class": "extra-network-control--enabled" if show_tree else "",
"items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None),
"extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width,
Expand Down
10 changes: 5 additions & 5 deletions modules/ui_extra_networks_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ def write_user_metadata(self, name, metadata):
filename = item.get("filename", None)
basename, ext = os.path.splitext(filename)

with open(basename + '.json', "w", encoding="utf8") as file:
metadata_path = basename + '.json'
with open(metadata_path, "w", encoding="utf8") as file:
json.dump(metadata, file, indent=4, ensure_ascii=False)
self.page.lister.update_file_entry(metadata_path)

def save_user_metadata(self, name, desc, notes):
user_metadata = self.get_user_metadata(name)
Expand Down Expand Up @@ -185,7 +187,8 @@ def save_preview(self, index, gallery, name):
geninfo, items = images.read_info_from_image(image)

images.save_image_with_geninfo(image, geninfo, item["local_preview"])

self.page.lister.update_file_entry(item["local_preview"])
item['preview'] = self.page.find_preview(item["local_preview"])
return self.get_card_html(name), ''

def setup_ui(self, gallery):
Expand All @@ -200,6 +203,3 @@ def setup_ui(self, gallery):
inputs=[self.edit_name_input],
outputs=[]
)



17 changes: 17 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def __init__(self, dirname):
self.files = {x[0].lower(): x for x in files}
self.files_cased = {x[0]: x for x in files}

def update_entry(self, filename):
"""Add a file to the cache"""
file_path = os.path.join(self.dirname, filename)
try:
stat = os.stat(file_path)
entry = (filename, stat.st_mtime, stat.st_ctime)
self.files[filename.lower()] = entry
self.files_cased[filename] = entry
except FileNotFoundError as e:
print(f'MassFileListerCachedDir.add_entry: "{file_path}" {e}')


class MassFileLister:
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
Expand Down Expand Up @@ -136,3 +147,9 @@ def mctime(self, path):
def reset(self):
"""Clear the cache of all directories."""
self.cached_dirs.clear()

def update_file_entry(self, path):
"""Update the cache for a specific directory."""
dirname, filename = os.path.split(path)
if cached_dir := self.cached_dirs.get(dirname):
cached_dir.update_entry(filename)
Loading

0 comments on commit 9adc0d4

Please sign in to comment.