Skip to content

Commit

Permalink
Merge pull request #137 from invoke-ai/Typo-fix-for-float32
Browse files Browse the repository at this point in the history
Fix
  • Loading branch information
RyanJDick committed May 30, 2024
2 parents 57ef42f + 8c9fb0a commit 38242a8
Showing 1 changed file with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline


# fmt: off
# HACK(ryand): Import order matters, because invokeai contains circular imports.
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
# fmt: on

from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline


# TODO(ryand): Consolidate multiple implementations of this function across the project.
def str_to_dtype(dtype_str: Literal["float32", "float16", "bfloat16"]):
if dtype_str == "float23":
if dtype_str == "float32":
return torch.float32
elif dtype_str == "float16":
return torch.float16
Expand Down Expand Up @@ -136,7 +136,7 @@ def merge_lora_into_sd_model(

def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]:
"""Parse a --lora-model argument into a tuple of the model path and weight."""
parts = lora_model_arg.split(":")
parts = lora_model_arg.split("::")
if len(parts) == 1:
return parts[0], 1.0
elif len(parts) == 2:
Expand Down Expand Up @@ -171,8 +171,8 @@ def main():
type=str,
nargs="+",
help="The path(s) to one or more LoRA models to merge into the base model. Model weights can be appended to "
"the path, separated by a colon (':'). E.g. 'path/to/lora_model:0.5'. The weight is optional and defaults to "
"1.0.",
"the path, separated by a double colon ('::'). E.g. 'path/to/lora_model::0.5'. The weight is optional and "
"defaults to 1.0.",
required=True,
)
parser.add_argument(
Expand All @@ -192,6 +192,15 @@ def main():

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Log the parsed arguments
logger.info(f"Base model: {args.base_model}")
logger.info(f"Base model variant: {args.base_model_variant}")
logger.info(f"Base model type: {args.base_model_type}")
logger.info(f"LoRA models: {args.lora_model}")
logger.info(f"Output directory: {args.output}")
logger.info(f"Save dtype: {args.save_dtype}")

merge_lora_into_sd_model(
logger=logger,
base_model=args.base_model,
Expand Down

0 comments on commit 38242a8

Please sign in to comment.