Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(runner): add InferenceError to all pipelines #188

Merged
merged 15 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils.audio import AudioConverter
from app.utils.errors import InferenceError
from fastapi import File, UploadFile
from huggingface_hub import file_download
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
Expand Down Expand Up @@ -76,7 +77,12 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]:
converted_bytes = audio_converter.convert(audio, "mp3")
audio_converter.write_bytes_to_file(converted_bytes, audio)

return self.tm(audio.file.read(), **kwargs)
try:
outputs = self.tm(audio.file.read(), **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return outputs

def __str__(self) -> str:
return f"AudioToTextPipeline model_id={self.model_id}"
12 changes: 8 additions & 4 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_lightning_model,
is_turbo_model,
)
from app.utils.errors import InferenceError
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
Expand Down Expand Up @@ -233,14 +234,17 @@ def __call__(
# Default to 8step
kwargs["num_inference_steps"] = 8

output = self.ldm(prompt, image=image, **kwargs)
try:
outputs = self.ldm(prompt, image=image, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"
8 changes: 7 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile
Expand Down Expand Up @@ -135,7 +136,12 @@ def __call__(
else:
has_nsfw_concept = [None]

return self.ldm(image, **kwargs).frames, has_nsfw_concept
try:
outputs = self.ldm(image, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return outputs.frames, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"
2 changes: 1 addition & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def compile_model(pipe):
except ImportError:
logger.info("xformers not installed, skip")
try:
import triton # noqa: F401
import triton # noqa: F401

config.enable_triton = True
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import PIL
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from app.pipelines.utils import get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

Expand Down
12 changes: 8 additions & 4 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_turbo_model,
split_prompt,
)
from app.utils.errors import InferenceError
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
Expand Down Expand Up @@ -274,14 +275,17 @@ def __call__(
)
kwargs.update(neg_prompts)

output = self.ldm(prompt=prompt, **kwargs)
try:
outputs = self.ldm(prompt=prompt, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"
14 changes: 10 additions & 4 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
is_lightning_model,
is_turbo_model,
)
from app.utils.errors import InferenceError
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile
Expand Down Expand Up @@ -114,14 +115,19 @@ def __call__(
):
del kwargs["num_inference_steps"]

output = self.ldm(prompt, image=image, **kwargs)
try:
outputs = self.ldm(prompt, image=image, **kwargs)
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"UpscalePipeline model_id={self.model_id}"
2 changes: 2 additions & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from app.pipelines.utils.utils import (
LoraLoader,
LoraLoadingError,
SafetyChecker,
get_model_dir,
get_model_path,
get_torch_device,
is_lightning_model,
is_turbo_model,
is_numeric,
split_prompt,
validate_torch_device,
)
2 changes: 1 addition & 1 deletion runner/app/pipelines/utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class AudioConversionError(Exception):
"""Raised when an audio file cannot be converted."""

def __init__(self, message="Audio conversion failed."):
def __init__(self, message="Audio conversion failed"):
self.message = message
super().__init__(self.message)

Expand Down
67 changes: 35 additions & 32 deletions runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import os
from typing import Annotated
from typing import Annotated, Dict, Tuple, Union

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.audio import AudioConversionError
from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error
from app.routes.utils import (
HTTPError,
TextResponse,
file_exceeds_max_size,
http_error,
handle_pipeline_exception,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -14,6 +20,20 @@

logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration.
AUDIO_FORMAT_ERROR_MESSAGE = "Unsupported audio format or malformed file."
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"AudioConversionError": (
AUDIO_FORMAT_ERROR_MESSAGE,
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
),
"Soundfile is either not in the correct format or is malformed": (
AUDIO_FORMAT_ERROR_MESSAGE,
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
),
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
Expand All @@ -27,35 +47,11 @@
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


def handle_pipeline_error(e: Exception) -> JSONResponse:
"""Handles exceptions raised during audio processing.

Args:
e: The exception raised during audio processing.

Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"Audio processing error: {str(e)}") # Log the detailed error
if "Soundfile is either not in the correct format or is malformed" in str(
e
) or isinstance(e, AudioConversionError):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
error_message = "Unsupported audio format or malformed file."
else:
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
error_message = "Internal server error during audio processing."

return JSONResponse(
status_code=status_code,
content=http_error(error_message),
)


@router.post(
"/audio-to-text",
response_model=TextResponse,
Expand Down Expand Up @@ -89,25 +85,32 @@ async def audio_to_text(
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
content=http_error("Invalid bearer token."),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
f"{model_id}."
),
)

if file_exceeds_max_size(audio, 50 * 1024 * 1024):
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content=http_error("File size exceeds limit"),
content=http_error("File size exceeds limit."),
)

try:
return pipeline(audio=audio)
except Exception as e:
return handle_pipeline_error(e)
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"AudioToText pipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Audio-to-text pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
44 changes: 26 additions & 18 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import os
import random
from typing import Annotated
from typing import Annotated, Dict, Tuple, Union

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.utils import LoraLoadingError
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.routes.utils import (
HTTPError,
ImageResponse,
http_error,
image_to_data_url,
handle_pipeline_exception,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -20,6 +25,15 @@
logger = logging.getLogger(__name__)


# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"OutOfMemoryError": (
"Out of memory error. Try reducing input image resolution.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
Expand Down Expand Up @@ -144,15 +158,15 @@ async def image_to_image(
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
content=http_error("Invalid bearer token."),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
f"{model_id}."
),
)

Expand Down Expand Up @@ -180,23 +194,17 @@ async def image_to_image(
num_images_per_prompt=1,
num_inference_steps=num_inference_steps,
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
except LoraLoadingError as e:
logger.error(f"ImageToImagePipeline error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToImagePipeline error"),
logger.error(f"ImageToImagePipeline pipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Image-to-image pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)

# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
Expand Down
Loading