Skip to content

Commit

Permalink
fix(api): collate CNet after unloading UNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 15, 2023
1 parent 01ab864 commit b6a4cba
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ def convert_diffusion_diffusers_cnet(
del pipe_cnet
run_gc()

return cnet_path


@torch.no_grad()
def collate_cnet(cnet_path):
logger.debug("collating CNet external tensors")
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
Expand Down Expand Up @@ -292,7 +297,7 @@ def convert_diffusion_diffusers(

cnet_only = False
if path.exists(dest_path) and path.exists(model_index):
if not path.exists(model_cnet):
if not single_vae and not path.exists(model_cnet):
logger.info(
"ONNX model was converted without a ControlNet UNet, converting one"
)
Expand Down Expand Up @@ -424,9 +429,10 @@ def convert_diffusion_diffusers(
v2=v2,
)

cnet_path = None
if not single_vae or not conversion.control:
# if converting only the CNet, the rest of the model has already been converted
convert_diffusion_diffusers_cnet(
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
source,
device,
Expand All @@ -445,28 +451,34 @@ def convert_diffusion_diffusers(
del pipeline.unet
run_gc()

if cnet_path is not None:
collate_cnet(cnet_path)

if cnet_only:
logger.info("done converting CNet")
return (True, dest_path)
else:
logger.debug("collating UNet external tensors")
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
unet = load_model(unet_model_path)

# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)

# collate external tensor files into one
save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)

logger.debug("collating UNet external tensors")
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
unet = load_model(unet_model_path)

# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)

# collate external tensor files into one
save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)

del unet
run_gc()

# VAE
if replace_vae is not None:
Expand Down Expand Up @@ -540,7 +552,7 @@ def convert_diffusion_diffusers(
vae_decoder.forward = vae_encoder.decode

vae_path = output_path / "vae_decoder" / ONNX_MODEL
logger.info("exporting VAE encoder to %s", vae_path)
logger.info("exporting VAE decoder to %s", vae_path)
onnx_export(
vae_decoder,
model_args=(
Expand Down

0 comments on commit b6a4cba

Please sign in to comment.