From 21b8d4f25dad1e8b019dec8a9da9526afa00e079 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Wed, 5 Jun 2024 23:23:03 +0200 Subject: [PATCH] refactor(runner): improve consistency between I2I and T2I pipelines This commit enhances the consistency between the I2I and T2I pipelines, making them easier to compare. --- runner/app/routes/image_to_image.py | 64 ++++++++++++----------------- runner/app/routes/text_to_image.py | 2 +- 2 files changed, 27 insertions(+), 39 deletions(-) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 319e85bb..df8ac336 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -62,24 +62,20 @@ async def image_to_image( ), ) - seeds = [] - if seed is None: - seeds = [random.randint(0, 2**32 - 1)] - if num_images_per_prompt > 1: - seeds = [ - i for i in range(seeds[0], seeds[0] + num_images_per_prompt) - ] + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + seeds = [seed + i for i in range(num_images_per_prompt)] - img = Image.open(image.file).convert("RGB") - - try: - images = [] - has_nsfw_concept = [] - - for seed in seeds: - image_out, nsfw = pipeline( + image = Image.open(image.file).convert("RGB") + + # TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again + # once LIV-243 and LIV-379 are resolved. + images = [] + has_nsfw_concept = [] + for seed in seeds: + try: + imgs, nsfw_checks = pipeline( prompt=prompt, - image=img, + image=image, strength=strength, guidance_scale=guidance_scale, image_guidance_scale=image_guidance_scale, @@ -88,28 +84,20 @@ async def image_to_image( seed=seed, num_images_per_prompt=1, ) - - images.extend(image_out) - has_nsfw_concept.extend(nsfw) - - except Exception as e: - logger.error(f"ImageToImagePipeline error: {e}") - logger.exception(e) - return JSONResponse( - status_code=500, content=http_error("ImageToImagePipeline error") - ) - - seeds = seed - if not isinstance(seeds, list): - seeds = [seeds] + images.extend(imgs) + has_nsfw_concept.extend(nsfw_checks) + except Exception as e: + logger.error(f"ImageToImagePipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=500, content=http_error("ImageToImagePipeline error") + ) - output_images = [] - for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): - # TODO: Return None once Go codegen tool supports optional properties - # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 - is_nsfw = is_nsfw or False - output_images.append( - {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} - ) + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + output_images = [ + {"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False} + for img, sd, nsfw in zip(images, seeds, has_nsfw_concept) + ] return {"images": output_images} diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 51a28f71..3f52e36d 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -67,7 +67,7 @@ async def text_to_image( params.num_images_per_prompt = 1 for seed in seeds: try: - params.seed = [seed] + params.seed = seed imgs, nsfw_check = pipeline(**params.model_dump()) images.extend(imgs) has_nsfw_concept.extend(nsfw_check)