Skip to content

Commit

Permalink
feat: added unit tests for everthing, refactored the code to make it …
Browse files Browse the repository at this point in the history
…better testable, added test images
  • Loading branch information
TimPietrusky committed Nov 17, 2023
1 parent dc92334 commit a7492ec
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 38 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
venv
.env
data
models
models
simulated_uploaded
__pycache__
Empty file added src/__init__.py
Empty file.
117 changes: 80 additions & 37 deletions src/rp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
COMFY_POLLING_MAX_RETRIES = 100
# Host where ComfyUI is running
COMFY_HOST = "127.0.0.1:8188"
# The path where ComfyUI stores the generated images
COMFY_OUTPUT_PATH = "/comfyui/output"


def check_server(url, retries=50, delay=500):
"""
Expand All @@ -38,6 +35,7 @@ def check_server(url, retries=50, delay=500):
for i in range(retries):
try:
response = requests.get(url)

# If the response status code is 200, the server is up and running
if response.status_code == 200:
print(f"runpod-worker-comfy - API is reachable")
Expand Down Expand Up @@ -87,10 +85,82 @@ def get_history(prompt_id):
def base64_encode(img_path):
"""
Returns base64 encoded image.
Args:
img_path (str): The path to the image
Returns:
str: The base64 encoded image
"""
with open(img_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("utf-8")
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:image/png;base64,{encoded_string}"

def process_output_images(outputs, job_id):
"""
This function takes the "outputs" from image generation and the job ID,
then determines the correct way to return the image, either as a direct URL
to an AWS S3 bucket or as a base64 encoded string, depending on the
environment configuration.
Args:
outputs (dict): A dictionary containing the outputs from image generation,
typically includes node IDs and their respective output data.
job_id (str): The unique identifier for the job.
Returns:
dict: A dictionary with the status ('success' or 'error') and the message,
which is either the URL to the image in the AWS S3 bucket or a base64
encoded string of the image. In case of error, the message details the issue.
The function works as follows:
- It first determines the output path for the images from an environment variable,
defaulting to "/comfyui/output" if not set.
- It then iterates through the outputs to find the filenames of the generated images.
- After confirming the existence of the image in the output folder, it checks if the
AWS S3 bucket is configured via the BUCKET_ENDPOINT_URL environment variable.
- If AWS S3 is configured, it uploads the image to the bucket and returns the URL.
- If AWS S3 is not configured, it encodes the image in base64 and returns the string.
- If the image file does not exist in the output folder, it returns an error status
with a message indicating the missing image file.
"""

# The path where ComfyUI stores the generated images
COMFY_OUTPUT_PATH = os.environ.get('COMFY_OUTPUT_PATH', "/comfyui/output")

output_images = {}

for node_id, node_output in outputs.items():
if "images" in node_output:
for image in node_output["images"]:
output_images = image["filename"]

print(f"runpod-worker-comfy - image generation is done")

# expected image output folder
local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}"

# The image is in the output folder
if os.path.exists(local_image_path):
print("runpod-worker-comfy - the image exists in the output folder")

if os.environ.get('BUCKET_ENDPOINT_URL', False):
# URL to image in AWS S3
image = rp_upload.upload_image(job_id, local_image_path)
else:
# base64 image
image = base64_encode(local_image_path)

return {
"status": "success",
"message": image,
}
else:
print("runpod-worker-comfy - the image does not exist in the output folder")
return {
"status": "error",
"message": f"the image does not exist in the specified output folder: {local_image_path}",
}


def handler(job):
Expand Down Expand Up @@ -158,37 +228,10 @@ def handler(job):
except Exception as e:
return {"error": f"Error waiting for image generation: {str(e)}"}

# Fetching generated images
output_images = {}

outputs = history[prompt_id].get("outputs")

for node_id, node_output in outputs.items():
if "images" in node_output:
for image in node_output["images"]:
output_images = image["filename"]

print(f"runpod-worker-comfy - image generation is done")

# expected image output folder
local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}"
# The image is in the output folder
if os.path.exists(local_image_path):
print("runpod-worker-comfy - the image exists in the output folder")
image_url = rp_upload.upload_image(job["id"], local_image_path)
return_base64 = "simulated_uploaded/" in image_url
return_output = f"{image_url}" if not return_base64 else base64_encode(local_image_path)
return {
"status": "success",
"message": return_output,
}
else:
print("runpod-worker-comfy - the image does not exist in the output folder")
return {
"status": "error",
"message": f"the image does not exist in the specified output folder: {local_image_path}",
}
# Get the generated image and return it as URL in an AWS bucket or as base64
process_output_images(history[prompt_id].get("outputs"), job[id])


# Start the handler
runpod.serverless.start({"handler": handler})
# Start the handler only if this script is run directly
if __name__ == "__main__":
runpod.serverless.start({"handler": handler})
Binary file added test_resources/images/ComfyUI_00001_.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added tests/__init__.py
Empty file.
130 changes: 130 additions & 0 deletions tests/test_rp_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import unittest
from unittest.mock import patch, MagicMock, mock_open, Mock
import sys
import os
import json

# Make sure that "src" is known and can be used to import rp_handler.py
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
from src import rp_handler

# Local folder for test resources
RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES = "./test_resources/images"

class TestRunpodWorkerComfy(unittest.TestCase):
@patch('rp_handler.requests.get')
def test_check_server_server_up(self, mock_requests):
mock_response = MagicMock()
mock_response.status_code = 200
mock_requests.return_value = mock_response

result = rp_handler.check_server('http://127.0.0.1:8188', 1, 50)
self.assertTrue(result)

@patch('rp_handler.requests.get')
def test_check_server_server_down(self, mock_requests):
mock_requests.get.side_effect = rp_handler.requests.RequestException()
result = rp_handler.check_server('http://127.0.0.1:8188', 1, 50)
self.assertFalse(result)

@patch('rp_handler.urllib.request.urlopen')
def test_queue_prompt(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({"prompt_id": "123"}).encode()
mock_urlopen.return_value = mock_response
result = rp_handler.queue_prompt({"prompt": "test"})
self.assertEqual(result, {"prompt_id": "123"})

@patch('rp_handler.urllib.request.urlopen')
def test_get_history(self, mock_urlopen):
# Mock response data as a JSON string
mock_response_data = json.dumps({"key": "value"}).encode('utf-8')

# Define a mock response function for `read`
def mock_read():
return mock_response_data

# Create a mock response object
mock_response = Mock()
mock_response.read = mock_read

# Mock the __enter__ and __exit__ methods to support the context manager
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = Mock()

# Set the return value of the urlopen mock
mock_urlopen.return_value = mock_response

# Call the function under test
result = rp_handler.get_history("123")

# Assertions
self.assertEqual(result, {"key": "value"})
mock_urlopen.assert_called_with("http://127.0.0.1:8188/history/123")

@patch('builtins.open', new_callable=mock_open, read_data=b'test')
def test_base64_encode(self, mock_file):
result = rp_handler.base64_encode("dummy_path")
self.assertTrue(result.startswith("data:image/png;base64,"))

@patch('rp_handler.os.path.exists')
@patch('rp_handler.rp_upload.upload_image')
@patch.dict(os.environ, {'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES})
def test_bucket_endpoint_not_configured(self, mock_upload_image, mock_exists):
mock_exists.return_value = True
mock_upload_image.return_value = 'simulated_uploaded/image.png'

outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}}
job_id = '123'

result = rp_handler.process_output_images(outputs, job_id)

self.assertEqual(result['status'], 'success')
self.assertTrue(result['message'].startswith("data:image/png;base64,"))

@patch('rp_handler.os.path.exists')
@patch('rp_handler.rp_upload.upload_image')
@patch.dict(os.environ, {'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, 'BUCKET_ENDPOINT_URL': 'http://example.com'})
def test_bucket_endpoint_configured(self, mock_upload_image, mock_exists):
# Mock the os.path.exists to return True, simulating that the image exists
mock_exists.return_value = True

# Mock the rp_upload.upload_image to return a simulated URL
mock_upload_image.return_value = 'http://example.com/uploaded/image.png'

# Define the outputs and job_id for the test
outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}}
job_id = '123'

# Call the function under test
result = rp_handler.process_output_images(outputs, job_id)

# Assertions
self.assertEqual(result['status'], 'success')
self.assertEqual(result['message'], 'http://example.com/uploaded/image.png')
mock_upload_image.assert_called_once_with(job_id, './test_resources/images/ComfyUI_00001_.png')


@patch('rp_handler.os.path.exists')
@patch('rp_handler.rp_upload.upload_image')
@patch.dict(os.environ, {
'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES,
'BUCKET_ENDPOINT_URL': 'http://example.com',
'BUCKET_ACCESS_KEY_ID': '',
'BUCKET_SECRET_ACCESS_KEY': ''
})
def test_bucket_image_upload_fails_env_vars_wrong_or_missing(self, mock_upload_image, mock_exists):
# Simulate the file existing in the output path
mock_exists.return_value = True

# When AWS credentials are wrong or missing, upload_image should return 'simulated_uploaded/...'
mock_upload_image.return_value = 'simulated_uploaded/image.png'

outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}}
job_id = '123'

result = rp_handler.process_output_images(outputs, job_id)

# Check if the image was saved to the 'simulated_uploaded' directory
self.assertIn('simulated_uploaded', result['message'])
self.assertEqual(result['status'], 'success')

0 comments on commit a7492ec

Please sign in to comment.