-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(pipeline): add upscale image support to AI pipelines (#96)
* add upscale image support using stabilityai/stable-diffusion-x4-upscaler model * Add host port for upscale pipeline * added model download to the dl_checkpoints.sh * fix: fix some small bugs and improve formatting This commit fixes some small bugs and improves the code formatting so that it is more in line with the other pipelines. * fix: ensure upscaling OpenAPI spec gets created This commit ensures that the `gen_openapi.py` file also creates the OpenAPI spec for the upscaling route. It also updates the Golang client bindings. * fix(worker): fix incorrect automatic client types This commit ensures that the right client request and response types are used. --------- Co-authored-by: Elite Encoder <[email protected]> Co-authored-by: Rick Staa <[email protected]>
- Loading branch information
1 parent
6bb526c
commit 45687a7
Showing
12 changed files
with
599 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,123 @@ | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model | ||
|
||
from diffusers import ( | ||
StableDiffusionUpscalePipeline | ||
) | ||
from safetensors.torch import load_file | ||
from huggingface_hub import file_download, hf_hub_download | ||
import torch | ||
import PIL | ||
from typing import List, Tuple, Optional | ||
import logging | ||
import os | ||
|
||
from PIL import ImageFile | ||
from PIL import Image | ||
from io import BytesIO | ||
import torch | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class UpscalePipeline(Pipeline): | ||
pass | ||
def __init__(self, model_id: str): | ||
kwargs = {"cache_dir": get_model_dir()} | ||
|
||
torch_device = get_torch_device() | ||
folder_name = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model" | ||
) | ||
folder_path = os.path.join(get_model_dir(), folder_name) | ||
has_fp16_variant = any( | ||
".fp16.safetensors" in fname | ||
for _, _, files in os.walk(folder_path) | ||
for fname in files | ||
) | ||
if torch_device != "cpu" and has_fp16_variant: | ||
logger.info("UpscalePipeline loading fp16 variant for %s", model_id) | ||
|
||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
|
||
self.model_id = model_id | ||
self.ldm = StableDiffusionUpscalePipeline.from_pretrained( | ||
model_id, **kwargs | ||
).to(torch_device) | ||
|
||
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" | ||
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true" | ||
if sfast_enabled and deepcache_enabled: | ||
logger.warning( | ||
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended " | ||
"as it may lead to suboptimal performance. Please disable one of them." | ||
) | ||
|
||
if sfast_enabled: | ||
logger.info( | ||
"UpscalePipeline will be dynamically compiled with stable-fast " | ||
"for %s", | ||
model_id, | ||
) | ||
from app.pipelines.optim.sfast import compile_model | ||
|
||
self.ldm = compile_model(self.ldm) | ||
|
||
# Warm-up the pipeline. | ||
# TODO: Not yet supported for UpscalePipeline. | ||
if os.getenv("SFAST_WARMUP", "true").lower() == "true": | ||
logger.warning( | ||
"The 'SFAST_WARMUP' flag is not yet supported for the " | ||
"UpscalePipeline and will be ignored. As a result the first " | ||
"call may be slow if 'SFAST' is enabled." | ||
) | ||
|
||
if deepcache_enabled and not ( | ||
is_lightning_model(model_id) or is_turbo_model(model_id) | ||
): | ||
logger.info( | ||
"UpscalePipeline will be optimized with DeepCache for %s", | ||
model_id, | ||
) | ||
from app.pipelines.optim.deepcache import enable_deepcache | ||
|
||
self.ldm = enable_deepcache(self.ldm) | ||
elif deepcache_enabled: | ||
logger.warning( | ||
"DeepCache is not supported for Lightning or Turbo models. " | ||
"TextToImagePipeline will NOT be optimized with DeepCache for %s", | ||
model_id, | ||
) | ||
|
||
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower() | ||
self._safety_checker = SafetyChecker(device=safety_checker_device) | ||
|
||
def __call__( | ||
self, prompt: str, image: PIL.Image, **kwargs | ||
) -> Tuple[List[PIL.Image], List[Optional[bool]]]: | ||
seed = kwargs.pop("seed", None) | ||
safety_check = kwargs.pop("safety_check", True) | ||
|
||
if seed is not None: | ||
if isinstance(seed, int): | ||
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed( | ||
seed | ||
) | ||
elif isinstance(seed, list): | ||
kwargs["generator"] = [ | ||
torch.Generator(get_torch_device()).manual_seed(s) for s in seed | ||
] | ||
|
||
output = self.ldm(prompt, image=image, **kwargs) | ||
|
||
if safety_check: | ||
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) | ||
else: | ||
has_nsfw_concept = [None] * len(output.images) | ||
|
||
return output.images, has_nsfw_concept | ||
|
||
def __str__(self) -> str: | ||
return f"UpscalePipeline model_id={self.model_id}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from fastapi import Depends, APIRouter, UploadFile, File, Form | ||
from fastapi.responses import JSONResponse | ||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | ||
from app.pipelines.base import Pipeline | ||
from app.dependencies import get_pipeline | ||
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error | ||
from PIL import Image | ||
from typing import Annotated | ||
import logging | ||
import random | ||
import os | ||
|
||
from PIL import ImageFile | ||
|
||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
router = APIRouter() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}} | ||
|
||
|
||
# TODO: Make model_id and other None properties optional once Go codegen tool supports | ||
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 | ||
@router.post("/upscale", response_model=ImageResponse, responses=responses) | ||
@router.post( | ||
"/upscale/", | ||
response_model=ImageResponse, | ||
responses=responses, | ||
include_in_schema=False, | ||
) | ||
async def upscale( | ||
prompt: Annotated[str, Form()], | ||
image: Annotated[UploadFile, File()], | ||
model_id: Annotated[str, Form()] = "", | ||
safety_check: Annotated[bool, Form()] = True, | ||
seed: Annotated[int, Form()] = None, | ||
pipeline: Pipeline = Depends(get_pipeline), | ||
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), | ||
): | ||
auth_token = os.environ.get("AUTH_TOKEN") | ||
if auth_token: | ||
if not token or token.credentials != auth_token: | ||
return JSONResponse( | ||
status_code=401, | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
content=http_error("Invalid bearer token"), | ||
) | ||
|
||
if model_id != "" and model_id != pipeline.model_id: | ||
return JSONResponse( | ||
status_code=400, | ||
content=http_error( | ||
f"pipeline configured with {pipeline.model_id} but called with " | ||
f"{model_id}" | ||
), | ||
) | ||
|
||
seed = seed or random.randint(0, 2**32 - 1) | ||
|
||
image = Image.open(image.file).convert("RGB") | ||
|
||
try: | ||
images, has_nsfw_concept = pipeline( | ||
prompt=prompt, | ||
image=image, | ||
safety_check=safety_check, | ||
seed=seed, | ||
) | ||
except Exception as e: | ||
logger.error(f"UpscalePipeline error: {e}") | ||
logger.exception(e) | ||
return JSONResponse( | ||
status_code=500, content=http_error("UpscalePipeline error") | ||
) | ||
|
||
seeds = [seed] | ||
|
||
# TODO: Return None once Go codegen tool supports optional properties | ||
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 | ||
output_images = [ | ||
{"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False} | ||
for img, sd, nsfw in zip(images, seeds, has_nsfw_concept) | ||
] | ||
|
||
return {"images": output_images} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.