-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): add chain pipeline stage result type
- Loading branch information
Showing
31 changed files
with
434 additions
and
385 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,2 @@ | ||
from .base import ChainPipeline, PipelineStage, StageParams | ||
from .blend_denoise import BlendDenoiseStage | ||
from .blend_img2img import BlendImg2ImgStage | ||
from .blend_grid import BlendGridStage | ||
from .blend_linear import BlendLinearStage | ||
from .blend_mask import BlendMaskStage | ||
from .correct_codeformer import CorrectCodeformerStage | ||
from .correct_gfpgan import CorrectGFPGANStage | ||
from .persist_disk import PersistDiskStage | ||
from .persist_s3 import PersistS3Stage | ||
from .reduce_crop import ReduceCropStage | ||
from .reduce_thumbnail import ReduceThumbnailStage | ||
from .source_noise import SourceNoiseStage | ||
from .source_s3 import SourceS3Stage | ||
from .source_txt2img import SourceTxt2ImgStage | ||
from .source_url import SourceURLStage | ||
from .upscale_bsrgan import UpscaleBSRGANStage | ||
from .upscale_highres import UpscaleHighresStage | ||
from .upscale_outpaint import UpscaleOutpaintStage | ||
from .upscale_resrgan import UpscaleRealESRGANStage | ||
from .upscale_simple import UpscaleSimpleStage | ||
from .upscale_stable_diffusion import UpscaleStableDiffusionStage | ||
from .upscale_swinir import UpscaleSwinIRStage | ||
|
||
CHAIN_STAGES = { | ||
"blend-denoise": BlendDenoiseStage, | ||
"blend-img2img": BlendImg2ImgStage, | ||
"blend-inpaint": UpscaleOutpaintStage, | ||
"blend-grid": BlendGridStage, | ||
"blend-linear": BlendLinearStage, | ||
"blend-mask": BlendMaskStage, | ||
"correct-codeformer": CorrectCodeformerStage, | ||
"correct-gfpgan": CorrectGFPGANStage, | ||
"persist-disk": PersistDiskStage, | ||
"persist-s3": PersistS3Stage, | ||
"reduce-crop": ReduceCropStage, | ||
"reduce-thumbnail": ReduceThumbnailStage, | ||
"source-noise": SourceNoiseStage, | ||
"source-s3": SourceS3Stage, | ||
"source-txt2img": SourceTxt2ImgStage, | ||
"source-url": SourceURLStage, | ||
"upscale-bsrgan": UpscaleBSRGANStage, | ||
"upscale-highres": UpscaleHighresStage, | ||
"upscale-outpaint": UpscaleOutpaintStage, | ||
"upscale-resrgan": UpscaleRealESRGANStage, | ||
"upscale-simple": UpscaleSimpleStage, | ||
"upscale-stable-diffusion": UpscaleStableDiffusionStage, | ||
"upscale-swinir": UpscaleSwinIRStage, | ||
} | ||
from .pipeline import ChainPipeline, PipelineStage, StageParams | ||
from .stages import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,283 +1,39 @@ | ||
from datetime import timedelta | ||
from logging import getLogger | ||
from time import monotonic | ||
from typing import Any, List, Optional, Tuple | ||
from typing import List, Optional | ||
|
||
from PIL import Image | ||
|
||
from ..errors import RetryException | ||
from ..output import save_image | ||
from ..params import ImageParams, Size, StageParams | ||
from ..server import ServerContext | ||
from ..utils import is_debug, run_gc | ||
from ..worker import ProgressCallback, WorkerContext | ||
from .stage import BaseStage | ||
from .tile import needs_tile, process_tile_order | ||
from .result import StageResult | ||
from ..params import ImageParams, Size, SizeChart, StageParams | ||
from ..server.context import ServerContext | ||
from ..worker.context import WorkerContext | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] | ||
|
||
|
||
class ChainProgress: | ||
def __init__(self, parent: ProgressCallback, start=0) -> None: | ||
self.parent = parent | ||
self.step = start | ||
self.total = 0 | ||
|
||
def __call__(self, step: int, timestep: int, latents: Any) -> None: | ||
if step < self.step: | ||
# accumulate on resets | ||
self.total += self.step | ||
|
||
self.step = step | ||
self.parent(self.get_total(), timestep, latents) | ||
|
||
def get_total(self) -> int: | ||
return self.step + self.total | ||
|
||
@classmethod | ||
def from_progress(cls, parent: ProgressCallback): | ||
start = parent.step if hasattr(parent, "step") else 0 | ||
return ChainProgress(parent, start=start) | ||
|
||
|
||
class ChainPipeline: | ||
""" | ||
Run many stages in series, passing the image results from each to the next, and processing | ||
tiles as needed. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
stages: Optional[List[PipelineStage]] = None, | ||
): | ||
""" | ||
Create a new pipeline that will run the given stages. | ||
""" | ||
self.stages = list(stages or []) | ||
|
||
def append(self, stage: Optional[PipelineStage]): | ||
""" | ||
Append an additional stage to this pipeline. | ||
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to | ||
assemble the stage from loose arguments. | ||
""" | ||
if stage is not None: | ||
self.stages.append(stage) | ||
class BaseStage: | ||
max_tile = SizeChart.auto | ||
|
||
def run( | ||
self, | ||
worker: WorkerContext, | ||
server: ServerContext, | ||
params: ImageParams, | ||
sources: List[Image.Image], | ||
callback: Optional[ProgressCallback], | ||
**kwargs | ||
) -> List[Image.Image]: | ||
return self( | ||
worker, server, params, sources=sources, callback=callback, **kwargs | ||
) | ||
|
||
def stage(self, callback: BaseStage, params: StageParams, **kwargs): | ||
self.stages.append((callback, params, kwargs)) | ||
return self | ||
|
||
def steps(self, params: ImageParams, size: Size): | ||
steps = 0 | ||
for callback, _params, kwargs in self.stages: | ||
steps += callback.steps(kwargs.get("params", params), size) | ||
|
||
return steps | ||
|
||
def outputs(self, params: ImageParams, sources: int): | ||
outputs = sources | ||
for callback, _params, kwargs in self.stages: | ||
outputs = callback.outputs(kwargs.get("params", params), outputs) | ||
|
||
return outputs | ||
|
||
def __call__( | ||
_worker: WorkerContext, | ||
_server: ServerContext, | ||
_stage: StageParams, | ||
_params: ImageParams, | ||
_sources: List[Image.Image], | ||
*args, | ||
stage_source: Optional[Image.Image] = None, | ||
**kwargs, | ||
) -> StageResult: | ||
raise NotImplementedError() # noqa | ||
|
||
def steps( | ||
self, | ||
worker: WorkerContext, | ||
server: ServerContext, | ||
params: ImageParams, | ||
sources: List[Image.Image], | ||
callback: Optional[ProgressCallback] = None, | ||
**pipeline_kwargs | ||
) -> List[Image.Image]: | ||
""" | ||
DEPRECATED: use `run` instead | ||
""" | ||
if callback is None: | ||
callback = worker.get_progress_callback() | ||
else: | ||
callback = ChainProgress.from_progress(callback) | ||
|
||
start = monotonic() | ||
|
||
if len(sources) > 0: | ||
logger.info( | ||
"running pipeline on %s source images", | ||
len(sources), | ||
) | ||
else: | ||
logger.info("running pipeline without source images") | ||
_params: ImageParams, | ||
_size: Size, | ||
) -> int: | ||
return 1 # noqa | ||
|
||
stage_sources = sources | ||
for stage_pipe, stage_params, stage_kwargs in self.stages: | ||
name = stage_params.name or stage_pipe.__class__.__name__ | ||
kwargs = stage_kwargs or {} | ||
kwargs = {**pipeline_kwargs, **kwargs} | ||
logger.debug( | ||
"running stage %s with %s source images, parameters: %s", | ||
name, | ||
len(stage_sources) - stage_sources.count(None), | ||
kwargs.keys(), | ||
) | ||
|
||
per_stage_params = params | ||
if "params" in kwargs: | ||
per_stage_params = kwargs["params"] | ||
kwargs.pop("params") | ||
|
||
# the stage must be split and tiled if any image is larger than the selected/max tile size | ||
must_tile = any( | ||
[ | ||
needs_tile( | ||
stage_pipe.max_tile, | ||
stage_params.tile_size, | ||
size=kwargs.get("size", None), | ||
source=source, | ||
) | ||
for source in stage_sources | ||
] | ||
) | ||
|
||
tile = stage_params.tile_size | ||
if stage_pipe.max_tile > 0: | ||
tile = min(stage_pipe.max_tile, stage_params.tile_size) | ||
|
||
if stage_sources or must_tile: | ||
stage_outputs = [] | ||
for source in stage_sources: | ||
logger.info( | ||
"image contains sources or is larger than tile size of %s, tiling stage", | ||
tile, | ||
) | ||
|
||
extra_tiles = [] | ||
|
||
def stage_tile( | ||
source_tile: Image.Image, | ||
tile_mask: Image.Image, | ||
dims: Tuple[int, int, int], | ||
) -> Image.Image: | ||
for _i in range(worker.retries): | ||
try: | ||
output_tile = stage_pipe.run( | ||
worker, | ||
server, | ||
stage_params, | ||
per_stage_params, | ||
[source_tile], | ||
tile_mask=tile_mask, | ||
callback=callback, | ||
dims=dims, | ||
**kwargs, | ||
) | ||
|
||
if len(output_tile) > 1: | ||
while len(extra_tiles) < len(output_tile): | ||
extra_tiles.append([]) | ||
|
||
for tile, layer in zip(output_tile, extra_tiles): | ||
layer.append((tile, dims)) | ||
|
||
if is_debug(): | ||
save_image(server, "last-tile.png", output_tile[0]) | ||
|
||
return output_tile[0] | ||
except Exception: | ||
worker.retries = worker.retries - 1 | ||
logger.exception( | ||
"error while running stage pipeline for tile, %s retries left", | ||
worker.retries, | ||
) | ||
server.cache.clear() | ||
run_gc([worker.get_device()]) | ||
|
||
raise RetryException("exhausted retries on tile") | ||
|
||
output = process_tile_order( | ||
stage_params.tile_order, | ||
source, | ||
tile, | ||
stage_params.outscale, | ||
[stage_tile], | ||
**kwargs, | ||
) | ||
|
||
stage_outputs.append(output) | ||
|
||
if len(extra_tiles) > 1: | ||
for layer in extra_tiles: | ||
layer_output = Image.new("RGB", output.size) | ||
for layer_tile, dims in layer: | ||
layer_output.paste(layer_tile, (dims[0], dims[1])) | ||
|
||
stage_outputs.append(layer_output) | ||
|
||
stage_sources = stage_outputs | ||
else: | ||
logger.debug( | ||
"image does not contain sources and is within tile size of %s, running stage", | ||
tile, | ||
) | ||
for i in range(worker.retries): | ||
try: | ||
stage_outputs = stage_pipe.run( | ||
worker, | ||
server, | ||
stage_params, | ||
per_stage_params, | ||
stage_sources, | ||
callback=callback, | ||
dims=(0, 0, tile), | ||
**kwargs, | ||
) | ||
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline | ||
# does not like, so it throws | ||
stage_sources = stage_outputs | ||
break | ||
except Exception: | ||
worker.retries = worker.retries - 1 | ||
logger.exception( | ||
"error while running stage pipeline, %s retries left", | ||
worker.retries, | ||
) | ||
server.cache.clear() | ||
run_gc([worker.get_device()]) | ||
|
||
if worker.retries <= 0: | ||
raise RetryException("exhausted retries on stage") | ||
|
||
logger.debug( | ||
"finished stage %s with %s results", | ||
name, | ||
len(stage_sources), | ||
) | ||
|
||
if is_debug(): | ||
save_image(server, "last-stage.png", stage_sources[0]) | ||
|
||
end = monotonic() | ||
duration = timedelta(seconds=(end - start)) | ||
logger.info( | ||
"finished pipeline in %s with %s results", | ||
duration, | ||
len(stage_sources), | ||
) | ||
return stage_sources | ||
def outputs( | ||
self, | ||
_params: ImageParams, | ||
sources: int, | ||
) -> int: | ||
return sources |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.