Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mass file lister as an attempt to tackle #14507 #14528

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions modules/extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def parse_prompts(prompts):
return res, extra_data


def get_user_metadata(filename):
def get_user_metadata(filename, lister=None):
if filename is None:
return {}

Expand All @@ -215,7 +215,8 @@ def get_user_metadata(filename):

metadata = {}
try:
if os.path.isfile(metadata_filename):
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
if exists:
with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file)
except Exception as e:
Expand Down
20 changes: 13 additions & 7 deletions modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib.parse
from pathlib import Path

from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util
from modules.images import read_info_from_image, save_image_with_geninfo
import gradio as gr
import json
Expand Down Expand Up @@ -107,13 +107,14 @@ def __init__(self, title):
self.allow_negative_prompt = False
self.metadata = {}
self.items = {}
self.lister = util.MassFileLister()

def refresh(self):
pass

def read_user_metadata(self, item):
filename = item.get("filename", None)
metadata = extra_networks.get_user_metadata(filename)
metadata = extra_networks.get_user_metadata(filename, lister=self.lister)

desc = metadata.get("description", None)
if desc is not None:
Expand All @@ -123,7 +124,7 @@ def read_user_metadata(self, item):

def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
mtime, _ = self.lister.mctime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"

def search_terms_from_path(self, filename, possible_directories=None):
Expand All @@ -137,6 +138,8 @@ def search_terms_from_path(self, filename, possible_directories=None):
return ""

def create_html(self, tabname):
self.lister.reset()

items_html = ''

self.metadata = {}
Expand Down Expand Up @@ -282,10 +285,10 @@ def get_sort_keys(self, path):
List of default keys used for sorting in the UI.
"""
pth = Path(path)
stat = pth.stat()
mtime, ctime = self.lister.mctime(path)
return {
"date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0),
"date_created": int(mtime),
"date_modified": int(ctime),
"name": pth.name.lower(),
"path": str(pth.parent).lower(),
}
Expand All @@ -298,7 +301,7 @@ def find_preview(self, path):
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])

for file in potential_files:
if os.path.isfile(file):
if self.lister.exists(file):
return self.link_preview(file)

return None
Expand All @@ -308,6 +311,9 @@ def find_description(self, path):
Find and read a description file for a given path (without extension).
"""
for file in [f"{path}.txt", f"{path}.description.txt"]:
if not self.lister.exists(file):
continue

try:
with open(file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
Expand Down
70 changes: 70 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,73 @@ def truncate_path(target_path, base_path=cwd):
except ValueError:
pass
return abs_target


class MassFileListerCachedDir:
"""A class that caches file metadata for a specific directory."""

def __init__(self, dirname):
self.files = None
self.files_cased = None
self.dirname = dirname

stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname))
files = [(n, s.st_mtime, s.st_ctime) for n, s in stats]
self.files = {x[0].lower(): x for x in files}
self.files_cased = {x[0]: x for x in files}


class MassFileLister:
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""

def __init__(self):
self.cached_dirs = {}

def find(self, path):
"""
Find the metadata for a file at the given path.

Returns:
tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not.
"""

dirname, filename = os.path.split(path)

cached_dir = self.cached_dirs.get(dirname)
if cached_dir is None:
cached_dir = MassFileListerCachedDir(dirname)
self.cached_dirs[dirname] = cached_dir

stats = cached_dir.files_cased.get(filename)
if stats is not None:
return stats

stats = cached_dir.files.get(filename.lower())
if stats is None:
return None

try:
os_stats = os.stat(path, follow_symlinks=False)
return filename, os_stats.st_mtime, os_stats.st_ctime
except Exception:
return None

def exists(self, path):
"""Check if a file exists at the given path."""

return self.find(path) is not None

def mctime(self, path):
"""
Get the modification and creation times for a file at the given path.

Returns:
tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not.
"""

stats = self.find(path)
return (0, 0) if stats is None else stats[1:3]

def reset(self):
"""Clear the cache of all directories."""
self.cached_dirs.clear()