Skip to content

Commit

Permalink
add min_diff, clamp_quantile args
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Nov 9, 2023
1 parent e20e9f6 commit 2c1e669
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions networks/extract_lora_from_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import lora


CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-1
#CLAMP_QUANTILE = 0.99
#MIN_DIFF = 1e-1


def save_to_file(file_name, model, state_dict, dtype):
Expand All @@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name)


def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False):
def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False):
def str_to_dtype(p):
if p == "float":
return torch.float
Expand Down Expand Up @@ -91,9 +91,9 @@ def str_to_dtype(p):
diff = module_t.weight - module_o.weight

# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")

diff = diff.float()
diffs[lora_name] = diff
Expand Down Expand Up @@ -149,7 +149,7 @@ def str_to_dtype(p):
Vh = Vh[:rank, :]

dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val

U = U.clamp(low_val, hi_val)
Expand Down Expand Up @@ -252,6 +252,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01",
)
parser.add_argument(
"--no_metadata",
action="store_true",
Expand Down

0 comments on commit 2c1e669

Please sign in to comment.