Skip to content

Commit

Permalink
fix(api): download pretrained models from HF correctly (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 4, 2023
1 parent b6692f0 commit d66bf9e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
48 changes: 26 additions & 22 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def fetch_model(
format: Optional[str] = None,
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> str:
) -> Tuple[str, bool]:
cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name)

Expand All @@ -223,7 +223,7 @@ def fetch_model(

if path.exists(cache_name):
logger.debug("model already exists in cache, skipping fetch")
return cache_name
return cache_name, False

for proto in model_sources:
api_name, api_root = model_sources.get(proto)
Expand All @@ -232,33 +232,36 @@ def fetch_model(
logger.info(
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
)
return download_progress([(api_source, cache_name)])
return download_progress([(api_source, cache_name)]), False

if source.startswith(model_source_huggingface):
hub_source = remove_prefix(source, model_source_huggingface)
logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return hf_hub_download(
repo_id=hub_source,
filename=hf_hub_filename,
cache_dir=cache_path,
force_filename=f"{name}.bin",
return (
hf_hub_download(
repo_id=hub_source,
filename=hf_hub_filename,
cache_dir=cache_path,
force_filename=f"{name}.bin",
),
False,
)
else:
return hub_source
return hub_source, True
elif source.startswith("https://"):
logger.info("downloading model from: %s", source)
return download_progress([(source, cache_name)])
return download_progress([(source, cache_name)]), False
elif source.startswith("http://"):
logger.warning("downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)])
return download_progress([(source, cache_name)]), False
elif source.startswith(path.sep) or source.startswith("."):
logger.info("using local model: %s", source)
return source
return source, False
else:
logger.info("unknown model location, using path as provided: %s", source)
return source
return source, False


def convert_models(conversion: ConversionContext, args, models: Models):
Expand All @@ -280,7 +283,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])

dest = fetch_model(
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
Expand All @@ -302,7 +305,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):

try:
if network_type == "control":
dest = fetch_model(
dest, hf = fetch_model(
conversion,
name,
source,
Expand All @@ -315,7 +318,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
dest,
)
if network_type == "inversion" and network_model == "concept":
dest = fetch_model(
dest, hf = fetch_model(
conversion,
name,
source,
Expand All @@ -325,7 +328,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
hf_hub_filename="learned_embeds.bin",
)
else:
dest = fetch_model(
dest, hf = fetch_model(
conversion,
name,
source,
Expand All @@ -349,7 +352,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_format = source_format(model)

try:
source = fetch_model(
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)

Expand All @@ -358,6 +361,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model,
source,
model_format,
hf=hf,
)

# make sure blending only happens once, not every run
Expand Down Expand Up @@ -389,7 +393,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source = fetch_model(
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
Expand Down Expand Up @@ -430,7 +434,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
Expand Down Expand Up @@ -489,7 +493,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_format = source_format(model)

try:
source = fetch_model(
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "resrgan")
Expand Down Expand Up @@ -521,7 +525,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
else:
model_format = source_format(model)
try:
source = fetch_model(
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "gfpgan")
Expand Down
8 changes: 8 additions & 0 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def convert_diffusion_diffusers(
model: Dict,
source: str,
format: str,
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Expand Down Expand Up @@ -316,6 +317,13 @@ def convert_diffusion_diffusers(
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
Expand Down

0 comments on commit d66bf9e

Please sign in to comment.