Skip to content

Commit

Permalink
feat(api): allow offloading individual models to CPU (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 24, 2023
1 parent 7caaa9e commit cad87b9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
42 changes: 38 additions & 4 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def load_pipeline(
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider(),
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
)
)
Expand Down Expand Up @@ -244,7 +244,7 @@ def load_pipeline(
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider(),
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
)
)
Expand All @@ -264,7 +264,7 @@ def load_pipeline(
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider(),
provider=device.ort_provider("unet"),
sess_options=unet_opts,
)
)
Expand All @@ -276,10 +276,44 @@ def load_pipeline(
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet,
provider=device.ort_provider(),
provider=device.ort_provider("unet"),
sess_options=device.sess_options(),
)
)

# one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL)
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
vae_encoder = path.join(model, "vae_encoder", ONNX_MODEL)

if path.exists(vae):
logger.debug("loading VAE from %s", vae)
components["vae"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)
elif path.exists(vae_decoder) and path.exists(vae_encoder):
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)

logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)


pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
Expand Down
8 changes: 7 additions & 1 deletion api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ def __init__(
def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options)

def ort_provider(self) -> Union[str, Tuple[str, Any]]:
def ort_provider(self, model_type: Optional[str] = None) -> Union[str, Tuple[str, Any]]:
if model_type is not None:
# check if model has been pinned to CPU
# TODO: check whether the CPU device is allowed
if f"onnx-cpu-{model_type}" in self.optimizations:
return "CPUExecutionProvider"

if self.options is None:
return self.provider
else:
Expand Down

0 comments on commit cad87b9

Please sign in to comment.