Skip to content

Commit

Permalink
feat(api): backend support for multiple GPUs in diffusion pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 21, 2023
1 parent 88c5113 commit a868c8c
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions api/onnx_web/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
OnnxStableDiffusionInpaintPipeline,
)
from PIL import Image, ImageChops
from typing import Any
from typing import Any, Union

import gc
import numpy as np
Expand Down Expand Up @@ -44,7 +44,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents


def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any):
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None]):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
Expand All @@ -61,14 +61,23 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)

if device is not None:
pipe = pipe.to(device)

last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler

if last_pipeline_scheduler != scheduler:
print('changing pipeline scheduler')
pipe.scheduler = scheduler.from_pretrained(
scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')

if device is not None:
scheduler = scheduler.to(device)

pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler

print('running garbage collection during pipeline change')
Expand Down

0 comments on commit a868c8c

Please sign in to comment.