Skip to content

Commit

Permalink
fix(api): add model image size and version hint to extras file
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 30, 2023
1 parent 2690eaf commit bc71583
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 36 deletions.
83 changes: 47 additions & 36 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,42 +52,51 @@


def get_model_version(
checkpoint,
size=None,
source,
map_location,
size = None,
version = None,
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None

if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768

v2 = False
v2 = version is not None and "v2" in version
opts = {
"extract_ema": True,
"image_size": size,
}

key_name = (
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True

if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
try:
checkpoint = load_tensor(source, map_location=map_location)

if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None

if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768

opts["image_size"] = size

key_name = (
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True

if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
except:
logger.debug("unable to load tensor for version check")
pass

return (v2, opts)

Expand Down Expand Up @@ -241,12 +250,14 @@ def convert_diffusion_diffusers(
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
source = source or model.get("source")
config = model.get("config", None)
image_size = model.get("image_size", None)
name = model.get("name")
pipe_type = model.get("pipeline", "txt2img")
single_vae = model.get("single_vae")
source = source or model.get("source")
replace_vae = model.get("vae")
pipe_type = model.get("pipeline", "txt2img")
version = model.get("version", None)

device = conversion.training_device
dtype = conversion.torch_dtype()
Expand Down Expand Up @@ -279,7 +290,7 @@ def convert_diffusion_diffusers(
return (False, dest_path)

pipe_class = available_pipelines.get(pipe_type)
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))
v2, pipe_args = get_model_version(source, conversion.map_location, size=image_size, version=version)

if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
Expand All @@ -299,7 +310,7 @@ def convert_diffusion_diffusers(
pipeline_class=pipe_class,
torch_dtype=dtype,
**pipe_args,
).to(device)
).to(device, torch_dtype=dtype)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
except Exception as e:
logger.warning("error loading tensor: %s", e)

if checkpoint is None:
raise ValueError("error loading tensor")

if checkpoint is not None and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]

Expand Down
9 changes: 9 additions & 0 deletions api/schemas/extras.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ $defs:
properties:
config:
type: string
image_size:
type: number
inversions:
type: array
items:
Expand All @@ -94,6 +96,13 @@ $defs:
]
vae:
type: string
version:
type: string
enum: [
v1,
v2,
v2.1,
]

upscaling_model:
allOf:
Expand Down

0 comments on commit bc71583

Please sign in to comment.