Skip to content

Commit

Permalink
refactor: improve out of memory error handling
Browse files Browse the repository at this point in the history
This commit improves the out of memory error handling by using the
native torch error.
  • Loading branch information
rickstaa committed Oct 14, 2024
1 parent c7b759a commit 53d76eb
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import PIL
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils import get_torch_device, get_model_dir
from app.utils.errors import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor
Expand Down
2 changes: 2 additions & 0 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __call__(

try:
outputs = self.ldm(prompt, image=image, **kwargs)
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)

Expand Down
6 changes: 4 additions & 2 deletions runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Annotated

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.audio import AudioConversionError
Expand Down Expand Up @@ -42,15 +43,15 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"AudioToText pipeline error: {str(e)}") # Log the detailed error
if "Soundfile is either not in the correct format or is malformed" in str(
e
) or isinstance(e, AudioConversionError):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
error_message = "Unsupported audio format or malformed file."
elif "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError):
elif isinstance(e, torch.cuda.OutOfMemoryError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = "Out of memory error."
torch.cuda.empty_cache()
elif isinstance(e, InferenceError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = str(e)
Expand Down Expand Up @@ -118,4 +119,5 @@ async def audio_to_text(
try:
return pipeline(audio=audio)
except Exception as e:
logger.error(f"AudioToText pipeline error: {str(e)}")
return handle_pipeline_error(e)
7 changes: 2 additions & 5 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(
f"ImageToImagePipeline pipeline error: {str(e)}"
) # Log the detailed error
logger.exception(e) # TODO: Check if needed.
if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError) or isinstance(torch.cuda.OutOfMemoryError): # TODO: simplify condition.
if isinstance(e, torch.cuda.OutOfMemoryError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = "Out of memory error. Try reducing input image resolution."
torch.cuda.empty_cache()
Expand Down Expand Up @@ -215,6 +211,7 @@ async def image_to_image(
num_inference_steps=num_inference_steps,
)
except Exception as e:
logger.error(f"ImageToImagePipeline pipeline error: {str(e)}")
return handle_pipeline_error(e)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
Expand Down
8 changes: 5 additions & 3 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Annotated

from app.dependencies import get_pipeline
import torch
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url
from app.utils.errors import InferenceError, OutOfMemoryError
from app.utils.errors import InferenceError
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -28,12 +29,12 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"ImageToVideo pipeline error: {str(e)}") # Log the detailed error
if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError):
if isinstance(e, torch.cuda.OutOfMemoryError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = (
"Out of memory error. Try reducing input or output video resolution."
)
torch.cuda.empty_cache()
elif isinstance(e, InferenceError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = str(e)
Expand Down Expand Up @@ -181,6 +182,7 @@ async def image_to_video(
seed=seed,
)
except Exception as e:
logger.error(f"ImageToVideo pipeline error: {str(e)}")
return handle_pipeline_error(e)

output_frames = []
Expand Down
6 changes: 4 additions & 2 deletions runner/app/routes/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Annotated

import numpy as np
import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array
Expand All @@ -28,10 +29,10 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"SegmentAnything2 pipeline error: {str(e)}") # Log the detailed error
if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError):
if isinstance(e, torch.cuda.OutOfMemoryError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = "Out of memory error. Try reducing input image resolution."
torch.cuda.empty_cache()
elif isinstance(e, InferenceError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = str(e)
Expand Down Expand Up @@ -192,6 +193,7 @@ async def segment_anything_2(
normalize_coords=normalize_coords,
)
except Exception as e:
logger.error(f"SegmentAnything2 pipeline error: {str(e)}")
return handle_pipeline_error(e)

# Return masks sorted by descending score as string.
Expand Down
6 changes: 3 additions & 3 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url
from app.utils.errors import InferenceError, OutOfMemoryError
from app.utils.errors import InferenceError
from app.pipelines.utils.utils import LoraLoadingError
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
Expand All @@ -28,8 +28,7 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"TextToImage pipeline error: {str(e)}") # Log the detailed error
if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError) or isinstance(e, torch.cuda.OutOfMemoryError): # TODO: Simplify.
if isinstance(e, torch.cuda.OutOfMemoryError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = "Out of memory error. Try reducing output image resolution."
torch.cuda.empty_cache()
Expand Down Expand Up @@ -204,6 +203,7 @@ async def text_to_image(
try:
imgs, nsfw_check = pipeline(**kwargs)
except Exception as e:
logger.error(f"TextToImage pipeline error: {str(e)}")
return handle_pipeline_error(e)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
Expand Down
8 changes: 5 additions & 3 deletions runner/app/routes/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import random
from typing import Annotated

from app.utils.errors import InferenceError, OutOfMemoryError
import torch
from app.utils.errors import InferenceError
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url
Expand All @@ -28,10 +29,10 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
Returns:
A JSONResponse with the appropriate error message and status code.
"""
logger.error(f"TextToImage pipeline error: {str(e)}") # Log the detailed error
if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError):
if isinstance(e, torch.cuda.OutOfMemoryError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = "Out of memory error. Try reducing input image resolution."
torch.cuda.empty_cache()
elif isinstance(e, InferenceError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = str(e)
Expand Down Expand Up @@ -145,6 +146,7 @@ async def upscale(
seed=seed,
)
except Exception as e:
logger.error(f"TextToImage pipeline error: {str(e)}")
return handle_pipeline_error(e)

seeds = [seed]
Expand Down

0 comments on commit 53d76eb

Please sign in to comment.