From 2c1e669bd868fa103f17823456a35304f534f2bb Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 8 Nov 2023 19:35:10 +0900 Subject: [PATCH] add min_diff, clamp_quantile args based on https://github.com/bmaltais/kohya_ss/pull/1332 https://github.com/bmaltais/kohya_ss/commit/a9ec90c40a1b390586edfba4f21a4acf6b6c09ad --- networks/extract_lora_from_models.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index e500185db..6c45d9770 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -13,8 +13,8 @@ import lora -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-1 +#CLAMP_QUANTILE = 0.99 +#MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False): +def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False): def str_to_dtype(p): if p == "float": return torch.float @@ -91,9 +91,9 @@ def str_to_dtype(p): diff = module_t.weight - module_o.weight # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") diff = diff.float() diffs[lora_name] = diff @@ -149,7 +149,7 @@ def str_to_dtype(p): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) + hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) @@ -252,6 +252,18 @@ def setup_parser() -> argparse.ArgumentParser: help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", ) parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99", + ) + parser.add_argument( + "--min_diff", + type=float, + default=0.01, + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01", + ) parser.add_argument( "--no_metadata", action="store_true",