diff --git a/.release b/.release index e308a340c..535e1cca7 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v22.5.0 +v22.6.0 diff --git a/README.md b/README.md index bddda2d54..c95a930a7 100644 --- a/README.md +++ b/README.md @@ -47,11 +47,6 @@ The GUI allows you to set the training parameters and generate and run the requi - [ControlNet-LLLite](#controlnet-lllite) - [Sample image generation during training](#sample-image-generation-during-training-1) - [Change History](#change-history) - - [Jan 15, 2024 / 2024/1/15: v0.8.0](#jan-15-2024--2024115-v080) - - [Naming of LoRA](#naming-of-lora) - - [LoRAの名称について](#loraの名称について) - - [Sample image generation during training](#sample-image-generation-during-training-2) - - [Change History](#change-history-1) ## 🦒 Colab @@ -109,7 +104,7 @@ Please note that the CUDNN 8.6 DLLs needed for this process cannot be hosted on 1. Unzip the downloaded file and place the `cudnn_windows` folder in the root directory of the `kohya_ss` repository. -2. Run .\setup.bat and select the option to install cudann. +2. Run .\setup.bat and select the option to install cudnn. ### Linux and macOS @@ -122,7 +117,7 @@ To install the necessary dependencies on a Linux system, ensure that you fulfill apt install python3.10-venv ``` -- Install the cudaNN drivers by following the instructions provided in [this link](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64). +- Install the cudNN drivers by following the instructions provided in [this link](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64). - Make sure you have Python version 3.10.6 or higher (but lower than 3.11.0) installed on your system. @@ -484,78 +479,7 @@ save_file(state_dict, file) ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. -<<<<<<< HEAD ### Sample image generation during training -======= - -## Change History - -### Jan 15, 2024 / 2024/1/15: v0.8.0 - -- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). - - Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded. -- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev! - - This feature works only on Linux or WSL. - - Please specify `--torch_compile` option in each training script. - - You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work. - - Please use `--spda` option instead of `--xformers` option. - - PyTorch 2.1 or later is recommended. - - Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details. -- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t! -- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0! -- Fixed a bug that Diffusers format model cannot be saved. - -- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。 - - 最新の Transformers を前提とした一部のモデルファイル(Text Encoder が position_id を持たないもの)が読み込めるようになりました。 -- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します。 - - Linux または WSL でのみ動作します。 - - 各学習スクリプトで `--torch_compile` オプションを指定してください。 - - `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。 - - `--xformers` オプションとは互換性がありません。 代わりに `--spda` オプションを使用してください。 - - PyTorch 2.1以降を推奨します。 - - 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。 -- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。 -- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。 -- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました。 - - -Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. -最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 - -### Naming of LoRA - -The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository. - -1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers) - - LoRA for Linear layers and Conv2d layers with 1x1 kernel - -2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers) - - In addition to 1., LoRA for Conv2d layers with 3x3 kernel - -LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI. - -To use LoRA-C3Lier with Web UI, please use our extension. - -### LoRAの名称について - -`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。 - -1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) - - Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA - -2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) - - 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA - -LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。 - -LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。 - -## Sample image generation during training ->>>>>>> 26d35794e3b858e7b5bd20d1e70547c378550b3d A prompt file might look like this, for example ``` @@ -579,6 +503,62 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b ## Change History +* 2024/01/27 (v22.6.0) +- Merge sd-scripts v0.8.3 code update + - Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380! + - `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library. + - Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf! + - Fixed a bug that the training crashes when training ControlNet-LLLite. + +- Merge sd-scripts v0.8.2 code update + - [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! + - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. + - PyTorch 2.1 or later is required. + - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. + - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. + + - [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc. + - This is an experimental option and may be removed or changed in the future. + - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. + - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. + - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. + + - Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. + - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. + - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. + + - The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf! + + - The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx! + + - Fixed a bug in multi-GPU Textual Inversion training. + + - `.toml` example for network multiplier + + ```toml + [general] + [[datasets]] + resolution = 512 + batch_size = 8 + network_multiplier = 1.0 + + ... subset settings ... + + [[datasets]] + resolution = 512 + batch_size = 8 + network_multiplier = -1.0 + + ... subset settings ... + ``` + +- Merge sd-scripts v0.8.1 code update + + - Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). + - Text Encoders were not moved to CPU. + + - Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) + * 2024/01/15 (v22.5.0) - Merged sd-scripts v0.8.0 updates - Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). diff --git "a/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" "b/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" index c58afaf50..68a14a921 100644 --- "a/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" +++ "b/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" @@ -1,141 +1,127 @@ -嗨!我把日语 README 文件的主要内容翻译成中文如下: +SDXL已得到支持。sdxl分支已合并到main分支。当更新仓库时,请执行升级步骤。由于accelerate版本也已经升级,请重新运行accelerate config。 -## 关于这个仓库 +有关SDXL训练的信息,请参见[此处](./README.md#sdxl-training)(英文)。 -这个是用于Stable Diffusion模型训练、图像生成和其他脚本的仓库。 +## 关于本仓库 -[英文版 README](./README.md) <-- 更新信息在这里 +用于Stable Diffusion的训练、图像生成和其他脚本的仓库。 -GUI和PowerShell脚本等使其更易用的功能在[bmaltais的仓库](https://github.com/bmaltais/kohya_ss)(英语)中提供,一并参考。感谢bmaltais。 +[英文README](./README.md) <- 更新信息在这里 + +[bmaltais的仓库](https://github.com/bmaltais/kohya_ss)中提供了GUI和PowerShell脚本等使其更易于使用的功能(英文),也请一并参阅。衷心感谢bmaltais。 包含以下脚本: -* 支持DreamBooth、U-Net和文本编码器的训练 -* fine-tuning的支持 +* 支持DreamBooth、U-Net和Text Encoder的训练 +* 微调,同上 +* 支持LoRA的训练 * 图像生成 -* 模型转换(Stable Diffusion ckpt/safetensors 和 Diffusers之间的相互转换) - -## 使用方法 (中国用户只需要按照这个安装教程操作) -- 进入kohya_ss文件夹根目录下,点击 setup.bat 启动安装程序 *(需要科学上网) -- 根据界面上给出的英文选项: -Kohya_ss GUI setup menu: - -1. Install kohya_ss gui -2. (Optional) Install cudann files (avoid unless you really need it) -3. (Optional) Install specific bitsandbytes versions -4. (Optional) Manually configure accelerate -5. (Optional) Start Kohya_ss GUI in browser -6. Quit - -Enter your choice: 1 +* 模型转换(在Stable Diffision ckpt/safetensors与Diffusers之间转换) -1. Torch 1 (legacy, no longer supported. Will be removed in v21.9.x) -2. Torch 2 (recommended) -3. Cancel - -Enter your choice: 2 - -开始安装环境依赖,接着再出来的选项,按照下列选项操作: -```txt -- This machine -- No distributed training -- NO -- NO -- NO -- all -- bf16 -``` --------------------------------------------------------------------- -这里都选择完毕,即可关闭终端窗口,直接点击 gui.bat或者 kohya中文启动器.bat 即可运行kohya +## 使用方法 - -当仓库内和note.com有相关文章,请参考那里。(未来可能全部移到这里) - -* [关于训练,通用篇](./docs/train_README-zh.md): 数据准备和选项等 - * [数据集设置](./docs/config_README-ja.md) -* [DreamBooth训练指南](./docs/train_db_README-zh.md) -* [fine-tuning指南](./docs/fine_tune_README_ja.md) -* [LoRA训练指南](./docs/train_network_README-zh.md) -* [文本反转训练指南](./docs/train_ti_README-ja.md) +* [通用部分的训练信息](./docs/train_README-ja.md): 数据准备和选项等 +* [数据集设置](./docs/config_README-ja.md) +* [DreamBooth的训练信息](./docs/train_db_README-ja.md) +* [微调指南](./docs/fine_tune_README_ja.md) +* [LoRA的训练信息](./docs/train_network_README-ja.md) +* [Textual Inversion的训练信息](./docs/train_ti_README-ja.md) * [图像生成脚本](./docs/gen_img_README-ja.md) * note.com [模型转换脚本](https://note.com/kohya_ss/n/n374f316fe4ad) -## Windows环境所需程序 +## Windows上需要的程序 需要Python 3.10.6和Git。 - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe -- git: https://git-scm.com/download/win +- git: https://git-scm.com/download/win -如果要在PowerShell中使用venv,需要按以下步骤更改安全设置: -(不仅仅是venv,使脚本可以执行。请注意。) +如果要在PowerShell中使用,请按以下步骤更改安全设置以使用venv。 +(不仅仅是venv,这使得脚本的执行成为可能,所以请注意。) -- 以管理员身份打开PowerShell -- 输入"Set-ExecutionPolicy Unrestricted",选择Y -- 关闭管理员PowerShell +- 以管理员身份打开PowerShell。 +- 输入“Set-ExecutionPolicy Unrestricted”,并回答Y。 +- 关闭管理员PowerShell。 ## 在Windows环境下安装 -下例中安装的是PyTorch 1.12.1/CUDA 11.6版。如果要使用CUDA 11.3或PyTorch 1.13,请适当修改。 +脚本已在PyTorch 2.0.1上通过测试。PyTorch 1.12.1也应该可以工作。 + +下例中,将安装PyTorch 2.0.1/CUDA 11.8版。如果使用CUDA 11.6版或PyTorch 1.12.1,请酌情更改。 -(如果只显示"python",请将下例中的"python"改为"py") +(注意,如果python -m venv~这行只显示“python”,请将其更改为py -m venv~。) -在普通(非管理员)PowerShell中依次执行以下命令: +如果使用PowerShell,请打开常规(非管理员)PowerShell并按顺序执行以下操作: ```powershell -git clone https://github.com/kohya-ss/sd-scripts.git +git clone https://github.com/kohya-ss/sd-scripts.git cd sd-scripts python -m venv venv -.\venv\Scripts\activate +.\venv\Scripts\activate -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py +pip install xformers==0.0.20 accelerate config ``` -在命令提示符中: - -```bat -git clone https://github.com/kohya-ss/sd-scripts.git -cd sd-scripts +在命令提示符下也相同。 -python -m venv venv -.\venv\Scripts\activate +(注:由于 ``python -m venv venv`` 比 ``python -m venv --system-site-packages venv`` 更安全,已进行更改。如果global python中安装了package,后者会引发各种问题。) -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl +在accelerate config的提示中,请按以下方式回答。(如果以bf16学习,最后一个问题回答bf16。) -copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py +※从0.15.0开始,在日语环境中按方向键选择会崩溃(......)。请使用数字键0、1、2......进行选择。 -accelerate config +```txt +- This machine +- No distributed training +- NO +- NO +- NO +- all +- fp16 ``` -accelerate config的问题请按以下回答: -(如果要用bf16训练,最后一个问题选择bf16) +※有时可能会出现 ``ValueError: fp16 mixed precision requires a GPU`` 错误。在这种情况下,对第6个问题 ``(What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:`` +)回答“0”。(将使用id `0`的GPU)。 -``` -- 此计算机 -- 不进行分布式训练 -- 否 -- 否 -- 否 -- 所有 -- fp16 -``` +### 可选:``bitsandbytes``(8位优化器) + +`bitsandbytes`现在是可选的。在Linux上,可以通过pip正常安装(推荐0.41.1或更高版本)。 + +在Windows上,推荐0.35.0或0.41.1。 + +- `bitsandbytes` 0.35.0: 似乎是稳定的版本。可以使用AdamW8bit,但不能使用其他一些8位优化器和`full_bf16`学习时的选项。 +- `bitsandbytes` 0.41.1: 支持 Lion8bit、PagedAdamW8bit、PagedLion8bit。可以使用`full_bf16`。 -### PyTorch和xformers版本注意事项 +注意:`bitsandbytes` 从0.35.0到0.41.0之间的版本似乎存在问题。 https://github.com/TimDettmers/bitsandbytes/issues/659 -在其他版本中训练可能失败。如果没有特殊原因,请使用指定版本。 +请按以下步骤安装`bitsandbytes`。 + +### 使用0.35.0 + +以下是PowerShell的例子。在命令提示符中,请使用copy代替cp。 + +```powershell +cd sd-scripts +.\venv\Scripts\activate +pip install bitsandbytes==0.35.0 + +cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ +cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py +cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py +``` + +### 使用0.41.1 + +请从[此处](https://github.com/jllllll/bitsandbytes-windows-webui)或其他地方安装jllllll发布的Windows whl文件。 + +```powershell +python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui +``` ### 可选:使用Lion8bit diff --git a/XTI_hijack.py b/XTI_hijack.py index ec0849455..1dbc263ac 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,7 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/config_README-zh.md b/config_README-zh.md new file mode 100644 index 000000000..1266e7b25 --- /dev/null +++ b/config_README-zh.md @@ -0,0 +1,289 @@ + + +paste.txt +For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. + +`--dataset_config` 可以传递设置文件的说明。 + +## 概述 + +通过传递设置文件,用户可以进行更详细的设置。 + +* 可以设置多个数据集 + * 例如,可以为每个数据集设置 `resolution`,并混合它们进行训练。 + * 对于支持DreamBooth方法和fine tuning方法的训练方法,可以混合使用DreamBooth方式和fine tuning方式的数据集。 +* 可以为每个子集更改设置 + * 数据子集是指根据图像目录或元数据拆分的数据集。几个子集构成一个数据集。 + * 诸如 `keep_tokens` 和 `flip_aug` 之类的选项可以为每个子集设置。另一方面,诸如 `resolution` 和 `batch_size` 之类的选项可以为每个数据集设置,属于同一数据集的子集的值将是共享的。具体细节将在下文中说明。 + +设置文件的格式可以是JSON或TOML。考虑到描述的便利性,建议使用 [TOML](https://toml.io/zh/v1.0.0-rc.2)。下面将以TOML为前提进行说明。 + +下面是一个用TOML描述的设置文件的示例。 + +```toml +[general] +shuffle_caption = true +caption_extension = '.txt' +keep_tokens = 1 + +# 这是一个DreamBooth方式的数据集 +[[datasets]] +resolution = 512 +batch_size = 4 +keep_tokens = 2 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + class_tokens = 'hoge girl' + # 此子集为keep_tokens = 2(使用属于的数据集的值) + + [[datasets.subsets]] + image_dir = 'C:\fuga' + class_tokens = 'fuga boy' + keep_tokens = 3 + + [[datasets.subsets]] + is_reg = true + image_dir = 'C:\reg' + class_tokens = 'human' + keep_tokens = 1 + +# 这是一个fine tuning方式的数据集 +[[datasets]] +resolution = [768, 768] +batch_size = 2 + + [[datasets.subsets]] + image_dir = 'C:\piyo' + metadata_file = 'C:\piyo\piyo_md.json' + # 此子集为keep_tokens = 1(使用general的值) +``` + +在此示例中,有3个目录作为DreamBooth方式的数据集以512x512(批量大小4)进行训练,有1个目录作为fine tuning方式的数据集以768x768(批量大小2)进行训练。 + +## 数据子集的设置 + +数据集和子集的相关设置分为几个可以注册的部分。 + +* `[general]` + * 适用于所有数据集或所有子集的选项指定部分。 + * 如果在数据集设置或子集设置中存在同名选项,则会优先使用数据集或子集的设置。 +* `[[datasets]]` + * `datasets`是数据集设置注册部分。这是指定适用于每个数据集的选项的部分。 + * 如果存在子集设置,则会优先使用子集设置。 +* `[[datasets.subsets]]` + * `datasets.subsets`是子集设置的注册部分。这是指定适用于每个子集的选项的部分。 + +下面是先前示例中图像目录和注册部分的对应关系图。 + +``` +C:\ +├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ +├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] +├─ reg -> [[datasets.subsets]] No.3 ┘ | +└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ +``` + +每个图像目录对应一个 `[[datasets.subsets]]`。然后1个或多个 `[[datasets.subsets]]` 构成一个 `[[datasets]]`。所有的 `[[datasets]]`、`[[datasets.subsets]]` 属于 `[general]`。 + +根据注册部分,可以指定不同的选项,但如果指定了同名选项,则下级注册部分中的值将被优先使用。建议参考先前 keep_tokens 选项的处理,这将有助于理解。 + +此外,可指定的选项会根据训练方法支持的技术而变化。 + +* 仅DreamBooth方式的选项 +* 仅fine tuning方式的选项 +* 当可以使用caption dropout技术时的选项 + +在可同时使用DreamBooth方法和fine tuning方法的训练方法中,它们可以一起使用。 +需要注意的是,判断是DreamBooth方式还是fine tuning方式是按数据集进行的,因此同一数据集中不能混合存在DreamBooth方式子集和fine tuning方式子集。 +也就是说,如果要混合使用这两种方式,则需要设置不同方式的子集属于不同的数据集。 + +从程序行为上看,如果存在 `metadata_file` 选项,则判断为fine tuning方式的子集。 +因此,对于属于同一数据集的子集,要么全部都具有 `metadata_file` 选项,要么全部都不具有 `metadata_file` 选项,这两种情况之一。 + +下面解释可用的选项。对于与命令行参数名称相同的选项,基本上省略解释。请参阅其他自述文件。 + +### 所有训练方法通用的选项 + +不管训练方法如何,都可以指定的选项。 + +#### 数据集的选项 + +与数据集设置相关的选项。不能在 `datasets.subsets` 中描述。 + +| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | +| ---- | ---- | ---- | ---- | +| `batch_size` | `1` | o | o | +| `bucket_no_upscale` | `true` | o | o | +| `bucket_reso_steps` | `64` | o | o | +| `enable_bucket` | `true` | o | o | +| `max_bucket_reso` | `1024` | o | o | +| `min_bucket_reso` | `128` | o | o | +| `resolution` | `256`, `[512, 512]` | o | o | + +* `batch_size` + * 等同于命令行参数 `--train_batch_size`。 + +这些设置对每个数据集是固定的。 +也就是说,属于该数据集的子集将共享这些设置。 +例如,如果要准备不同分辨率的数据集,可以像上面的示例那样将它们定义为单独的数据集,以便可以为它们单独设置不同的分辨率。 + +#### 子集的选项 + +与子集设置相关的选项。 + +| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `color_aug` | `false` | o | o | o | +| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | +| `flip_aug` | `true` | o | o | o | +| `keep_tokens` | `2` | o | o | o | +| `num_repeats` | `10` | o | o | o | +| `random_crop` | `false` | o | o | o | +| `shuffle_caption` | `true` | o | o | o | +| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | +| `caption_suffix` | `“, from side”` | o | o | o | + +* `num_repeats` + * 指定子集中图像的重复次数。在fine tuning中相当于 `--dataset_repeats`,但 `num_repeats` + 在任何训练方法中都可以指定。 +* `caption_prefix`, `caption_suffix` + * 指定在标题前后附加的字符串。包括这些字符串的状态下进行洗牌。当指定 `keep_tokens` 时请注意。 + +### 仅DreamBooth方式的选项 + +DreamBooth方式的选项仅存在于子集选项中。 + +#### 子集的选项 + +DreamBooth方式子集的设置相关选项。 + +| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o(必需) | +| `caption_extension` | `".txt"` | o | o | o | +| `class_tokens` | `“sks girl”` | - | - | o | +| `is_reg` | `false` | - | - | o | + +首先需要注意的是,`image_dir` 必须指定图像文件直接放置在路径下。与以往的 DreamBooth 方法不同,它与放置在子目录中的图像不兼容。另外,即使使用像 `5_cat` 这样的文件夹名,也不会反映图像的重复次数和类名。如果要单独设置这些,需要使用 `num_repeats` 和 `class_tokens` 明确指定。 + +* `image_dir` + * 指定图像目录路径。必填选项。 + * 图像需要直接放置在目录下。 +* `class_tokens` + * 设置类标记。 + * 仅在不存在与图像对应的 caption 文件时才在训练中使用。是否使用是按图像判断的。如果没有指定 `class_tokens` 且也找不到 caption 文件,将报错。 +* `is_reg` + * 指定子集中的图像是否用于正则化。如果未指定,则视为 `false`,即视为非正则化图像。 + +### 仅fine tuning方式的选项 + +fine tuning方式的选项仅存在于子集选项中。 + +#### 子集的选项 + +fine tuning方式子集的设置相关选项。 + +| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | ---- | +| `image_dir` | `‘C:\hoge’` | - | - | o | +| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必需) | + +* `image_dir` + * 指定图像目录路径。与 DreamBooth 方式不同,不是必填的,但建议设置。 + * 不需要指定的情况是在生成元数据时指定了 `--full_path`。 + * 图像需要直接放置在目录下。 +* `metadata_file` + * 指定子集使用的元数据文件路径。必填选项。 + * 与命令行参数 `--in_json` 等效。 + * 由于子集需要按元数据文件指定,因此避免跨目录创建一个元数据文件。强烈建议为每个图像目录准备元数据文件,并将它们作为单独的子集进行注册。 + +### 当可使用caption dropout技术时可以指定的选项 + +当可以使用 caption dropout 技术时的选项仅存在于子集选项中。 +无论是 DreamBooth 方式还是 fine tuning 方式,只要训练方法支持 caption dropout,就可以指定。 + +#### 子集的选项 + +可使用 caption dropout 的子集的设置相关选项。 + +| 选项名称 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | +| ---- | ---- | ---- | ---- | +| `caption_dropout_every_n_epochs` | o | o | o | +| `caption_dropout_rate` | o | o | o | +| `caption_tag_dropout_rate` | o | o | o | + +## 存在重复子集时的行为 + +对于 DreamBooth 方式的数据集,如果其中 `image_dir` 相同的子集将被视为重复。 +对于 fine tuning 方式的数据集,如果其中 `metadata_file` 相同的子集将被视为重复。 +如果数据集中存在重复的子集,则从第二个开始将被忽略。 + +另一方面,如果它们属于不同的数据集,则不会被视为重复。 +例如,像下面这样在不同的数据集中放置相同的 `image_dir` 的子集,不会被视为重复。 +这在想要以不同分辨率训练相同图像的情况下很有用。 + +```toml +# 即使存在于不同的数据集中,也不会被视为重复,两者都将用于训练 + +[[datasets]] +resolution = 512 + + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 768 + + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +## 与命令行参数的组合使用 + +设置文件中的某些选项与命令行参数选项的作用相同。 + +如果传递了设置文件,则会忽略下列命令行参数选项。 + +* `--train_data_dir` +* `--reg_data_dir` +* `--in_json` + +如果同时在命令行参数和设置文件中指定了下列命令行参数选项,则设置文件中的值将优先于命令行参数的值。除非另有说明,否则它们是同名选项。 + +| 命令行参数选项 | 优先的设置文件选项 | +| ---- | ---- | +| `--bucket_no_upscale` | | +| `--bucket_reso_steps` | | +| `--caption_dropout_every_n_epochs` | | +| `--caption_dropout_rate` | | +| `--caption_extension` | | +| `--caption_tag_dropout_rate` | | +| `--color_aug` | | +| `--dataset_repeats` | `num_repeats` | +| `--enable_bucket` | | +| `--face_crop_aug_range` | | +| `--flip_aug` | | +| `--keep_tokens` | | +| `--min_bucket_reso` | | +| `--random_crop`| | +| `--resolution` | | +| `--shuffle_caption` | | +| `--train_batch_size` | `batch_size` | + +## 错误指南 + +目前,正在使用外部库来检查设置文件是否正确,但维护不够充分,错误消息难以理解。 +计划在将来改进此问题。 + +作为次佳方案,列出一些常见错误和解决方法。 +如果确信设置正确但仍出现错误,或者完全不明白错误内容,可能是bug,请联系我们。 + +* `voluptuous.error.MultipleInvalid: required key not provided @ ...`:: 缺少必填选项的错误。可能忘记指定或错误输入了选项名。 + * `...`部分显示了错误发生位置。例如,`voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` 意味着在第0个 `datasets` 的第0个 `subsets` 的设置中缺失 `image_dir`。 +* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`:值的格式错误。输入值的格式可能错误。可以参考本 README 中选项的「设置示例」部分。 +* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`:存在不支持的选项名。可能错误输入或误输入了选项名。 + + + + \ No newline at end of file diff --git a/fine_tune.py b/fine_tune.py index be61b3d16..982dc8aec 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -11,15 +11,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847a6..a207ad5a1 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,15 +66,10 @@ import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/library/class_advanced_training.py b/library/class_advanced_training.py index fe1620445..df432209a 100644 --- a/library/class_advanced_training.py +++ b/library/class_advanced_training.py @@ -3,9 +3,10 @@ class AdvancedTraining: - def __init__(self, headless=False, finetuning: bool = False): + def __init__(self, headless=False, finetuning: bool = False, training_type: str = ""): self.headless = headless self.finetuning = finetuning + self.training_type = training_type def noise_offset_type_change(noise_offset_type): if noise_offset_type == 'Original': @@ -105,6 +106,14 @@ def full_options_update(full_fp16, full_bf16): ], value='75', ) + + with gr.Row(): + if training_type == "lora": + self.fp8_base = gr.Checkbox( + label='fp8 base training (experimental)', + info="U-Net and Text Encoder can be trained with fp8 (experimental)", + value=False, + ) self.full_fp16 = gr.Checkbox( label='Full fp16 training (experimental)', value=False, @@ -114,6 +123,7 @@ def full_options_update(full_fp16, full_bf16): value=False, info='Required bitsandbytes >= 0.36.0', ) + self.full_fp16.change( full_options_update, inputs=[self.full_fp16, self.full_bf16], diff --git a/library/common_gui.py b/library/common_gui.py index 9a1d69dd2..f39cf102c 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -816,6 +816,10 @@ def run_cmd_advanced_training(**kwargs): if gradient_checkpointing: run_cmd += ' --gradient_checkpointing' + fp8_base = kwargs.get('fp8_base') + if fp8_base: + run_cmd += ' --fp8_base' + full_fp16 = kwargs.get('full_fp16') if full_fp16: run_cmd += ' --full_fp16' diff --git a/library/config_util.py b/library/config_util.py index 47868f3ba..a98c2b90d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -1,462 +1,499 @@ import argparse from dataclasses import ( - asdict, - dataclass, + asdict, + dataclass, ) import functools import random from textwrap import dedent, indent import json from pathlib import Path + # from toolz import curry from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, + List, + Optional, + Sequence, + Tuple, + Union, ) import toml import voluptuous from voluptuous import ( - Any, - ExactSequence, - MultipleInvalid, - Object, - Required, - Schema, + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, ) from transformers import CLIPTokenizer from . import train_util from .train_util import ( - DreamBoothSubset, - FineTuningSubset, - ControlNetSubset, - DreamBoothDataset, - FineTuningDataset, - ControlNetDataset, - DatasetGroup, + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, ) def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + # TODO: inherit Params class in Subset, Dataset + @dataclass class BaseSubsetParams: - image_dir: Optional[str] = None - num_repeats: int = 1 - shuffle_caption: bool = False - caption_separator: str = ',', - keep_tokens: int = 0 - keep_tokens_separator: str = None, - color_aug: bool = False - flip_aug: bool = False - face_crop_aug_range: Optional[Tuple[float, float]] = None - random_crop: bool = False - caption_prefix: Optional[str] = None - caption_suffix: Optional[str] = None - caption_dropout_rate: float = 0.0 - caption_dropout_every_n_epochs: int = 0 - caption_tag_dropout_rate: float = 0.0 - token_warmup_min: int = 1 - token_warmup_step: float = 0 + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + @dataclass class DreamBoothSubsetParams(BaseSubsetParams): - is_reg: bool = False - class_tokens: Optional[str] = None - caption_extension: str = ".caption" + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + @dataclass class FineTuningSubsetParams(BaseSubsetParams): - metadata_file: Optional[str] = None + metadata_file: Optional[str] = None + @dataclass class ControlNetSubsetParams(BaseSubsetParams): - conditioning_data_dir: str = None - caption_extension: str = ".caption" + conditioning_data_dir: str = None + caption_extension: str = ".caption" + @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class ControlNetDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + @dataclass class SubsetBlueprint: - params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + @dataclass class DatasetBlueprint: - is_dreambooth: bool - is_controlnet: bool - params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] - subsets: Sequence[SubsetBlueprint] + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + @dataclass class DatasetGroupBlueprint: - datasets: Sequence[DatasetBlueprint] + datasets: Sequence[DatasetBlueprint] + + @dataclass class Blueprint: - dataset_group: DatasetGroupBlueprint + dataset_group: DatasetGroupBlueprint class ConfigSanitizer: - # @curry - @staticmethod - def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: - Schema(ExactSequence([klass, klass]))(value) - return tuple(value) - - # @curry - @staticmethod - def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: - Schema(Any(klass, ExactSequence([klass, klass])))(value) - try: - Schema(klass)(value) - return (value, value) - except: - return ConfigSanitizer.__validate_and_convert_twodim(klass, value) - - # subset schema - SUBSET_ASCENDABLE_SCHEMA = { - "color_aug": bool, - "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), - "flip_aug": bool, - "num_repeats": int, - "random_crop": bool, - "shuffle_caption": bool, - "keep_tokens": int, - "keep_tokens_separator": str, - "token_warmup_min": int, - "token_warmup_step": Any(float,int), - "caption_prefix": str, - "caption_suffix": str, - } - # DO means DropOut - DO_SUBSET_ASCENDABLE_SCHEMA = { - "caption_dropout_every_n_epochs": int, - "caption_dropout_rate": Any(float, int), - "caption_tag_dropout_rate": Any(float, int), - } - # DB means DreamBooth - DB_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - "class_tokens": str, - } - DB_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - "is_reg": bool, - } - # FT means FineTuning - FT_SUBSET_DISTINCT_SCHEMA = { - Required("metadata_file"): str, - "image_dir": str, - } - CN_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - } - CN_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - Required("conditioning_data_dir"): str, - } - - # datasets schema - DATASET_ASCENDABLE_SCHEMA = { - "batch_size": int, - "bucket_no_upscale": bool, - "bucket_reso_steps": int, - "enable_bucket": bool, - "max_bucket_reso": int, - "min_bucket_reso": int, - "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), - } - - # options handled by argparse but not handled by user config - ARGPARSE_SPECIFIC_SCHEMA = { - "debug_dataset": bool, - "max_token_length": Any(None, int), - "prior_loss_weight": Any(float, int), - } - # for handling default None value of argparse - ARGPARSE_NULLABLE_OPTNAMES = [ - "face_crop_aug_range", - "resolution", - ] - # prepare map because option name may differ among argparse and user config - ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { - "train_batch_size": "batch_size", - "dataset_repeats": "num_repeats", - } - - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" - - self.db_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_DISTINCT_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.ft_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.FT_SUBSET_DISTINCT_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.cn_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_DISTINCT_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.db_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.db_subset_schema]}, - ) - - self.ft_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.ft_subset_schema]}, - ) - - self.cn_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.cn_subset_schema]}, - ) - - if support_dreambooth and support_finetuning: - def validate_flex_dataset(dataset_config: dict): - subsets_config = dataset_config.get("subsets", []) - - if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): - return Schema(self.cn_dataset_schema)(dataset_config) - # check dataset meets FT style - # NOTE: all FT subsets should have "metadata_file" - elif all(["metadata_file" in subset for subset in subsets_config]): - return Schema(self.ft_dataset_schema)(dataset_config) - # check dataset meets DB style - # NOTE: all DB subsets should have no "metadata_file" - elif all(["metadata_file" not in subset for subset in subsets_config]): - return Schema(self.db_dataset_schema)(dataset_config) - else: - raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") - - self.dataset_schema = validate_flex_dataset - elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema - elif support_finetuning: - self.dataset_schema = self.ft_dataset_schema - elif support_controlnet: - self.dataset_schema = self.cn_dataset_schema - - self.general_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, - self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.user_config_validator = Schema({ - "general": self.general_schema, - "datasets": [self.dataset_schema], - }) - - self.argparse_schema = self.__merge_dict( - self.general_schema, - self.ARGPARSE_SPECIFIC_SCHEMA, - {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, - {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, - ) - - self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) - - def sanitize_user_config(self, user_config: dict) -> dict: - try: - return self.user_config_validator(user_config) - except MultipleInvalid: - # TODO: エラー発生時のメッセージをわかりやすくする - print("Invalid user config / ユーザ設定の形式が正しくないようです") - raise - - # NOTE: In nature, argument parser result is not needed to be sanitize - # However this will help us to detect program bug - def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: - try: - return self.argparse_config_validator(argparse_namespace) - except MultipleInvalid: - # XXX: this should be a bug - print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") - raise - - # NOTE: value would be overwritten by latter dict if there is already the same key - @staticmethod - def __merge_dict(*dict_list: dict) -> dict: - merged = {} - for schema in dict_list: - # merged |= schema - for k, v in schema.items(): - merged[k] = v - return merged + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert ( + support_dreambooth or support_finetuning or support_controlnet + ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + print("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged class BlueprintGenerator: - BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { - } - - def __init__(self, sanitizer: ConfigSanitizer): - self.sanitizer = sanitizer - - # runtime_params is for parameters which is only configurable on runtime, such as tokenizer - def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: - sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) - sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) - - # convert argparse namespace to dict like config - # NOTE: it is ok to have extra entries in dict - optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME - argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} - - general_config = sanitized_user_config.get("general", {}) - - dataset_blueprints = [] - for dataset_config in sanitized_user_config.get("datasets", []): - # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets - subsets = dataset_config.get("subsets", []) - is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) - if is_controlnet: - subset_params_klass = ControlNetSubsetParams - dataset_params_klass = ControlNetDatasetParams - elif is_dreambooth: - subset_params_klass = DreamBoothSubsetParams - dataset_params_klass = DreamBoothDatasetParams - else: - subset_params_klass = FineTuningSubsetParams - dataset_params_klass = FineTuningDatasetParams - - subset_blueprints = [] - for subset_config in subsets: - params = self.generate_params_by_fallbacks(subset_params_klass, - [subset_config, dataset_config, general_config, argparse_config, runtime_params]) - subset_blueprints.append(SubsetBlueprint(params)) - - params = self.generate_params_by_fallbacks(dataset_params_klass, - [dataset_config, general_config, argparse_config, runtime_params]) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) - - dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) - - return Blueprint(dataset_group_blueprint) - - @staticmethod - def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): - name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME - search_value = BlueprintGenerator.search_value - default_params = asdict(param_klass()) - param_names = default_params.keys() - - params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} - - return param_klass(**params) - - @staticmethod - def search_value(key: str, fallbacks: Sequence[dict], default_value = None): - for cand in fallbacks: - value = cand.get(key) - if value is not None: - return value - - return default_value + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - - for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset - elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset - else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) - datasets.append(dataset) - - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ [Dataset {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} - """) - - if dataset.enable_bucket: - info += indent(dedent(f"""\ + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") - else: - info += "\n" + \n""" + ), + " ", + ) + else: + info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} @@ -475,147 +512,176 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: - info += indent(dedent(f"""\ + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ metadata_file: {subset.metadata_file} - \n"""), " ") + \n""" + ), + " ", + ) - print(info) + print(info) - # make buckets first because it determines the length of dataset - # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no - for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") - dataset.make_buckets() - dataset.set_seed(seed) + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) - return DatasetGroup(datasets) + return DatasetGroup(datasets) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): - def extract_dreambooth_params(name: str) -> Tuple[int, str]: - tokens = name.split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") - return 0, "" - caption_by_folder = '_'.join(tokens[1:]) - return n_repeats, caption_by_folder - - def generate(base_dir: Optional[str], is_reg: bool): - if base_dir is None: - return [] - - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config subsets_config = [] - for subdir in base_dir.iterdir(): - if not subdir.is_dir(): - continue - - num_repeats, class_tokens = extract_dreambooth_params(subdir.name) - if num_repeats < 1: - continue - - subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir, False) - subsets_config += generate(reg_data_dir, True) - return subsets_config +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] -def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"): - def generate(base_dir: Optional[str]): - if base_dir is None: - return [] + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] + return subsets_config subsets_config = [] - subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1} - subsets_config.append(subset_config) + subsets_config += generate(train_data_dir) return subsets_config - subsets_config = [] - subsets_config += generate(train_data_dir) - return subsets_config +def load_user_config(file: str) -> dict: + file: Path = Path(file) + if not file.is_file(): + raise ValueError(f"file not found / ファイルが見つかりません: {file}") + + if file.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + print( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file.name.lower().endswith(".toml"): + try: + config = toml.load(file) + except Exception: + print( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") + return config -def load_user_config(file: str) -> dict: - file: Path = Path(file) - if not file.is_file(): - raise ValueError(f"file not found / ファイルが見つかりません: {file}") - - if file.name.lower().endswith('.json'): - try: - with open(file, 'r') as f: - config = json.load(f) - except Exception: - print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - elif file.name.lower().endswith('.toml'): - try: - config = toml.load(file) - except Exception: - print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - else: - raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") - - return config # for config test if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--support_dreambooth", action="store_true") - parser.add_argument("--support_finetuning", action="store_true") - parser.add_argument("--support_controlnet", action="store_true") - parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("dataset_config") - config_args, remain = parser.parse_known_args() - - parser = argparse.ArgumentParser() - train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) - train_util.add_training_arguments(parser, config_args.support_dreambooth) - argparse_namespace = parser.parse_args(remain) - train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - print("[argparse_namespace]") - print(vars(argparse_namespace)) + print("[argparse_namespace]") + print(vars(argparse_namespace)) - user_config = load_user_config(config_args.dataset_config) + user_config = load_user_config(config_args.dataset_config) - print("\n[user_config]") - print(user_config) + print("\n[user_config]") + print(user_config) - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) - sanitized_user_config = sanitizer.sanitize_user_config(user_config) + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) - print("\n[sanitized_user_config]") - print(sanitized_user_config) + print("\n[sanitized_user_config]") + print(sanitized_user_config) - blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - print("\n[blueprint]") - print(blueprint) + print("\n[blueprint]") + print(blueprint) diff --git a/library/ipex_interop.py b/library/ipex_interop.py new file mode 100644 index 000000000..6fe320c57 --- /dev/null +++ b/library/ipex_interop.py @@ -0,0 +1,24 @@ +import torch + + +def init_ipex(): + """ + Try to import `intel_extension_for_pytorch`, and apply + the hijacks using `library.ipex.ipex_init`. + + If IPEX is not installed, this function does nothing. + """ + try: + import intel_extension_for_pytorch as ipex # noqa + except ImportError: + return + + try: + from library.ipex import ipex_init + + if torch.xpu.is_available(): + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 1f40ce324..4361b4994 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,15 +5,9 @@ import os import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass +init_ipex() import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/library/original_unet.py b/library/original_unet.py index 00997e7c0..030c5c9ec 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1262,9 +1262,9 @@ def set_use_memory_efficient_attention(self, xformers, mem_eff): for attn in self.attentions: attn.set_use_memory_efficient_attention(xformers, mem_eff) - def set_use_sdpa(self, spda): + def set_use_sdpa(self, sdpa): for attn in self.attentions: - attn.set_use_sdpa(spda) + attn.set_use_sdpa(sdpa) def forward( self, diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee4056..03b182566 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ def __call__( if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ def __call__( if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents): diff --git a/library/train_util.py b/library/train_util.py index ff161feab..ba428e508 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -558,6 +558,7 @@ def __init__( tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], max_token_length: int, resolution: Optional[Tuple[int, int]], + network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() @@ -567,6 +568,7 @@ def __init__( self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] @@ -1106,7 +1108,9 @@ def __getitem__(self, index): for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + loss_weights.append( + self.prior_loss_weight if image_info.is_reg else 1.0 + ) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1272,6 +1276,8 @@ def __getitem__(self, index): example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) example["flippeds"] = flippeds + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1346,15 +1352,16 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1520,14 +1527,15 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1724,14 +1732,15 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -1765,6 +1774,7 @@ def __init__( tokenizer, max_token_length, resolution, + network_multiplier, enable_bucket, min_bucket_reso, max_bucket_reso, @@ -2039,6 +2049,8 @@ def debug_dataset(train_dataset, show_input_ids=False): print( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) + if "network_multipliers" in example: + print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: print(f"input ids: {iid}") @@ -2105,8 +2117,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2850,14 +2862,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") parser.add_argument( - "--dynamo_backend", - type=str, - default="inductor", + "--dynamo_backend", + type=str, + default="inductor", # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], - help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( @@ -2904,6 +2916,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") parser.add_argument( "--ddp_timeout", type=int, @@ -3886,7 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) - + # torch.compile のオプション。 NO の場合は torch.compile は使わない dynamo_backend = "NO" if args.torch_compile: diff --git a/lora_gui.py b/lora_gui.py index 1e403ff32..413d92174 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -83,6 +83,7 @@ def save_configuration( caption_extension, enable_bucket, gradient_checkpointing, + fp8_base, full_fp16, no_token_padding, stop_text_encoder_training, @@ -245,6 +246,7 @@ def open_configuration( caption_extension, enable_bucket, gradient_checkpointing, + fp8_base, full_fp16, no_token_padding, stop_text_encoder_training, @@ -438,6 +440,7 @@ def train_model( caption_extension, enable_bucket, gradient_checkpointing, + fp8_base, full_fp16, no_token_padding, stop_text_encoder_training_pct, @@ -1048,6 +1051,7 @@ def train_model( color_aug=color_aug, shuffle_caption=shuffle_caption, gradient_checkpointing=gradient_checkpointing, + fp8_base=fp8_base, full_fp16=full_fp16, xformers=xformers, # use_8bit_adam=use_8bit_adam, @@ -1802,7 +1806,7 @@ def update_LoRA_settings( placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", info="Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.", ) - advanced_training = AdvancedTraining(headless=headless) + advanced_training = AdvancedTraining(headless=headless, training_type="lora") advanced_training.color_aug.change( color_aug_changed, inputs=[advanced_training.color_aug], @@ -1909,6 +1913,7 @@ def update_LoRA_settings( basic_training.caption_extension, basic_training.enable_bucket, advanced_training.gradient_checkpointing, + advanced_training.fp8_base, advanced_training.full_fp16, advanced_training.no_token_padding, basic_training.stop_text_encoder_training, diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 6357df55d..b9027adba 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -43,6 +43,9 @@ def svd( clamp_quantile=0.99, min_diff=0.01, no_metadata=False, + load_precision=None, + load_original_model_to=None, + load_tuned_model_to=None, ): def str_to_dtype(p): if p == "float": @@ -57,28 +60,51 @@ def str_to_dtype(p): if v_parameterization is None: v_parameterization = v2 + load_dtype = str_to_dtype(load_precision) if load_precision else None save_dtype = str_to_dtype(save_precision) + work_device = "cpu" # load models if not sdxl: print(f"loading original SD model : {model_org}") text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] + if load_dtype is not None: + text_encoder_o = text_encoder_o.to(load_dtype) + unet_o = unet_o.to(load_dtype) + print(f"loading tuned SD model : {model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] + if load_dtype is not None: + text_encoder_t = text_encoder_t.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: + device_org = load_original_model_to if load_original_model_to else "cpu" + device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org ) text_encoders_o = [text_encoder_o1, text_encoder_o2] + if load_dtype is not None: + text_encoder_o1 = text_encoder_o1.to(load_dtype) + text_encoder_o2 = text_encoder_o2.to(load_dtype) + unet_o = unet_o.to(load_dtype) + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned ) text_encoders_t = [text_encoder_t1, text_encoder_t2] + if load_dtype is not None: + text_encoder_t1 = text_encoder_t1.to(load_dtype) + text_encoder_t2 = text_encoder_t2.to(load_dtype) + unet_t = unet_t.to(load_dtype) + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha @@ -100,38 +126,54 @@ def str_to_dtype(p): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) + + # clear weight to save memory + module_o.weight = None + module_t.weight = None # Text Encoder might be same if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - diff = diff.float() diffs[lora_name] = diff + # clear target Text Encoder to save memory + for text_encoder in text_encoders_t: + del text_encoder + if not text_encoder_different: print("Text encoder is same. Extract U-Net only.") lora_network_o.text_encoder_loras = [] - diffs = {} + diffs = {} # clear diffs for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): lora_name = lora_o.lora_name module_o = lora_o.org_module module_t = lora_t.org_module - diff = module_t.weight - module_o.weight - diff = diff.float() + diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - if args.device: - diff = diff.to(args.device) + # clear weight to save memory + module_o.weight = None + module_t.weight = None diffs[lora_name] = diff + # clear LoRA network, target U-Net to save memory + del lora_network_o + del lora_network_t + del unet_t + # make LoRA with svd print("calculating by svd") lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): + if args.device: + mat = mat.to(args.device) + mat = mat.to(torch.float) # calc by float + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] @@ -171,8 +213,8 @@ def str_to_dtype(p): U = U.reshape(out_dim, rank, 1, 1) Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - U = U.to("cpu").contiguous() - Vh = Vh.to("cpu").contiguous() + U = U.to(work_device, dtype=save_dtype).contiguous() + Vh = Vh.to(work_device, dtype=save_dtype).contiguous() lora_weights[lora_name] = (U, Vh) @@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" ) + parser.add_argument( + "--load_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" + ) parser.add_argument( "--save_precision", type=str, @@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--load_original_model_to", + type=str, + default=None, + help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) + parser.add_argument( + "--load_tuned_model_to", + type=str, + default=None, + help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", + ) return parser diff --git a/requirements.txt b/requirements.txt index c2be38700..0cf9d11f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,7 +34,7 @@ opencv-python==4.7.0.68 prodigyopt==1.0 pytorch-lightning==1.9.0 rich==13.4.1 -safetensors==0.3.1 +safetensors==0.4.2 timm==0.6.12 tk==0.1.0 toml==0.10.2 diff --git a/requirements_windows_torch2.txt b/requirements_windows_torch2.txt index a093afbff..6b7c5edb8 100644 --- a/requirements_windows_torch2.txt +++ b/requirements_windows_torch2.txt @@ -1,5 +1,6 @@ -torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 # no_verify -xformers==0.0.21 +# torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 # no_verify +torch==2.1.2+cu118 torchvision==0.16.2+cu118 torchaudio==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 bitsandbytes==0.41.1 # no_verify # https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl # no_verify tensorboard==2.14.1 tensorflow==2.14.0 diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab5399842..0db9e340e 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,15 +18,10 @@ import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd65..15a70678f 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,13 +9,11 @@ from einops import repeat import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index b4ce2770e..a3f6f3a17 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,15 +11,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3cd..7a88feb84 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,13 +14,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377ba..b94bf5c1b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,13 +11,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_network.py b/sdxl_train_network.py index a35779d00..5d363280d 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,15 +1,10 @@ import argparse import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network @@ -95,8 +90,8 @@ def cache_text_encoder_outputs_if_needed( unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device) - text_encoders[1].to(accelerator.device) + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f8a1d7bce..df3937135 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,13 +3,9 @@ import regex import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/setup/setup_windows.py b/setup/setup_windows.py index f11e3a9e4..4dc1c0329 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -14,7 +14,7 @@ RESET_COLOR = "\033[0m" -def cudann_install(): +def cudnn_install(): cudnn_src = os.path.join( os.path.dirname(os.path.realpath(__file__)), "..\cudnn_windows" ) @@ -144,7 +144,7 @@ def main_menu(): while True: print("\nKohya_ss GUI setup menu:\n") print("1. Install kohya_ss gui") - print("2. (Optional) Install cudann files (avoid unless you really need it)") + print("2. (Optional) Install cudnn files (avoid unless you really need it)") print("3. (Optional) Install specific bitsandbytes versions") print("4. (Optional) Manually configure accelerate") print("5. (Optional) Start Kohya_ss GUI in browser") @@ -156,7 +156,7 @@ def main_menu(): if choice == "1": install_kohya_ss_torch2() elif choice == "2": - cudann_install() + cudnn_install() elif choice == "3": while True: print("1. (Optional) Force installation of bitsandbytes 0.35.0") diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab7a..7b0b2bbfe 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -12,15 +12,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index 14d9dff13..888cad25e 100644 --- a/train_db.py +++ b/train_db.py @@ -12,15 +12,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index a75299cda..8d102ae8f 100644 --- a/train_network.py +++ b/train_network.py @@ -14,15 +14,10 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util @@ -117,7 +112,7 @@ def cache_text_encoder_outputs_if_needed( self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): for t_enc in text_encoders: - t_enc.to(accelerator.device) + t_enc.to(accelerator.device, dtype=weight_dtype) def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): input_ids = batch["input_ids"].to(accelerator.device) @@ -278,6 +273,7 @@ def train(self, args): accelerator.wait_for_everyone() # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu self.cache_text_encoder_outputs_if_needed( args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype ) @@ -309,6 +305,7 @@ def train(self, args): ) if network is None: return + network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) @@ -389,17 +386,33 @@ def train(self, args): accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != "no" + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) - # acceleratorがなんかよろしくやってくれるらしい - # TODO めちゃくちゃ冗長なのでコードを整理する + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: if len(text_encoders) > 1: text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] @@ -407,8 +420,8 @@ def train(self, args): text_encoder = accelerator.prepare(text_encoder) text_encoders = [text_encoder] else: - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: @@ -421,9 +434,6 @@ def train(self, args): if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -685,7 +695,7 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -754,7 +764,17 @@ def remove_model(old_ckpt_name): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor - b_size = latents.shape[0] + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + accelerator.unwrap_model(network).set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning @@ -778,10 +798,24 @@ def remove_model(old_ckpt_name): args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: @@ -808,10 +842,11 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) - self.all_reduce_network(accelerator, network) # sync DDP grad manually - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(network).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1d..441c1e00b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,15 +8,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer @@ -505,7 +500,7 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: - init_kwargs['wandb'] = {'name': args.wandb_run_name} + init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( @@ -730,14 +725,13 @@ def remove_model(old_ckpt_name): is_main_process = accelerator.is_main_process if is_main_process: text_encoder = accelerator.unwrap_model(text_encoder) + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() accelerator.end_training() if args.save_state and is_main_process: train_util.save_state_on_train_end(args, accelerator) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b43549d..7046a4808 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,13 +8,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler