Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v21.8.6 #1332

Merged
merged 70 commits into from
Aug 5, 2023
Merged

v21.8.6 #1332

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
8a073ee
vram leak fix
pamparamm Jul 17, 2023
d131bde
Support for bitsandbytes 0.39.1 with Paged Optimizer(AdamW8bit and Li…
sdbds Jul 22, 2023
bb167f9
init unet with empty weights
Isotr0py Jul 23, 2023
eec6aad
fix safetensors error: device invalid
Isotr0py Jul 23, 2023
4849ea8
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into sdx…
bmaltais Jul 23, 2023
fcf087c
Fix finetuner gui code issues
bmaltais Jul 24, 2023
e83ee21
format by black
kohya-ss Jul 24, 2023
2b969e9
support sdxl
kohya-ss Jul 24, 2023
12f7ca8
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sd…
bmaltais Jul 24, 2023
b78c0e2
remove unused func
kohya-ss Jul 25, 2023
101d263
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sd…
bmaltais Jul 25, 2023
a9ec90c
Update extract LoRA and add sdpa
bmaltais Jul 26, 2023
334a551
increased max norms to 10
M4X1K02 Jul 27, 2023
32aaa40
appended relevant infor to info text
M4X1K02 Jul 27, 2023
50544b7
fix pipeline dtype
Isotr0py Jul 27, 2023
96a52d9
add dtype to u-net loading
Isotr0py Jul 27, 2023
a6a66dd
Merge pull request #1272 from M4X1K02/max_norm_adjust
bmaltais Jul 28, 2023
7de217a
Merge branch 'master' into dev2
bmaltais Jul 28, 2023
f4941ca
Merge branch 'dev2' into sdxl-dev
bmaltais Jul 28, 2023
272dd99
Merge branch 'sdxl' into sdxl
Isotr0py Jul 28, 2023
4a1b92d
Update README.md
Noyii Jul 28, 2023
315fbc1
refactor model loading to catch error
Isotr0py Jul 28, 2023
fdb58b0
fix mismatch dtype
Isotr0py Jul 28, 2023
1199eac
fix typo
Isotr0py Jul 28, 2023
38b59a9
Merge branch 'main' into dev
kohya-ss Jul 29, 2023
3a7326a
Merge pull request #693 from kohya-ss/dev
kohya-ss Jul 29, 2023
1e4512b
support ckpt without position id in sd v1 #687
kohya-ss Jul 29, 2023
fb1054b
Merge pull request #694 from kohya-ss/dev
kohya-ss Jul 29, 2023
cf80210
Merge pull request #688 from Noyii/main
kohya-ss Jul 29, 2023
4072f72
Merge branch 'main' into sdxl
kohya-ss Jul 29, 2023
d9180c0
fix typos for _load_state_dict
Isotr0py Jul 29, 2023
814191b
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
bmaltais Jul 29, 2023
6f6bf88
Merge branch 'sdxl-dev' into dev2
bmaltais Jul 29, 2023
c55606e
Updated to bitsandbytes version for linux and windows
bmaltais Jul 29, 2023
72f9fdc
Allow DB training on SDXL models.
bmaltais Jul 29, 2023
e20b6ac
Merge pull request #676 from Isotr0py/sdxl
kohya-ss Jul 30, 2023
9ec7025
Add Paged/ adam8bit/lion8bit for Sdxl bitsandbytes 0.39.1 cuda118 on …
sdbds Jul 30, 2023
54a4aa2
Merge pull request #658 from pamparamm/cache_latents_leak_fix
kohya-ss Jul 30, 2023
e6034b7
move releasing cache outside of the loop
kohya-ss Jul 30, 2023
b62185b
change method name, add comments
kohya-ss Jul 30, 2023
a296654
refactor optimizer selection for bnb
kohya-ss Jul 30, 2023
2a4ae88
format by black
kohya-ss Jul 30, 2023
0eacadf
fix ControlNet not working
kohya-ss Jul 30, 2023
8856c19
fix batch generation not working
kohya-ss Jul 30, 2023
496c3f2
arbitrary args for diffusers lr scheduler
kohya-ss Jul 30, 2023
f61996b
remove dependency for albumenations
kohya-ss Jul 30, 2023
7e474d2
fix recorded seed in highres fix
kohya-ss Jul 30, 2023
0636399
add adding v-pred like loss for noise pred
kohya-ss Jul 30, 2023
89aae3e
fix vae crashes in large reso
kohya-ss Jul 31, 2023
ebcb6ef
Update requirements and others
bmaltais Aug 1, 2023
db80c5a
format by black
kohya-ss Aug 3, 2023
6b1cf6c
fix ControlNet with regional LoRA, add shuffle cap
kohya-ss Aug 3, 2023
cf68328
fix ControlNet with regional LoRA
kohya-ss Aug 3, 2023
c6d52fd
Add workaround for clip's bug for pooled output
kohya-ss Aug 3, 2023
9d7619d
remove debug print
kohya-ss Aug 3, 2023
f3be995
remove debug print
kohya-ss Aug 3, 2023
9d85509
make bitsandbytes optional
kohya-ss Aug 4, 2023
f4935dd
Merge pull request #714 from kohya-ss/dev
kohya-ss Aug 4, 2023
78bae16
Updates
bmaltais Aug 4, 2023
90d2160
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
bmaltais Aug 4, 2023
8b46fe3
Update release
bmaltais Aug 4, 2023
25d8cd4
fix sdxl_gen_img not working
kohya-ss Aug 5, 2023
2dfa26c
Merge pull request #716 from kohya-ss/dev
kohya-ss Aug 5, 2023
7c55ca3
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
bmaltais Aug 5, 2023
e5f9772
fix training textencoder in sdxl not working
kohya-ss Aug 5, 2023
cd54af0
Merge pull request #720 from kohya-ss/dev
kohya-ss Aug 5, 2023
08eb3ad
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
bmaltais Aug 5, 2023
3e64386
Updates
bmaltais Aug 5, 2023
5eef71e
Update presets
bmaltais Aug 5, 2023
5042975
Update readme
bmaltais Aug 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v21.8.5
v21.8.6
10 changes: 10 additions & 0 deletions README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsa

アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。

### オプション:PagedAdamW8bitとPagedLion8bitを使う

PagedAdamW8bitとPagedLion8bitを使う場合には`bitsandbytes`を0.39.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。

```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```

アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。

## アップグレード

新しいリリースがあった場合、以下のコマンドで更新できます。
Expand Down
42 changes: 27 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ First SDXL Tutorial: [First Ever SDXL Training With Kohya LoRA - Stable Diffusio

The feature of SDXL training is now available in sdxl branch as an experimental feature.

Aug 4, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.

- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
- `albumentations` is not required anymore.
- An issue for pooled output for Textual Inversion training is fixed.
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
- `sdxl_gen_imgs.py` supports batch size > 1.
- Fix ControlNet to work with attention couple and reginal LoRA in `gen_img_diffusers.py`.

Summary of the feature:

- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
Expand Down Expand Up @@ -115,12 +126,17 @@ Summary of the feature:
#### Tips for SDXL training

- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended:
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 12GB GPU memory.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (-8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Expand Down Expand Up @@ -480,16 +496,12 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is

## Change History

* 2023/07/27 (v21.8.4)
- Relocate LR number of cycles and LR power options
- Add missing LR number of cycles and LR power to Dreambooth and TI scripts
- Fix issue with conv_block_dims and conv_block_alphas
- Fix 0 noise offset issue
- Implement Stop training button on LoRA and other training tabs
- Update LyCORIS network release to fix an issue with the LoCon extraction.

* 2023/07/18 (v21.8.3)
- Update to latest sd-scripts sdxl code base
- Fix typo: https://github.com/bmaltais/kohya_ss/issues/1205
- Add min and max resolution parameter for buckets
- Add colab notebook from https://github.com/camenduru/kohya_ss-colab
* 2023/08/05 (v21.8.6)
- Merge latest sd-scripts updates.
- Allow DB training on SDXL models. Unsupported but appear to work.
- Fix finetuning latent caching issue when doing SDXL models in fp16
- Add SDXL merge lora support. You can now merge LoRAs into an SDXL checkpoint.
- Add SDPA CrossAttention option to trainers.
- Merge latest kohya_ss sd-scripts code
- Fix Dreambooth support for SDXL training
- Update to latest bitsandbytes release. New optional install option for bitsandbytes versions.
Binary file added bitsandbytes_windows/libbitsandbytes_cuda118.dll
Binary file not shown.
2 changes: 1 addition & 1 deletion bitsandbytes_windows/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,4 @@ def get_binary_name():

binary_name = get_binary_name()

return binary_name
return binary_name
2 changes: 2 additions & 0 deletions docs/train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -609,10 +609,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
- 過去のバージョンのオプション未指定時と同じ
- AdamW8bit : 引数は同上
- PagedAdamW8bit : 引数は同上
- 過去のバージョンの--use_8bit_adam指定時と同じ
- Lion : https://github.com/lucidrains/lion-pytorch
- 過去のバージョンの--use_lion_optimizer指定時と同じ
- Lion8bit : 引数は同上
- PagedLion8bit : 引数は同上
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数は同上
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
Expand Down
5 changes: 4 additions & 1 deletion docs/train_README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
-- 指定优化器类型。您可以指定
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
- 与过去版本中未指定选项时相同
- AdamW8bit : 同上
- AdamW8bit : 参数同上
- PagedAdamW8bit : 参数同上
- 与过去版本中指定的 --use_8bit_adam 相同
- Lion : https://github.com/lucidrains/lion-pytorch
- Lion8bit : 参数同上
- PagedLion8bit : 参数同上
- 与过去版本中指定的 --use_lion_optimizer 相同
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 参数同上
Expand Down
29 changes: 22 additions & 7 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def save_configuration(
enable_bucket,
gradient_checkpointing,
full_fp16,
full_bf16,
no_token_padding,
stop_text_encoder_training,
min_bucket_reso,
Expand Down Expand Up @@ -192,6 +193,7 @@ def open_configuration(
enable_bucket,
gradient_checkpointing,
full_fp16,
full_bf16,
no_token_padding,
stop_text_encoder_training,
min_bucket_reso,
Expand Down Expand Up @@ -304,6 +306,7 @@ def train_model(
enable_bucket,
gradient_checkpointing,
full_fp16,
full_bf16,
no_token_padding,
stop_text_encoder_training_pct,
min_bucket_reso,
Expand Down Expand Up @@ -410,12 +413,12 @@ def train_model(
):
return

if sdxl:
output_message(
msg='TI training is not compatible with an SDXL model.',
headless=headless_bool,
)
return
# if sdxl:
# output_message(
# msg='Dreambooth training is not compatible with SDXL models yet..',
# headless=headless_bool,
# )
# return

# if optimizer == 'Adafactor' and lr_warmup != '0':
# output_message(
Expand Down Expand Up @@ -520,7 +523,13 @@ def train_model(
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
log.info(f'lr_warmup_steps = {lr_warmup_steps}')

run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
# run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"'
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}'
if sdxl:
run_cmd += f' "./sdxl_train.py"'
else:
run_cmd += f' "./train_db.py"'

if v2:
run_cmd += ' --v2'
if v_parameterization:
Expand Down Expand Up @@ -551,6 +560,8 @@ def train_model(
# run_cmd += f' --resume={resume}'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
if full_bf16:
run_cmd += ' --full_bf16'
if not vae == '':
run_cmd += f' --vae="{vae}"'
if not output_name == '':
Expand Down Expand Up @@ -696,6 +707,9 @@ def dreambooth_tab(
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
full_bf16 = gr.Checkbox(
label='Full bf16', value = False
)
with gr.Accordion('Advanced Configuration', open=False):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
Expand Down Expand Up @@ -765,6 +779,7 @@ def dreambooth_tab(
basic_training.enable_bucket,
advanced_training.gradient_checkpointing,
advanced_training.full_fp16,
full_bf16,
advanced_training.no_token_padding,
basic_training.stop_text_encoder_training,
basic_training.min_bucket_reso,
Expand Down
11 changes: 11 additions & 0 deletions examples/finetune_latent.ps1
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Command 1: merge_captions_to_metadata.py
$captionExtension = "--caption_extension=.txt"
$sourceDir1 = "d:\test\1_1960-1969"
$targetFile1 = "d:\test\1_1960-1969/meta_cap.json"

# Command 2: prepare_buckets_latents.py
$targetLatentFile = "d:\test\1_1960-1969/meta_lat.json"
$modelFile = "E:\models\sdxl\sd_xl_base_0.9.safetensors"

./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py $captionExtension $sourceDir1 $targetFile1 --full_path
./venv/Scripts/python.exe finetune/prepare_buckets_latents.py $sourceDir1 $targetFile1 $targetLatentFile $modelFile --batch_size=4 --max_resolution=1024,1024 --min_bucket_reso=64 --max_bucket_reso=2048 --mixed_precision=bf16 --full_path
17 changes: 14 additions & 3 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,13 @@ def train_model(
run_cmd += f' --full_path'

log.info(run_cmd)

executor.execute_command(run_cmd=run_cmd)

if not print_only_bool:
# Run the command
if os.name == 'posix':
os.system(run_cmd)
else:
subprocess.run(run_cmd)

# create images buckets
if generate_image_buckets:
Expand All @@ -442,12 +447,18 @@ def train_model(
# run_cmd += f' --flip_aug'
if full_path:
run_cmd += f' --full_path'
if sdxl_no_half_vae:
log.info('Using mixed_precision = no because no half vae is selected...')
run_cmd += f' --mixed_precision="no"'

log.info(run_cmd)

if not print_only_bool:
# Run the command
executor.execute_command(run_cmd=run_cmd)
if os.name == 'posix':
os.system(run_cmd)
else:
subprocess.run(run_cmd)

image_num = len(
[
Expand Down
28 changes: 23 additions & 5 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,18 +937,24 @@ def __call__(
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)

if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
# last subprompt and negative prompt
text_emb_last = []
for j in range(batch_size):
text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 2])
text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 1])
text_emb_last = torch.stack(text_emb_last)
else:
text_emb_last = text_embeddings

for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
if self.control_nets and self.control_net_enabled:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
i,
num_latent_input,
Expand All @@ -958,6 +964,7 @@ def __call__(
i / len(timesteps),
latent_model_input,
t,
text_embeddings,
text_emb_last,
).sample
else:
Expand Down Expand Up @@ -2746,6 +2753,10 @@ def resize_images(imgs, size):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF)

# shuffle prompt list
if args.shuffle_prompts:
random.shuffle(prompt_list)

# バッチ処理の関数
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
batch_size = len(batch)
Expand Down Expand Up @@ -2963,6 +2974,8 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts)
):
if highres_fix:
seed -= 1 # record original seed
metadata = PngInfo()
metadata.add_text("prompt", prompt)
metadata.add_text("seed", str(seed))
Expand Down Expand Up @@ -3319,6 +3332,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)",
)
parser.add_argument(
"--shuffle_prompts",
action="store_true",
help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする",
)
parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
Expand Down
12 changes: 9 additions & 3 deletions library/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def noise_offset_type_change(noise_offset_type):
self.no_token_padding = gr.Checkbox(
label='No token padding', value=False
)
self.gradient_accumulation_steps = gr.Number(
label='Gradient accumulate steps', value='1'
self.gradient_accumulation_steps = gr.Slider(
label='Gradient accumulate steps',
info='Number of updates steps to accumulate before performing a backward/update pass',
value='1',
minimum=1, maximum=120,
step=1
)
self.weighted_captions = gr.Checkbox(
label='Weighted captions', value=False
Expand Down Expand Up @@ -101,13 +105,15 @@ def noise_offset_type_change(noise_offset_type):
# use_8bit_adam = gr.Checkbox(
# label='Use 8bit adam', value=False, visible=False
# )
self.xformers = gr.Checkbox(label='Use xformers', value=True)
# self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention')
self.xformers = gr.Dropdown(label='CrossAttention', choices=["none", "sdpa", "xformers"], value='xformers')
self.color_aug = gr.Checkbox(label='Color augmentation', value=False)
self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
self.min_snr_gamma = gr.Slider(
label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1
)
with gr.Row():
# self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention')
self.bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=True
)
Expand Down
2 changes: 2 additions & 0 deletions library/class_basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(
'DAdaptSGD',
'Lion',
'Lion8bit',
"PagedAdamW8bit",
"PagedLion8bit",
'Prodigy',
'SGDNesterov',
'SGDNesterov8bit',
Expand Down
3 changes: 3 additions & 0 deletions library/class_dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
self.full_bf16 = gr.Checkbox(
label='Full bf16', value = False
)
with gr.Accordion('Advanced Configuration', open=False):
self.advanced_training = AdvancedTraining(headless=headless)
self.advanced_training.color_aug.change(
Expand Down
Loading
Loading