diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 60554071e..221c7cf47 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -5,7 +5,7 @@ from PIL import Image -from ..errors import RetryException +from ..errors import CancelledException, RetryException from ..output import save_image from ..params import ImageParams, Size, StageParams from ..server import ServerContext @@ -146,7 +146,7 @@ def __call__( kwargs.pop("params") # the stage must be split and tiled if any image is larger than the selected/max tile size - must_tile = "mask" in stage_kwargs or any( + must_tile = has_mask(stage_kwargs) or any( [ needs_tile( stage_pipe.max_tile, @@ -192,6 +192,10 @@ def stage_tile( save_image(server, f"last-tile-{j}.png", image) return tile_result + except CancelledException as err: + worker.retries = 0 + logger.exception("job was cancelled while tiling") + raise err except Exception: worker.retries = worker.retries - 1 logger.exception( @@ -234,6 +238,10 @@ def stage_tile( # does not like, so it throws stage_sources = stage_result break + except CancelledException as err: + worker.retries = 0 + logger.exception("job was cancelled during stage") + raise err except Exception: worker.retries = worker.retries - 1 logger.exception( @@ -264,3 +272,9 @@ def stage_tile( len(stage_sources), ) return stage_sources + + +MASK_KEYS = ["mask", "stage_mask", "tile_mask"] + +def has_mask(args: List[str]) -> bool: + return any([key in args for key in MASK_KEYS])