Skip to content

Commit

Permalink
fix(runner): update img2img to do sequential processing of batch requ…
Browse files Browse the repository at this point in the history
…est (#95)

* update img2img to do sequential processing of batch request

* refactor(runner): improve consistency between I2I and T2I pipelines

This commit enhances the consistency between the I2I and T2I pipelines,
making them easier to compare.

---------

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
ad-astra-video and rickstaa authored Jun 8, 2024
1 parent 5c08bf4 commit 0f2ead8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 44 deletions.
77 changes: 34 additions & 43 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,51 +62,42 @@ async def image_to_image(
),
)

if seed is None:
seed = random.randint(0, 2**32 - 1)
if num_images_per_prompt > 1:
seed = [
i for i in range(seed, seed + num_images_per_prompt)
]
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
seeds = [seed + i for i in range(num_images_per_prompt)]

img = Image.open(image.file).convert("RGB")
# If a list of seeds/generators is passed, diffusers wants a list of images
# https://github.com/huggingface/diffusers/blob/17808a091e2d5615c2ed8a63d7ae6f2baea11e1e/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L715
if isinstance(seed, list):
image = [img] * num_images_per_prompt
else:
image = img
image = Image.open(image.file).convert("RGB")

try:
images, has_nsfw_concept = pipeline(
prompt=prompt,
image=image,
strength=strength,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
negative_prompt=negative_prompt,
safety_check=safety_check,
seed=seed,
num_images_per_prompt=num_images_per_prompt,
)
except Exception as e:
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
)

seeds = seed
if not isinstance(seeds, list):
seeds = [seeds]
# TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again
# once LIV-243 and LIV-379 are resolved.
images = []
has_nsfw_concept = []
for seed in seeds:
try:
imgs, nsfw_checks = pipeline(
prompt=prompt,
image=image,
strength=strength,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
negative_prompt=negative_prompt,
safety_check=safety_check,
seed=seed,
num_images_per_prompt=1,
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
except Exception as e:
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
)

output_images = []
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
is_nsfw = is_nsfw or False
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
output_images = [
{"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False}
for img, sd, nsfw in zip(images, seeds, has_nsfw_concept)
]

return {"images": output_images}
2 changes: 1 addition & 1 deletion runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def text_to_image(
params.num_images_per_prompt = 1
for seed in seeds:
try:
params.seed = [seed]
params.seed = seed
imgs, nsfw_check = pipeline(**params.model_dump())
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
Expand Down

0 comments on commit 0f2ead8

Please sign in to comment.