diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index dba7cd4e2..ed39d7bc8 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(args): +def svd(v2=None, model_org=None, model_tuned=None, sdxl=None, dim=4, conv_dim=None, v_parameterization=None, save_to=None, device=None, save_precision=None, no_metadata=False): def str_to_dtype(p): if p == "float": return torch.float @@ -39,44 +39,44 @@ def str_to_dtype(p): return torch.bfloat16 return None - assert args.v2 != args.sdxl or ( - not args.v2 and not args.sdxl + assert v2 != sdxl or ( + not v2 and not sdxl ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if args.v_parameterization is None: - args.v_parameterization = args.v2 + if v_parameterization is None: + v_parameterization = v2 - save_dtype = str_to_dtype(args.save_precision) + save_dtype = str_to_dtype(save_precision) # load models - if not args.sdxl: - print(f"loading original SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + if not sdxl: + print(f"loading original SD model : {model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + print(f"loading tuned SD model : {model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] - model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - print(f"loading original SDXL model : {args.model_org}") + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {args.model_tuned}") + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" ) text_encoders_t = [text_encoder_t1, text_encoder_t2] model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha - if args.conv_dim is None: + if conv_dim is None: kwargs = {} else: - kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) + lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -120,16 +120,16 @@ def str_to_dtype(p): lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): - # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) - rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim out_dim, in_dim = mat.size()[0:2] - if args.device: - mat = mat.to(args.device) + if device: + mat = mat.to(device) # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim @@ -178,34 +178,34 @@ def str_to_dtype(p): info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") - dir_name = os.path.dirname(args.save_to) + dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # minimum metadata net_kwargs = {} - if args.conv_dim is not None: - net_kwargs["conv_dim"] = args.conv_dim - net_kwargs["conv_alpha"] = args.conv_dim + if conv_dim is not None: + net_kwargs["conv_dim"] = str(conv_dim) + net_kwargs["conv_alpha"] = str(float(conv_dim)) metadata = { - "ss_v2": str(args.v2), + "ss_v2": str(v2), "ss_base_model_version": model_version, "ss_network_module": "networks.lora", - "ss_network_dim": str(args.dim), - "ss_network_alpha": str(args.dim), + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), "ss_network_args": json.dumps(net_kwargs), } - if not args.no_metadata: - title = os.path.splitext(os.path.basename(args.save_to))[0] + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title + None, v2, v_parameterization, sdxl, True, False, time.time(), title=title ) metadata.update(sai_metadata) - lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + lora_network_save.save_weights(save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: @@ -213,7 +213,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument( "--v_parameterization", - type=bool, + action="store_true", default=None, help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", ) @@ -231,16 +231,18 @@ def setup_parser() -> argparse.ArgumentParser: "--model_org", type=str, default=None, + required=True, help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", ) parser.add_argument( "--model_tuned", type=str, default=None, + required=True, help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -264,4 +266,4 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() - svd(args) + svd(**vars(args))