Skip to content

Commit

Permalink
task manager added
Browse files Browse the repository at this point in the history
based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py

 * classified
 * use non blocking asyncio.sleep
 * this way, gc.collect() will work as intended.
  • Loading branch information
wkpark committed Sep 28, 2024
1 parent e8a6c28 commit 5b41914
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 7 deletions.
8 changes: 6 additions & 2 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import html
import time

from modules import shared, progress, errors, devices, fifo_lock, profiling
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager

queue_lock = fifo_lock.FIFOLock()

Expand Down Expand Up @@ -34,7 +34,7 @@ def f(*args, **kwargs):
progress.start_task(id_task)

try:
res = func(*args, **kwargs)
res = manager.task.run_and_wait_result(func, *args, **kwargs)
progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
Expand Down Expand Up @@ -73,6 +73,10 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
try:
res = list(func(*args, **kwargs))
except Exception as e:
if manager.task.last_exception is not None:
e = manager.task.last_exception
else:
pass
# When printing out our debug argument list,
# do not print out more than a 100 KB of text
max_debug_str_len = 131072
Expand Down
6 changes: 6 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,17 @@ def configure_for_tests():
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
import webui

from modules import manager

if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()

manager.task.main_loop()
return


def dump_sysinfo():
from modules import sysinfo
Expand Down
81 changes: 81 additions & 0 deletions modules/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
#
# Original author comment:
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.
#
# 2024/09/28 classified,

import asyncio
import random
import string
import time
import threading
import traceback

from collections import OrderedDict


class Task:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)


class TaskManager:
last_exception = None
pending_tasks = []
finished_tasks = OrderedDict()
lock = None

def __init__(self):
self.lock = threading.Lock()

def work(self, task):
try:
task.result = task.func(*task.args, **task.kwargs)
except Exception as e:
traceback.print_exc()
print(e)
task.exception = e
self.last_exception = e


def main_loop(self):
loop = asyncio.get_event_loop()
while True:
loop.run_until_complete(asyncio.sleep(0.01))
if len(self.pending_tasks) > 0:
with self.lock:
task = self.pending_tasks.pop(0)

self.work(task)

self.finished_tasks[task.task_id] = task


def push_task(self, func, *args, **kwargs):
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
task_id = args[0]
else:
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
self.pending_tasks.append(task)

return task.task_id


def run_and_wait_result(self, func, *args, **kwargs):
current_id = self.push_task(func, *args, **kwargs)

loop = asyncio.new_event_loop()
while True:
loop.run_until_complete(asyncio.sleep(0.01))
if current_id in self.finished_tasks:
finished = self.finished_tasks.pop(current_id)

return finished.result


task = TaskManager()
20 changes: 15 additions & 5 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modules import timer
from modules import initialize_util
from modules import initialize
from threading import Thread

startup_timer = timer.startup_timer
startup_timer.record("launcher")
Expand All @@ -14,6 +15,8 @@

initialize.check_versions()

initialize.initialize()


def create_api(app):
from modules.api.api import Api
Expand All @@ -23,12 +26,10 @@ def create_api(app):
return api


def api_only():
def _api_only():
from fastapi import FastAPI
from modules.shared_cmd_options import cmd_opts

initialize.initialize()

app = FastAPI()
initialize_util.setup_middleware(app)
api = create_api(app)
Expand All @@ -45,11 +46,10 @@ def api_only():
)


def webui():
def _webui():
from modules.shared_cmd_options import cmd_opts

launch_api = cmd_opts.api
initialize.initialize()

from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks

Expand Down Expand Up @@ -153,10 +153,20 @@ def webui():
initialize.initialize_rest(reload_script_modules=True)


def api_only():
Thread(target=_api_only, daemon=True).start()


def webui():
Thread(target=_webui, daemon=True).start()

if __name__ == "__main__":
from modules.shared_cmd_options import cmd_opts
from modules import manager

if cmd_opts.nowebui:
api_only()
else:
webui()

manager.task.main_loop()

0 comments on commit 5b41914

Please sign in to comment.