Skip to content

Commit

Permalink
Update extract LoRA and add sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 26, 2023
1 parent 101d263 commit a9ec90c
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 14 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,4 +494,6 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
- Add missing LR number of cycles and LR power to Dreambooth and TI scripts
- Fix issue with conv_block_dims and conv_block_alphas
- Fix 0 noise offset issue
- Implement Stop training button on LoRA
- Implement Stop training button on LoRA
- Add support to extract LoRA from SDXL finetuned models
- Add support for PagedAdamW8bit and PagedLion8bit optimizer. Those require a new version of bitsandbytes so success on some systems may vary. I had to uninstall all my nvidia drivers and othe cuda toolkit install, delete all cuda variable references and re-install cuda toolkit v11.8.0 to get things to work... so not super easy.

Check warning on line 499 in README.md

View workflow job for this annotation

GitHub Actions / build

"othe" should be "other".
11 changes: 11 additions & 0 deletions examples/finetune_latent.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Command 1: merge_captions_to_metadata.py
$captionExtension = "--caption_extension=.txt"
$sourceDir1 = "d:\test\1_1960-1969"
$targetFile1 = "d:\test\1_1960-1969/meta_cap.json"

# Command 2: prepare_buckets_latents.py
$targetLatentFile = "d:\test\1_1960-1969/meta_lat.json"
$modelFile = "E:\models\sdxl\sd_xl_base_0.9.safetensors"

