From 9d45c0e6caabf5c34129789d7e31cdd7c473ff0c Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Tue, 23 Jul 2024 18:16:23 +0200 Subject: [PATCH] feat: sort enhance images (#62) * feat: add checkbox, config and handling for saving only the final enhanced image * feat: sort output of enhance feature --- args_manager.py | 3 +++ modules/async_worker.py | 22 ++++++++++++++++------ readme.md | 1 + webui.py | 22 ++++++++++++++++++++++ 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/args_manager.py b/args_manager.py index dea851dc4..bb622c23a 100644 --- a/args_manager.py +++ b/args_manager.py @@ -28,6 +28,9 @@ args_parser.parser.add_argument("--disable-preset-download", action='store_true', help="Disables downloading models for presets", default=False) +args_parser.parser.add_argument("--disable-enhance-output-sorting", action='store_true', + help="Disables enhance output sorting for final image gallery.") + args_parser.parser.add_argument("--enable-auto-describe-image", action='store_true', help="Enables automatic description of uov and enhance image when prompt is empty", default=False) diff --git a/modules/async_worker.py b/modules/async_worker.py index e6050b806..25cfef408 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -9,7 +9,7 @@ class AsyncTask: def __init__(self, args): - from modules.flags import Performance, MetadataScheme, ip_list, controlnet_image_count + from modules.flags import Performance, MetadataScheme, ip_list, controlnet_image_count, disabled from modules.util import get_enabled_loras from modules.config import default_max_lora_number import args_manager @@ -155,7 +155,9 @@ def __init__(self, args): enhance_inpaint_erode_or_dilate, enhance_mask_invert ]) - + self.should_enhance = self.enhance_checkbox and (self.enhance_uov_method != disabled.casefold() or len(self.enhance_ctrls) > 0) + self.images_to_enhance_count = 0 + self.enhance_stats = {} async_tasks = [] @@ -1276,8 +1278,8 @@ def callback(step, x0, x, total_steps, y): int(current_progress + async_task.callback_steps), f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)]) - show_intermediate_results = len(tasks) > 1 or should_enhance - persist_image = not should_enhance or not async_task.save_final_enhanced_image_only + show_intermediate_results = len(tasks) > 1 or async_task.should_enhance + persist_image = not async_task.should_enhance or not async_task.save_final_enhanced_image_only for current_task_id, task in enumerate(tasks): progressbar(async_task, current_progress, f'Preparing task {current_task_id + 1}/{async_task.image_number} ...') @@ -1309,7 +1311,7 @@ def callback(step, x0, x, total_steps, y): execution_time = time.perf_counter() - execution_start_time print(f'Generating and saving time: {execution_time:.2f} seconds') - if not should_enhance: + if not async_task.should_enhance: print(f'[Enhance] Skipping, preconditions aren\'t met') stop_processing(async_task, processing_start_time) return @@ -1325,6 +1327,7 @@ def callback(step, x0, x, total_steps, y): enhance_uov_before = async_task.enhance_uov_processing_order == flags.enhancement_uov_before enhance_uov_after = async_task.enhance_uov_processing_order == flags.enhancement_uov_after total_count = len(images_to_enhance) * active_enhance_tabs + async_task.images_to_enhance_count = len(images_to_enhance) base_progress = current_progress current_task_id = -1 @@ -1332,7 +1335,8 @@ def callback(step, x0, x, total_steps, y): done_steps_inpainting = 0 enhance_steps, _, _, _ = apply_overrides(async_task, async_task.original_steps, height, width) exception_result = None - for img in images_to_enhance: + for index, img in enumerate(images_to_enhance): + async_task.enhance_stats[index] = 0 enhancement_image_start_time = time.perf_counter() last_enhance_prompt = async_task.prompt @@ -1346,6 +1350,8 @@ def callback(step, x0, x, total_steps, y): current_task_id, denoising_strength, done_steps_inpainting, done_steps_upscaling, enhance_steps, async_task.prompt, async_task.negative_prompt, final_scheduler_name, height, img, preparation_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width, persist_image) + async_task.enhance_stats[index] += 1 + if exception_result == 'continue': continue elif exception_result == 'break': @@ -1389,6 +1395,7 @@ def callback(step, x0, x, total_steps, y): async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)]) yield_result(async_task, mask, current_progress, async_task.black_out_nsfw, False, async_task.disable_intermediate_results) + async_task.enhance_stats[index] += 1 print(f'[Enhance] {dino_detection_count} boxes detected') print(f'[Enhance] {sam_detection_count} segments detected in boxes') @@ -1408,6 +1415,7 @@ def callback(step, x0, x, total_steps, y): enhance_prompt, enhance_negative_prompt, final_scheduler_name, goals_enhance, height, img, mask, preparation_steps, enhance_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width, persist_image=persist_image) + async_task.enhance_stats[index] += 1 if (should_process_enhance_uov and async_task.enhance_uov_processing_order == flags.enhancement_uov_after and async_task.enhance_uov_prompt_type == flags.enhancement_uov_prompt_type_last_filled): @@ -1444,6 +1452,8 @@ def callback(step, x0, x, total_steps, y): last_enhance_prompt, last_enhance_negative_prompt, final_scheduler_name, height, img, preparation_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width, persist_image) + async_task.enhance_stats[index] += 1 + if exception_result == 'continue': continue elif exception_result == 'break': diff --git a/readme.md b/readme.md index ca3dd66e0..cedccd28d 100644 --- a/readme.md +++ b/readme.md @@ -598,6 +598,7 @@ entry_with_update.py [-h] [--listen [IP]] [--port PORT] [--disable-offload-from-vram] [--theme THEME] [--disable-image-log] [--disable-analytics] [--disable-metadata] [--disable-preset-download] + [--disable-enhance-output-sorting] [--enable-auto-describe-image] [--always-download-new-model] [--rebuild-hash-cache [CPU_NUM_THREADS]] diff --git a/webui.py b/webui.py index 467594ddc..b02546223 100644 --- a/webui.py +++ b/webui.py @@ -73,6 +73,9 @@ def generate_clicked(task: worker.AsyncTask): gr.update(visible=True, value=product), \ gr.update(visible=False) if flag == 'finish': + if not args_manager.args.disable_enhance_output_sorting: + product = sort_enhance_images(product, task) + yield gr.update(visible=False), \ gr.update(visible=False), \ gr.update(visible=False), \ @@ -90,6 +93,25 @@ def generate_clicked(task: worker.AsyncTask): return +def sort_enhance_images(images, task): + if not task.should_enhance or len(images) <= task.images_to_enhance_count: + return images + + sorted_images = [] + walk_index = task.images_to_enhance_count + + for index, enhanced_img in enumerate(images[:task.images_to_enhance_count]): + sorted_images.append(enhanced_img) + if index not in task.enhance_stats: + continue + target_index = walk_index + task.enhance_stats[index] + if walk_index < len(images) and target_index <= len(images): + sorted_images += images[walk_index:target_index] + walk_index += task.enhance_stats[index] + + return sorted_images + + def inpaint_mode_change(mode, inpaint_engine_version): assert mode in modules.flags.inpaint_options