forked from blib-la/runpod-worker-comfy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: wait until server is ready, wait until image generation is done…
…, upload to s3
- Loading branch information
1 parent
0d485ff
commit ecfec13
Showing
1 changed file
with
181 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,189 @@ | ||
import runpod | ||
from runpod.serverless.utils import rp_upload | ||
import json | ||
import urllib.request | ||
import urllib.parse | ||
import time | ||
import random | ||
import os | ||
import requests | ||
|
||
# Time to wait between API check attempts in milliseconds | ||
COMFY_API_AVAILABLE_INTERVAL_MS = 50 | ||
# Maximum number of API check attempts | ||
COMFY_API_AVAILABLE_MAX_RETRIES = 500 | ||
# Time to wait between poll attempts in milliseconds | ||
COMFY_POLLING_INTERVAL_MS = 250 | ||
# Maximum number of poll attempts | ||
COMFY_POLLING_MAX_RETRIES = 100 | ||
# Host where ComfyUI is running | ||
COMFY_HOST = "127.0.0.1:8188" | ||
# The path where ComfyUI stores it generated images | ||
COMFY_OUTPUT_PATH = "/comfyui/output" | ||
|
||
def hello_world(job): | ||
def check_server(url, retries=50, delay=500): | ||
""" | ||
Check if a server is reachable via HTTP GET request | ||
Args: | ||
- url (str): The URL to check | ||
- retries (int, optional): The number of times to attempt connecting to the server. Default is 50 | ||
- delay (int, optional): The time in milliseconds to wait between retries. Default is 500 | ||
Returns: | ||
bool: True if the server is reachable within the given number of retries, otherwise False | ||
""" | ||
|
||
for i in range(retries): | ||
try: | ||
response = requests.get(url) | ||
# If the response status code is 200, the server is up and running | ||
if response.status_code == 200: | ||
print(f'runpod-worker-comfy - API is reachable') | ||
return True | ||
except requests.RequestException as e: | ||
# If an exception occurs, the server may not be ready | ||
pass | ||
|
||
# Wait for the specified delay before retrying | ||
time.sleep(delay / 1000) | ||
|
||
print(f'runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts.') | ||
return False | ||
|
||
def queue_prompt(prompt): | ||
""" | ||
Queue a prompt to be processed by ComfyUI | ||
Args: | ||
prompt (dict): A dictionary containing the prompt to be processed | ||
Returns: | ||
dict: The JSON response from ComfyUI after processing the prompt | ||
""" | ||
data = json.dumps(prompt).encode('utf-8') | ||
req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) | ||
return json.loads(urllib.request.urlopen(req).read()) | ||
|
||
def get_history(prompt_id): | ||
""" | ||
Retrieve the history of a given prompt using its ID | ||
Args: | ||
prompt_id (str): The ID of the prompt whose history is to be retrieved | ||
Returns: | ||
dict: The history of the prompt, containing all the processing steps and results | ||
""" | ||
with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response: | ||
return json.loads(response.read()) | ||
|
||
# TODO: Remove if not needed | ||
def get_image(filename, subfolder, folder_type): | ||
""" | ||
Retrieve an image generated by ComfyUI. | ||
Args: | ||
filename (str): The filename of the generated image. | ||
subfolder (str): The subfolder where the image is stored. | ||
folder_type (str): The type of folder where the image is stored. | ||
Returns: | ||
bytes: The image data. | ||
""" | ||
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | ||
url_values = urllib.parse.urlencode(data) | ||
with urllib.request.urlopen(f"http://{COMFY_HOST}/view?{url_values}") as response: | ||
return response.read() | ||
|
||
def handler(job): | ||
""" | ||
The main function that handles a job of generating an image. | ||
This function validates the input, sends a prompt to ComfyUI for processing, | ||
polls ComfyUI for result, and retrieves generated images. | ||
Args: | ||
job (dict): A dictionary containing job details and input parameters. | ||
Returns: | ||
dict: A dictionary containing either an error message or a success status with generated images. | ||
""" | ||
job_input = job["input"] | ||
greeting = job_input["greeting"] | ||
prompt_text = job_input.get("prompt") | ||
|
||
# Make sure that the ComfyUI API is available | ||
check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS) | ||
|
||
# Validate input | ||
if prompt_text is None: | ||
return {"error": "Please provide the 'prompt'"} | ||
|
||
# Is JSON? | ||
if isinstance(prompt_text, dict): | ||
prompt = prompt_text | ||
# Is String? | ||
elif isinstance(prompt_text, str): | ||
try: | ||
prompt = json.loads(prompt_text) | ||
except json.JSONDecodeError: | ||
return {"error": "Invalid JSON format in 'prompt'"} | ||
else: | ||
return {"error": "'prompt' must be a JSON object or a JSON-encoded string"} | ||
|
||
# TODO: REMOVE | ||
# prompt["prompt"]["3"]["inputs"]["seed"] = random.randint(1, 10000000000) | ||
|
||
# Queue the prompt | ||
try: | ||
queued_prompt = queue_prompt(prompt) | ||
prompt_id = queued_prompt['prompt_id'] | ||
print(f'runpod-worker-comfy - queued prompt with ID {prompt_id}') | ||
except Exception as e: | ||
return {"error": f"Error queuing prompt: {str(e)}"} | ||
|
||
# Poll for completion | ||
print(f'runpod-worker-comfy - wait until image generation is complete') | ||
retries = 0 | ||
try: | ||
while retries < COMFY_POLLING_MAX_RETRIES: | ||
history = get_history(prompt_id) | ||
|
||
# Exit the loop if we have found the history | ||
if prompt_id in history and history[prompt_id].get('outputs'): | ||
break | ||
else: | ||
# Wait before trying again | ||
time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) | ||
retries += 1 | ||
else: | ||
return {"error": "Max retries reached while waiting for image generation"} | ||
except Exception as e: | ||
return {"error": f"Error waiting for image generation: {str(e)}"} | ||
|
||
# Fetching generated images | ||
output_images = {} | ||
|
||
outputs = history[prompt_id].get("outputs") | ||
|
||
if not isinstance(greeting, str): | ||
return {"error": "Please provide a String"} | ||
for node_id, node_output in outputs.items(): | ||
if 'images' in node_output: | ||
images_output = [] | ||
for image in node_output['images']: | ||
output_images = image['filename'] | ||
# output_images[node_id] = image['filename'] | ||
# image_data = get_image(image['filename'], image['subfolder'], image['type']) | ||
# images_output.append(image_data) | ||
|
||
return f"Hello {greeting}" | ||
print(f'runpod-worker-comfy - image generation is done') | ||
|
||
if os.path.exists(f"{COMFY_OUTPUT_PATH}/{output_images}"): | ||
print("runpod-worker-comfy - the image exists in the output folder") | ||
image_url = rp_upload.upload_image(job['id'], f"{COMFY_OUTPUT_PATH}/{output_images}") | ||
return {"status": "success", "message": f'{image_url}'} | ||
else: | ||
print("runpod-worker-comfy - the image does not exist in the output folder") | ||
return {"status": "error", "message": f'the image does not exist in the specified output folder: {COMFY_OUTPUT_PATH}/{output_images}'} | ||
|
||
runpod.serverless.start({"handler": hello_world}) | ||
# Start the serverless function with the defined handler. | ||
runpod.serverless.start({"handler": handler}) |