Skip to content

Commit

Permalink
feat(pipeline): add upscale image support to AI pipelines (#96)
Browse files Browse the repository at this point in the history
* add upscale image support using stabilityai/stable-diffusion-x4-upscaler model

* Add host port for upscale pipeline

* added model download to the dl_checkpoints.sh

* fix: fix some small bugs and improve formatting

This commit fixes some small bugs and improves the code formatting so
that it is more in line with the other pipelines.

* fix: ensure upscaling OpenAPI spec gets created

This commit ensures that the `gen_openapi.py` file also creates the
OpenAPI spec for the upscaling route. It also updates the Golang client
bindings.

* fix(worker): fix incorrect automatic client types

This commit ensures that the right client request and response types
are used.

---------

Co-authored-by: Elite Encoder <[email protected]>
Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
3 people authored Jun 10, 2024
1 parent 6bb526c commit 45687a7
Show file tree
Hide file tree
Showing 12 changed files with 599 additions and 35 deletions.
7 changes: 5 additions & 2 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
raise NotImplementedError("upscale pipeline not implemented")
from app.pipelines.upscale import UpscalePipeline
return UpscalePipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand All @@ -69,7 +70,9 @@ def load_route(pipeline: str) -> any:
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
raise NotImplementedError("upscale pipeline not implemented")
from app.routes import upscale

return upscale.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
120 changes: 119 additions & 1 deletion runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,123 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model

from diffusers import (
StableDiffusionUpscalePipeline
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os

from PIL import ImageFile
from PIL import Image
from io import BytesIO
import torch

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)


class UpscalePipeline(Pipeline):
pass
def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("UpscalePipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.model_id = model_id
self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
logger.warning(
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended "
"as it may lead to suboptimal performance. Please disable one of them."
)

if sfast_enabled:
logger.info(
"UpscalePipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model

self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for UpscalePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
"UpscalePipeline and will be ignored. As a result the first "
"call may be slow if 'SFAST' is enabled."
)

if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"UpscalePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache

self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"TextToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)

safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

output = self.ldm(prompt, image=image, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)

return output.images, has_nsfw_concept

def __str__(self) -> str:
return f"UpscalePipeline model_id={self.model_id}"
87 changes: 87 additions & 0 deletions runner/app/routes/upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from fastapi import Depends, APIRouter, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.pipelines.base import Pipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
from PIL import Image
from typing import Annotated
import logging
import random
import os

from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}


# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/upscale", response_model=ImageResponse, responses=responses)
@router.post(
"/upscale/",
response_model=ImageResponse,
responses=responses,
include_in_schema=False,
)
async def upscale(
prompt: Annotated[str, Form()],
image: Annotated[UploadFile, File()],
model_id: Annotated[str, Form()] = "",
safety_check: Annotated[bool, Form()] = True,
seed: Annotated[int, Form()] = None,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=401,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=400,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

seed = seed or random.randint(0, 2**32 - 1)

image = Image.open(image.file).convert("RGB")

try:
images, has_nsfw_concept = pipeline(
prompt=prompt,
image=image,
safety_check=safety_check,
seed=seed,
)
except Exception as e:
logger.error(f"UpscalePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("UpscalePipeline error")
)

seeds = [seed]

# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
output_images = [
{"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False}
for img, sd, nsfw in zip(images, seeds, has_nsfw_concept)
]

return {"images": output_images}
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ function download_alpha_models() {
huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --cache-dir models
huggingface-cli download timbrooks/instruct-pix2pix --include "*.fp16.safetensors" "*.json" "*.txt" --cache-dir models

# Download upscale models
huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models

printf "\nDownloading token-gated models...\n"

# Download image-to-video models (token-gated).
Expand Down
3 changes: 2 additions & 1 deletion runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import yaml
from app.main import app, use_route_names_as_operation_ids
from app.routes import health, image_to_image, image_to_video, text_to_image
from app.routes import health, image_to_image, image_to_video, text_to_image, upscale
from fastapi.openapi.utils import get_openapi

# Specify Endpoints for OpenAPI schema generation.
Expand Down Expand Up @@ -65,6 +65,7 @@ def write_openapi(fname, entrypoint="runner"):
app.include_router(text_to_image.router)
app.include_router(image_to_image.router)
app.include_router(image_to_video.router)
app.include_router(upscale.router)

use_route_names_as_operation_ids(app)

Expand Down
96 changes: 96 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,69 @@
}
]
}
},
"/upscale": {
"post": {
"summary": "Upscale",
"operationId": "upscale",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_upscale_upscale_post"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImageResponse"
}
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
},
"security": [
{
"HTTPBearer": []
}
]
}
}
},
"components": {
Expand Down Expand Up @@ -346,6 +409,39 @@
],
"title": "Body_image_to_video_image_to_video_post"
},
"Body_upscale_upscale_post": {
"properties": {
"prompt": {
"type": "string",
"title": "Prompt"
},
"image": {
"type": "string",
"format": "binary",
"title": "Image"
},
"model_id": {
"type": "string",
"title": "Model Id",
"default": ""
},
"safety_check": {
"type": "boolean",
"title": "Safety Check",
"default": true
},
"seed": {
"type": "integer",
"title": "Seed"
}
},
"type": "object",
"required": [
"prompt",
"image"
],
"title": "Body_upscale_upscale_post"
},
"HTTPError": {
"properties": {
"detail": {
Expand Down
3 changes: 2 additions & 1 deletion runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ xformers==0.0.23
triton>=2.1.0
peft==0.11.1
deepcache==0.1.1
safetensors==0.4.3
safetensors==0.4.3
scipy==1.13.0
9 changes: 0 additions & 9 deletions runner/test.txt

This file was deleted.

1 change: 1 addition & 0 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var containerHostPorts = map[string]string{
"text-to-image": "8000",
"image-to-image": "8001",
"image-to-video": "8002",
"upscale": "8003",
}

type DockerManager struct {
Expand Down
Loading

0 comments on commit 45687a7

Please sign in to comment.