Skip to content

Commit

Permalink
fix(api): better handling for errors while converting checkpoint to t…
Browse files Browse the repository at this point in the history
…orch (#165)
  • Loading branch information
ssube committed Apr 9, 2023
1 parent 89ebbb8 commit ff9ce03
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
19 changes: 11 additions & 8 deletions api/onnx_web/convert/diffusion/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ def extract_checkpoint(
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
if checkpoint is None:
logger.warning("unable to load tensor")
return
return False

rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"]
Expand Down Expand Up @@ -1469,7 +1469,7 @@ def extract_checkpoint(
logger.warning(
"unable to select a config file: %s" % (original_config_file)
)
return
return False

logger.debug("trying to load: %s", original_config_file)
original_config = load_yaml(original_config_file)
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def extract_checkpoint(

if pipe is None or db_config is None:
logger.error("pipeline or config is not set, unable to continue")
return
return False
else:
logger.info("saving diffusion model")
pipe.save_pretrained(db_config.pretrained_model_name_or_path)
Expand Down Expand Up @@ -1681,21 +1681,24 @@ def convert_diffusion_original(
model_index = os.path.join(working_name, "model_index.json")

if os.path.exists(torch_path) and os.path.exists(model_index):
logger.info("torch pipeline already exists, reusing: %s", torch_path)
logger.info("Torch model already exists, reusing: %s", torch_path)
else:
logger.info(
"converting original Diffusers check to Torch model: %s -> %s",
"converting checkpoint to Torch model: %s -> %s",
source,
torch_path,
)
extract_checkpoint(
if extract_checkpoint(
ctx,
torch_name,
source,
config_file=model.get("config"),
vae_file=model.get("vae"),
)
logger.info("converted original Diffusers checkpoint to Torch model")
):
logger.info("converted checkpoint to Torch model")
else:
logger.error("unable to convert checkpoint to Torch model")
raise ValueError("unable to convert checkpoint to Torch model")

# VAE has already been converted and will confuse HF repo lookup
if "vae" in model:
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
logger.warning("error loading pickle tensor: %s", e)
elif extension in ["onnx", "pt"]:
logger.warning(
"tensor has ONNX extension, falling back to PyTorch: %s", extension
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", extension
)
try:
checkpoint = load_torch(name, map_location=map_location)
Expand Down

0 comments on commit ff9ce03

Please sign in to comment.