./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py $captionExtension $sourceDir1 $targetFile1 --full_path
./venv/Scripts/python.exe finetune/prepare_buckets_latents.py $sourceDir1 $targetFile1 $targetLatentFile $modelFile --batch_size=4 --max_resolution=1024,1024 --min_bucket_reso=64 --max_bucket_reso=2048 --mixed_precision=bf16 --full_path
5 changes: 5 additions & 0 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def save_configuration(
caption_extension,
# use_8bit_adam,
xformers,
sdpa,
clip_skip,
save_state,
resume,
Expand Down Expand Up @@ -209,6 +210,7 @@ def open_configuration(
caption_extension,
# use_8bit_adam,
xformers,
sdpa,
clip_skip,
save_state,
resume,
Expand Down Expand Up @@ -326,6 +328,7 @@ def train_model(
caption_extension,
# use_8bit_adam,
xformers,
sdpa,
clip_skip,
save_state,
resume,
Expand Down Expand Up @@ -575,6 +578,7 @@ def train_model(
gradient_checkpointing=gradient_checkpointing,
full_fp16=full_fp16,
xformers=xformers,
spda=sdpa,
# use_8bit_adam=use_8bit_adam,
keep_tokens=keep_tokens,
persistent_data_loader_workers=persistent_data_loader_workers,
Expand Down Expand Up @@ -866,6 +870,7 @@ def finetune_tab(headless=False):
source_model.save_model_as,
basic_training.caption_extension,
advanced_training.xformers,
advanced_training.sdpa,
advanced_training.clip_skip,
advanced_training.save_state,
advanced_training.resume,
Expand Down
3 changes: 2 additions & 1 deletion library/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ def noise_offset_type_change(noise_offset_type):
# use_8bit_adam = gr.Checkbox(
# label='Use 8bit adam', value=False, visible=False
# )
self.xformers = gr.Checkbox(label='Use xformers', value=True)
self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
self.color_aug = gr.Checkbox(label='Color augmentation', value=False)
self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
self.min_snr_gamma = gr.Slider(
label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1
)
with gr.Row():
self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention')
self.bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=True
)
Expand Down
19 changes: 19 additions & 0 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,10 @@ def run_cmd_advanced_training(**kwargs):
xformers = kwargs.get('xformers')
if xformers:
run_cmd += ' --xformers'

sdpa = kwargs.get('sdpa')
if sdpa:
run_cmd += ' --sdpa'

persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers')
if persistent_data_loader_workers:
Expand Down Expand Up @@ -856,3 +860,18 @@ def check_duplicate_filenames(folder_path, image_extension = ['.gif', '.png', '.
print(f"Current file: {full_path}")
else:
filenames[filename] = full_path

def is_file_writable(file_path):
if not os.path.exists(file_path):
# print(f"File '{file_path}' does not exist.")
return True

try:
log.warning(f"File '{file_path}' already exist... it will be overwritten...")
# Check if the file can be opened in write mode (which implies it's not open by another process)
with open(file_path, 'a'):
pass
return True
except IOError:
log.warning(f"File '{file_path}' can't be written to...")
return False
35 changes: 30 additions & 5 deletions library/extract_lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import os
from .common_gui import (
get_saveasfilename_path,
get_any_file_path,
get_file_path,
is_file_writable
)

from library.custom_logging import setup_logging
Expand All @@ -27,25 +27,31 @@ def extract_lora(
save_precision,
dim,
v2,
sdxl,
conv_dim,
clamp_quantile,
min_diff,
device,
):
# Check for caption_text_input
if model_tuned == '':
msgbox('Invalid finetuned model file')
log.info('Invalid finetuned model file')
return

if model_org == '':
msgbox('Invalid base model file')
log.info('Invalid base model file')
return

# Check if source model exist
if not os.path.isfile(model_tuned):
msgbox('The provided finetuned model is not a file')
log.info('The provided finetuned model is not a file')
return

if not os.path.isfile(model_org):
msgbox('The provided base model is not a file')
log.info('The provided base model is not a file')
return

if not is_file_writable(save_to):
return

run_cmd = (
Expand All @@ -61,6 +67,10 @@ def extract_lora(
run_cmd += f' --conv_dim {conv_dim}'
if v2:
run_cmd += f' --v2'
if sdxl:
run_cmd += f' --sdxl'
run_cmd += f' --clamp_quantile {clamp_quantile}'
run_cmd += f' --min_diff {min_diff}'

log.info(run_cmd)

Expand Down Expand Up @@ -160,7 +170,19 @@ def gradio_extract_lora_tab(headless=False):
step=1,
interactive=True,
)
clamp_quantile = gr.Number(
label='Clamp Quantile',
value=1,
interactive=True,
)
min_diff = gr.Number(
label='Minimum difference',
value=0.01,
interactive=True,
)
with gr.Row():
v2 = gr.Checkbox(label='v2', value=False, interactive=True)
sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True)
device = gr.Dropdown(
label='Device',
choices=[
Expand All @@ -182,7 +204,10 @@ def gradio_extract_lora_tab(headless=False):
save_precision,
dim,
v2,
sdxl,
conv_dim,
clamp_quantile,
min_diff,
device,
],
show_progress=False,
Expand Down
24 changes: 17 additions & 7 deletions networks/extract_lora_from_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import library.sdxl_model_util as sdxl_model_util
import lora


CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-4

# CLAMP_QUANTILE = 1
# MIN_DIFF = 1e-2

def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
Expand Down Expand Up @@ -91,9 +89,9 @@ def str_to_dtype(p):
diff = module_t.weight - module_o.weight

# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
if not text_encoder_different and torch.max(torch.abs(diff)) > args.min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {args.min_diff}")

diff = diff.float()
diffs[lora_name] = diff
Expand Down Expand Up @@ -149,7 +147,7 @@ def str_to_dtype(p):
Vh = Vh[:rank, :]

dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
hi_val = torch.quantile(dist, args.clamp_quantile)
low_val = -hi_val

U = U.clamp(low_val, hi_val)
Expand Down Expand Up @@ -243,6 +241,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=1,
help="Quantile clamping value, float, (0-1). Defailt = 1",

Check warning on line 248 in networks/extract_lora_from_models.py

View workflow job for this annotation

GitHub Actions / build

"Defailt" should be "Default".
)
parser.add_argument(
"--min_diff",
type=float,
default=1,
help="Minimum difference betwen finetuned model and base to consider them different enough to extract, float, (0-1). Defailt = 0.01",

Check warning on line 254 in networks/extract_lora_from_models.py

View workflow job for this annotation

GitHub Actions / build

"betwen" should be "between".

Check warning on line 254 in networks/extract_lora_from_models.py

View workflow job for this annotation

GitHub Actions / build

"Defailt" should be "Default".
)

return parser

Expand Down

0 comments on commit a9ec90c

Please sign in to comment.