From 3e023952dfb5217c679d248e51ffeac487afdd7f Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 9 Mar 2024 09:30:20 -0500 Subject: [PATCH] Dev pure (#2039) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove library folder, update setup_windows * Cleanup sd-scripts files * set sd-scripts release via file * More cleanup * Point to sd-scripts python files * Fix blip captioning working directory * Fix typo * Fix issue with verify lora * Improve git cloning of sd-scripts repo * Integrate cloning in linux setup script * Remove unnecessary folder * Update linux setup * Removing unnecessary code * Fix validation issue under linux * Update tensorboard version linux * Updating linux requirements * Update README * Added submodule for sd-scripts * Update for proper submodule handling * Simplify submodule init * Update comment * Improve WSL2 LD_LIBRARY_PATH handling * Update tensorboard handling of logging_dir * chore(docker): update TensorFlow and TensorBoard versions - Upgrade TensorFlow version in Dockerfile and requirements to 2.15.0.post1 - Update TensorBoard version in requirements to 2.15.2 Signed-off-by: 陳鈞 * chore(docker): Add a mount for `sd-scripts` in the Dockerfile during pip install - Add a mount step for `sd-scripts` in the Dockerfile during pip install. Signed-off-by: 陳鈞 * doc(docker): Update readme about docker * Fix issue with TE LR * Fix issue with TE LR * Update sd-scripts to Feb 27 2024 commit * Update version of bnb for windows * Update windows setup * Update windows setup * Update bnb linux to match windows * remove height:auto * update UI * unify Quick Pick + custom model selection into gr.Dropdown() * add list_dirs(), list_files() * use gr.Accordion() * always show model + training image folder + output name + minimal default settings * do not show hidden files * clean up list_*() * use PYTHONPATH env variable * add "tool" classname to gr.Button() for sd-webui compatible * update *_caption_gui.py using create_refresh_buttons() * update Utilities->Group Images * update Utilities->Convert model * update dataset gui * update Dreambooth/LoRA folders preparation * update *lora*_gui UI using list_files(), create_refresh_button() * more fixes for PYTHONPATH * Minor edits to basic caption and common gui * Optimising some code * Fix path issue * Fix issues with list_images_dir * Update finetuning image folder label * Minor updates to UI --------- Signed-off-by: 陳鈞 Co-authored-by: 陳鈞 Co-authored-by: Won-Kyu Park --- .augmentignore | 15 + .gitignore | 3 +- .gitmodules | 3 + .release | 2 +- .vscode/settings.json | 3 +- Dockerfile | 3 +- README-ja.md | 154 - README.md | 354 +- ...55\346\226\207\346\225\231\347\250\213.md" | 174 - XTI_hijack.py | 204 - config_README-zh.md | 289 - converted_markdown.md | 1726 ------ fine_tune.py | 505 -- fine_tune_README.md | 504 -- finetune/blip/blip.py | 244 - finetune/blip/med.py | 955 ---- finetune/blip/med_config.json | 22 - finetune/blip/vit.py | 305 - finetune/clean_captions_and_tags.py | 194 - finetune/hypernetwork_nai.py | 96 - finetune/make_captions.py | 210 - finetune/make_captions_by_git.py | 183 - finetune/merge_captions_to_metadata.py | 80 - finetune/merge_dd_tags_to_metadata.py | 75 - finetune/prepare_buckets_latents.py | 265 - finetune/tag_images_by_wd14_tagger.py | 386 -- gen_img.py | 3326 ----------- gen_img_diffusers.py | 3866 ------------- gui.sh | 13 + kohya_gui/basic_caption_gui.py | 63 +- kohya_gui/blip_caption_gui.py | 112 +- kohya_gui/class_advanced_training.py | 54 +- kohya_gui/class_basic_training.py | 9 - kohya_gui/class_configuration_file.py | 64 +- kohya_gui/class_folders.py | 118 +- kohya_gui/class_lora_tab.py | 15 +- kohya_gui/class_sample_images.py | 7 +- kohya_gui/class_source_model.py | 166 +- kohya_gui/common_gui.py | 520 +- kohya_gui/convert_lcm_gui.py | 70 +- kohya_gui/convert_model_gui.py | 64 +- kohya_gui/dataset_balancing_gui.py | 31 +- kohya_gui/dreambooth_folder_creation_gui.py | 86 +- kohya_gui/dreambooth_gui.py | 249 +- kohya_gui/extract_lora_from_dylora_gui.py | 77 +- kohya_gui/extract_lora_gui.py | 116 +- kohya_gui/extract_lycoris_locon_gui.py | 117 +- kohya_gui/finetune_gui.py | 365 +- kohya_gui/git_caption_gui.py | 37 +- kohya_gui/group_images_gui.py | 54 +- kohya_gui/lora_gui.py | 112 +- kohya_gui/manual_caption_gui.py | 44 +- kohya_gui/merge_lora_gui.py | 157 +- kohya_gui/merge_lycoris_gui.py | 82 +- kohya_gui/resize_lora_gui.py | 62 +- kohya_gui/svd_merge_lora_gui.py | 137 +- kohya_gui/tensorboard_gui.py | 13 +- kohya_gui/textual_inversion_gui.py | 120 +- kohya_gui/utilities.py | 8 +- kohya_gui/verify_lora_gui.py | 39 +- kohya_gui/wd14_caption_gui.py | 39 +- library/__init__.py | 0 library/attention_processors.py | 227 - library/config_util.py | 689 --- library/custom_train_functions.py | 532 -- library/device_utils.py | 84 - library/huggingface_util.py | 84 - library/hypernetwork.py | 223 - library/ipex/__init__.py | 179 - library/ipex/attention.py | 177 - library/ipex/diffusers.py | 312 - library/ipex/gradscaler.py | 183 - library/ipex/hijacks.py | 298 - library/lpw_stable_diffusion.py | 1233 ---- library/lpw_stable_diffusion_orig.py | 1254 ---- library/model_util.py | 1356 ----- library/original_unet.py | 1919 ------- library/sai_model_spec.py | 309 - library/sdxl_lpw_stable_diffusion.py | 1347 ----- library/sdxl_model_util.py | 577 -- library/sdxl_original_unet.py | 1284 ----- library/sdxl_train_util.py | 373 -- library/slicing_vae.py | 682 --- library/train_util.py | 5064 ----------------- library/utils.py | 266 - networks/check_lora_weights.py | 48 - networks/control_net_lllite.py | 449 -- networks/control_net_lllite_for_train.py | 505 -- networks/dylora.py | 478 -- networks/extract_lora_from_dylora.py | 128 - networks/extract_lora_from_models.py | 360 -- networks/lora.py | 1253 ---- networks/lora_diffusers.py | 616 -- networks/lora_fa.py | 1244 ---- networks/lora_interrogator.py | 146 - networks/merge_lora.py | 360 -- networks/merge_lora_old.py | 190 - networks/oft.py | 433 -- networks/resize_lora.py | 411 -- networks/sdxl_merge_lora.py | 351 -- networks/svd_merge_lora.py | 262 - requirements.txt | 2 +- requirements_linux.txt | 6 +- requirements_linux_docker.txt | 4 +- requirements_windows_torch2.txt | 8 +- sd-scripts | 1 + sdxl_gen_img.py | 3210 ----------- sdxl_minimal_inference.py | 329 -- sdxl_train.py | 792 --- sdxl_train_control_net_lllite.py | 616 -- sdxl_train_control_net_lllite_old.py | 584 -- sdxl_train_network.py | 184 - sdxl_train_textual_inversion.py | 138 - setup/setup_common.py | 94 + setup/setup_linux.py | 6 + setup/setup_windows.py | 54 +- setup/validate_requirements.py | 2 + style.css | 1 - test/config/dreambooth-Prodigy-SDXL.json | 91 + tools/blip2-for-sd/README.md | 33 - tools/blip2-for-sd/caption_processor.py | 105 - tools/blip2-for-sd/main.py | 89 - tools/blip2-for-sd/requirements.txt | 28 - tools/cache_latents.py | 197 - tools/cache_text_encoder_outputs.py | 194 - tools/canny.py | 34 - tools/convert_diffusers20_original_sd.md | 46 - tools/convert_diffusers20_original_sd.py | 163 - tools/detect_face_rotate.py | 250 - tools/latent_upscaler.py | 354 -- tools/merge_models.py | 171 - tools/original_control_net.py | 353 -- tools/resize_images_to_resolution.py | 141 - tools/show_metadata.py | 23 - train_controlnet.py | 620 -- train_db.py | 504 -- train_db_README.md | 309 - train_network.py | 1058 ---- train_network_README.md | 189 - train_network_appl_weights_README.md | 39 - train_textual_inversion.py | 805 --- train_textual_inversion_XTI.py | 712 --- train_ti_README.md | 61 - 143 files changed, 2430 insertions(+), 53422 deletions(-) create mode 100644 .augmentignore create mode 100644 .gitmodules delete mode 100644 README-ja.md delete mode 100644 "README_\344\270\255\346\226\207\346\225\231\347\250\213.md" delete mode 100644 XTI_hijack.py delete mode 100644 config_README-zh.md delete mode 100644 converted_markdown.md delete mode 100644 fine_tune.py delete mode 100644 fine_tune_README.md delete mode 100644 finetune/blip/blip.py delete mode 100644 finetune/blip/med.py delete mode 100644 finetune/blip/med_config.json delete mode 100644 finetune/blip/vit.py delete mode 100644 finetune/clean_captions_and_tags.py delete mode 100644 finetune/hypernetwork_nai.py delete mode 100644 finetune/make_captions.py delete mode 100644 finetune/make_captions_by_git.py delete mode 100644 finetune/merge_captions_to_metadata.py delete mode 100644 finetune/merge_dd_tags_to_metadata.py delete mode 100644 finetune/prepare_buckets_latents.py delete mode 100644 finetune/tag_images_by_wd14_tagger.py delete mode 100644 gen_img.py delete mode 100644 gen_img_diffusers.py delete mode 100644 library/__init__.py delete mode 100644 library/attention_processors.py delete mode 100644 library/config_util.py delete mode 100644 library/custom_train_functions.py delete mode 100644 library/device_utils.py delete mode 100644 library/huggingface_util.py delete mode 100644 library/hypernetwork.py delete mode 100644 library/ipex/__init__.py delete mode 100644 library/ipex/attention.py delete mode 100644 library/ipex/diffusers.py delete mode 100644 library/ipex/gradscaler.py delete mode 100644 library/ipex/hijacks.py delete mode 100644 library/lpw_stable_diffusion.py delete mode 100644 library/lpw_stable_diffusion_orig.py delete mode 100644 library/model_util.py delete mode 100644 library/original_unet.py delete mode 100644 library/sai_model_spec.py delete mode 100644 library/sdxl_lpw_stable_diffusion.py delete mode 100644 library/sdxl_model_util.py delete mode 100644 library/sdxl_original_unet.py delete mode 100644 library/sdxl_train_util.py delete mode 100644 library/slicing_vae.py delete mode 100644 library/train_util.py delete mode 100644 library/utils.py delete mode 100644 networks/check_lora_weights.py delete mode 100644 networks/control_net_lllite.py delete mode 100644 networks/control_net_lllite_for_train.py delete mode 100644 networks/dylora.py delete mode 100644 networks/extract_lora_from_dylora.py delete mode 100644 networks/extract_lora_from_models.py delete mode 100644 networks/lora.py delete mode 100644 networks/lora_diffusers.py delete mode 100644 networks/lora_fa.py delete mode 100644 networks/lora_interrogator.py delete mode 100644 networks/merge_lora.py delete mode 100644 networks/merge_lora_old.py delete mode 100644 networks/oft.py delete mode 100644 networks/resize_lora.py delete mode 100644 networks/sdxl_merge_lora.py delete mode 100644 networks/svd_merge_lora.py create mode 160000 sd-scripts delete mode 100755 sdxl_gen_img.py delete mode 100644 sdxl_minimal_inference.py delete mode 100644 sdxl_train.py delete mode 100644 sdxl_train_control_net_lllite.py delete mode 100644 sdxl_train_control_net_lllite_old.py delete mode 100644 sdxl_train_network.py delete mode 100644 sdxl_train_textual_inversion.py create mode 100644 test/config/dreambooth-Prodigy-SDXL.json delete mode 100644 tools/blip2-for-sd/README.md delete mode 100644 tools/blip2-for-sd/caption_processor.py delete mode 100644 tools/blip2-for-sd/main.py delete mode 100644 tools/blip2-for-sd/requirements.txt delete mode 100644 tools/cache_latents.py delete mode 100644 tools/cache_text_encoder_outputs.py delete mode 100644 tools/canny.py delete mode 100644 tools/convert_diffusers20_original_sd.md delete mode 100644 tools/convert_diffusers20_original_sd.py delete mode 100644 tools/detect_face_rotate.py delete mode 100644 tools/latent_upscaler.py delete mode 100644 tools/merge_models.py delete mode 100644 tools/original_control_net.py delete mode 100644 tools/resize_images_to_resolution.py delete mode 100644 tools/show_metadata.py delete mode 100644 train_controlnet.py delete mode 100644 train_db.py delete mode 100644 train_db_README.md delete mode 100644 train_network.py delete mode 100644 train_network_README.md delete mode 100644 train_network_appl_weights_README.md delete mode 100644 train_textual_inversion.py delete mode 100644 train_textual_inversion_XTI.py delete mode 100644 train_ti_README.md diff --git a/.augmentignore b/.augmentignore new file mode 100644 index 000000000..1b7f5dde8 --- /dev/null +++ b/.augmentignore @@ -0,0 +1,15 @@ +.env +.cache +.vscode +__pycache__ +bitsandbytes_windows +cudnn_windows +data +dataset +docs +examples +outputs +SmilingWolf +test +v2_inference +venv \ No newline at end of file diff --git a/.gitignore b/.gitignore index 99734cc19..827ef6a8a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Python venv +venv2 __pycache__ *.egg-info build @@ -44,4 +45,4 @@ requirements_tmp_for_setup.txt 0.13.3 *.npz -presets/*/user_presets/* \ No newline at end of file +presets/*/user_presets/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..dade70d80 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "sd-scripts"] + path = sd-scripts + url = https://github.com/kohya-ss/sd-scripts.git \ No newline at end of file diff --git a/.release b/.release index 751f4c9f3..698f8983e 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v22.7.0 +v23.0.0 diff --git a/.vscode/settings.json b/.vscode/settings.json index bea2e8fdf..d412e4162 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "python.linting.enabled": true, "python.formatting.provider": "yapf", - "DockerRun.DisableDockerrc": true + "DockerRun.DisableDockerrc": true, + "augment.enableAutomaticCompletions": false } \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index dc2df9c07..562a97391 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,7 +30,7 @@ RUN --mount=type=cache,id=pip-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/r torch==2.1.2 torchvision==0.16.2 \ xformers==0.0.23.post1 \ # Why [and-cuda]: https://github.com/tensorflow/tensorflow/issues/61468#issuecomment-1759462485 - tensorflow[and-cuda]==2.14.0 \ + tensorflow[and-cuda]==2.15.0.post1 \ ninja \ pip setuptools wheel @@ -39,6 +39,7 @@ RUN --mount=type=cache,id=pip-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/r --mount=source=requirements_linux_docker.txt,target=requirements_linux_docker.txt \ --mount=source=requirements.txt,target=requirements.txt \ --mount=source=setup/docker_setup.py,target=setup.py \ + --mount=source=sd-scripts,target=sd-scripts,rw \ pip install -r requirements_linux_docker.txt -r requirements.txt # Replace pillow with pillow-simd (Only for x86) diff --git a/README-ja.md b/README-ja.md deleted file mode 100644 index 29c33a659..000000000 --- a/README-ja.md +++ /dev/null @@ -1,154 +0,0 @@ -SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 - -SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 - -## リポジトリについて -Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 - -[README in English](./README.md) ←更新情報はこちらにあります - -GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 - -以下のスクリプトがあります。 - -* DreamBooth、U-NetおよびText Encoderの学習をサポート -* fine-tuning、同上 -* LoRAの学習をサポート -* 画像生成 -* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換) - -## 使用法について - -* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど - * [データセット設定](./docs/config_README-ja.md) -* [DreamBoothの学習について](./docs/train_db_README-ja.md) -* [fine-tuningのガイド](./docs/fine_tune_README_ja.md): -* [LoRAの学習について](./docs/train_network_README-ja.md) -* [Textual Inversionの学習について](./docs/train_ti_README-ja.md) -* [画像生成スクリプト](./docs/gen_img_README-ja.md) -* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad) - -## Windowsでの動作に必要なプログラム - -Python 3.10.6およびGitが必要です。 - -- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe -- git: https://git-scm.com/download/win - -PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。 -(venvに限らずスクリプトの実行が可能になりますので注意してください。) - -- PowerShellを管理者として開きます。 -- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。 -- 管理者のPowerShellを閉じます。 - -## Windows環境でのインストール - -スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 - -以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 - -(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) - -PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。 - -```powershell -git clone https://github.com/kohya-ss/sd-scripts.git -cd sd-scripts - -python -m venv venv -.\venv\Scripts\activate - -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 - -accelerate config -``` - -コマンドプロンプトでも同一です。 - -(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) - -accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) - -※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。 - -```txt -- This machine -- No distributed training -- NO -- NO -- NO -- all -- fp16 -``` - -※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( -``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) - -### オプション:`bitsandbytes`(8bit optimizer)を使う - -`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 - -Windowsでは0.35.0または0.41.1を推奨します。 - -- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 -- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 - -注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 - -以下の手順に従い、`bitsandbytes`をインストールしてください。 - -### 0.35.0を使う場合 - -PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 - -```powershell -cd sd-scripts -.\venv\Scripts\activate -pip install bitsandbytes==0.35.0 - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -``` - -### 0.41.1を使う場合 - -jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 - -```powershell -python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui -``` - -## アップグレード - -新しいリリースがあった場合、以下のコマンドで更新できます。 - -```powershell -cd sd-scripts -git pull -.\venv\Scripts\activate -pip install --use-pep517 --upgrade -r requirements.txt -``` - -コマンドが成功すれば新しいバージョンが使用できます。 - -## 謝意 - -LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。 - -Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。 - -## ライセンス - -スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。 - -[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT - -[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT - -[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause - - diff --git a/README.md b/README.md index 7c7e862aa..ca91eea9a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Kohya's GUI -This repository mostly provides a Windows-focused Gradio GUI for [Kohya's Stable Diffusion trainers](https://github.com/kohya-ss/sd-scripts)... but support for Linux OS is also provided through community contributions. Macos is not great at the moment. +This repository mostly provides a Gradio GUI for [Kohya's Stable Diffusion trainers](https://github.com/kohya-ss/sd-scripts)... but support for Linux OS is also provided through community contributions. Macos is not great at the moment... but might work if the wind blow in the right direction... The GUI allows you to set the training parameters and generate and run the required CLI commands to train the model. @@ -12,11 +12,11 @@ The GUI allows you to set the training parameters and generate and run the requi - [Installation](#installation) - [Windows](#windows) - [Windows Pre-requirements](#windows-pre-requirements) - - [Setup](#setup) - - [Optional: CUDNN 8.6](#optional-cudnn-86) + - [Setup Windows](#setup-windows) + - [Optional: CUDNN 8.9.6.50](#optional-cudnn-89650) - [Linux and macOS](#linux-and-macos) - [Linux Pre-requirements](#linux-pre-requirements) - - [Setup](#setup-1) + - [Setup Linux](#setup-linux) - [Install Location](#install-location) - [Runpod](#runpod) - [Manual installation](#manual-installation) @@ -30,27 +30,18 @@ The GUI allows you to set the training parameters and generate and run the requi - [Starting GUI Service](#starting-gui-service) - [Launching the GUI on Windows](#launching-the-gui-on-windows) - [Launching the GUI on Linux and macOS](#launching-the-gui-on-linux-and-macos) - - [Dreambooth](#dreambooth) - - [Finetune](#finetune) - - [Train Network](#train-network) - [LoRA](#lora) - [Sample image generation during training](#sample-image-generation-during-training) - [Troubleshooting](#troubleshooting) - [Page File Limit](#page-file-limit) - [No module called tkinter](#no-module-called-tkinter) - - [FileNotFoundError](#filenotfounderror) - [SDXL training](#sdxl-training) - - [Training scripts for SDXL](#training-scripts-for-sdxl) - - [Utility scripts for SDXL](#utility-scripts-for-sdxl) - - [Tips for SDXL training](#tips-for-sdxl-training) - - [Format of Textual Inversion embeddings for SDXL](#format-of-textual-inversion-embeddings-for-sdxl) - - [ControlNet-LLLite](#controlnet-lllite) - - [Sample image generation during training](#sample-image-generation-during-training-1) - [Change History](#change-history) + - [2024/03/02 (v22.7.0)](#20240302-v2270) ## 🦒 Colab -This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: https://github.com/camenduru/kohya_ss-colab. +This Colab notebook was not created or maintained by me; however, it appears to function effectively. The source can be found at: . I would like to express my gratitude to camendutu for their valuable contribution. If you encounter any issues with the Colab notebook, please report them on their repository. @@ -73,38 +64,37 @@ To install the necessary dependencies on a Windows system, follow these steps: 3. Install the [Visual Studio 2015, 2017, 2019, and 2022 redistributable](https://aka.ms/vs/17/release/vc_redist.x64.exe). -#### Setup +#### Setup Windows To set up the project, follow these steps: 1. Open a terminal and navigate to the desired installation directory. 2. Clone the repository by running the following command: + ```shell git clone https://github.com/bmaltais/kohya_ss.git ``` 3. Change into the `kohya_ss` directory: + ```shell cd kohya_ss ``` 4. Run the setup script by executing the following command: + ```shell .\setup.bat ``` During the accelerate config step use the default values as proposed during the configuration unless you know your hardware demand otherwise. The amount of VRAM on your GPU does not have an impact on the values used. -#### Optional: CUDNN 8.6 - -The following steps are optional but can improve the learning speed for owners of NVIDIA 30X0/40X0 GPUs. These steps enable larger training batch sizes and faster training speeds. +#### Optional: CUDNN 8.9.6.50 -Please note that the CUDNN 8.6 DLLs needed for this process cannot be hosted on GitHub due to file size limitations. You can download them [here](https://github.com/bmaltais/python-library/raw/main/cudnn_windows.zip) to boost sample generation speed (almost 50% on a 4090 GPU). After downloading the ZIP file, follow the installation steps below: +The following steps are optional but will improve the learning speed for owners of NVIDIA 30X0/40X0 GPUs. These steps enable larger training batch sizes and faster training speeds. -1. Unzip the downloaded file and place the `cudnn_windows` folder in the root directory of the `kohya_ss` repository. - -2. Run .\setup.bat and select the option to install cudnn. +1. Run .\setup.bat and select `2. (Optional) Install cudnn files (if you want to use latest supported cudnn version)` ### Linux and macOS @@ -113,6 +103,7 @@ Please note that the CUDNN 8.6 DLLs needed for this process cannot be hosted on To install the necessary dependencies on a Linux system, ensure that you fulfill the following requirements: - Ensure that `venv` support is pre-installed. You can install it on Ubuntu 22.04 using the command: + ```shell apt install python3.10-venv ``` @@ -122,32 +113,37 @@ To install the necessary dependencies on a Linux system, ensure that you fulfill - Make sure you have Python version 3.10.6 or higher (but lower than 3.11.0) installed on your system. - If you are using WSL2, set the `LD_LIBRARY_PATH` environment variable by executing the following command: + ```shell export LD_LIBRARY_PATH=/usr/lib/wsl/lib/ ``` -#### Setup +#### Setup Linux To set up the project on Linux or macOS, perform the following steps: 1. Open a terminal and navigate to the desired installation directory. 2. Clone the repository by running the following command: + ```shell git clone https://github.com/bmaltais/kohya_ss.git ``` 3. Change into the `kohya_ss` directory: + ```shell cd kohya_ss ``` 4. If you encounter permission issues, make the `setup.sh` script executable by running the following command: + ```shell chmod +x ./setup.sh ``` 5. Run the setup script by executing the following command: + ```shell ./setup.sh ``` @@ -173,18 +169,21 @@ To install the necessary components for Runpod and run kohya_ss, follow these st 2. SSH into the Runpod. 3. Clone the repository by running the following command: + ```shell cd /workspace git clone https://github.com/bmaltais/kohya_ss.git ``` 4. Run the setup script: + ```shell cd kohya_ss ./setup-runpod.sh ``` 5. Run the gui with: + ```shell ./gui.sh --share --headless ``` @@ -201,13 +200,12 @@ To install the necessary components for Runpod and run kohya_ss, follow these st To run from a pre-built Runpod template you can: -1. Open the Runpod template by clicking on https://runpod.io/gsc?template=ya6013lj5a&ref=w18gds2n +1. Open the Runpod template by clicking on 2. Deploy the template on the desired host 3. Once deployed connect to the Runpod on HTTP 3010 to connect to kohya_ss GUI. You can also connect to auto1111 on HTTP 3000. - ### Docker #### Local docker build @@ -219,11 +217,9 @@ If you prefer to use Docker, follow the instructions below: 2. Open your OS shell (Command Prompt or Terminal) and run the following commands: ```bash - git clone https://github.com/bmaltais/kohya_ss.git + git clone --recursive https://github.com/bmaltais/kohya_ss.git cd kohya_ss - docker compose create - docker compose build - docker compose run --service-ports kohya-ss-gui + docker compose up -d --build ``` Note: The initial run may take up to 20 minutes to complete. @@ -233,7 +229,7 @@ If you prefer to use Docker, follow the instructions below: - All training data must be placed in the `dataset` subdirectory, as the Docker container cannot access files from other directories. - The file picker feature is not functional. You need to manually set the folder path and config file path. - Dialogs may not work as expected, and it is recommended to use unique file names to avoid conflicts. - - There is no built-in auto-update support. To update the system, you must run update scripts outside of Docker and rebuild using `docker compose build`. + - This dockerfile has been designed to be easily disposable. You can discard the container at any time and docker build it with a new version of the code. To update the system, run update scripts outside of Docker and rebuild using `docker compose down && docker compose up -d --build`. If you are running Linux, an alternative Docker container port with fewer limitations is available [here](https://github.com/P2Enjoy/kohya_ss-docker). @@ -241,8 +237,8 @@ If you prefer to use Docker, follow the instructions below: You may want to use the following Dockerfile repos to build the images: - - Standalone Kohya_ss template: https://github.com/ashleykleynhans/kohya-docker - - Auto1111 + Kohya_ss GUI template: https://github.com/ashleykleynhans/stable-diffusion-docker +- Standalone Kohya_ss template: +- Auto1111 + Kohya_ss GUI template: ## Upgrading @@ -253,11 +249,13 @@ To upgrade your installation to a new version, follow the instructions below. If a new release becomes available, you can upgrade your repository by running the following commands from the root directory of the project: 1. Pull the latest changes from the repository: + ```powershell git pull ``` 2. Run the setup script: + ```powershell .\setup.bat ``` @@ -271,11 +269,13 @@ To upgrade your installation on Linux or macOS, follow these steps: directory of the project. 2. Pull the latest changes from the repository: + ```bash git pull ``` 3. Refresh and update everything: + ```bash ./setup.sh ``` @@ -316,39 +316,17 @@ To launch the GUI on Linux or macOS, run the `gui.sh` script located in the root gui.sh --listen 127.0.0.1 --server_port 7860 --inbrowser --share ``` -## Dreambooth - -For specific instructions on using the Dreambooth solution, please refer to the [Dreambooth README](https://github.com/bmaltais/kohya_ss/blob/master/train_db_README.md). - -## Finetune - -For specific instructions on using the Finetune solution, please refer to the [Finetune README](https://github.com/bmaltais/kohya_ss/blob/master/fine_tune_README.md). - -## Train Network - -For specific instructions on training a network, please refer to the [Train network README](https://github.com/bmaltais/kohya_ss/blob/master/train_network_README.md). - ## LoRA To train a LoRA, you can currently use the `train_network.py` code. You can create a LoRA network by using the all-in-one GUI. Once you have created the LoRA network, you can generate images using auto1111 by installing [this extension](https://github.com/kohya-ss/sd-webui-additional-networks). -The following are the names of LoRA types used in this repository: - -1. LoRA-LierLa: LoRA for Linear layers and Conv2d layers with a 1x1 kernel. - -2. LoRA-C3Lier: LoRA for Conv2d layers with a 3x3 kernel, in addition to LoRA-LierLa. - -LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network argument). You can use LoRA-LierLa with our extension for AUTOMATIC1111's Web UI or the built-in LoRA feature of the Web UI. - -To use LoRA-C3Lier with the Web UI, please use our extension. - ## Sample image generation during training A prompt file might look like this, for example: -``` +```txt # prompt 1 masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy, bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 @@ -379,272 +357,12 @@ If you encounter an X error related to the page file, you may need to increase t If you encounter an error indicating that the module `tkinter` is not found, try reinstalling Python 3.10 on your system. -### FileNotFoundError - -If you come across a `FileNotFoundError`, it is likely due to an installation issue. Make sure you do not have any locally installed Python modules that could conflict with the ones installed in the virtual environment. You can uninstall them by following these steps: - -1. Open a new PowerShell terminal and ensure that no virtual environment is active. - -2. Run the following commands to create a backup file of your locally installed pip packages and then uninstall them: - ```powershell - pip freeze > uninstall.txt - pip uninstall -r uninstall.txt - ``` - - After uninstalling the local packages, redo the installation steps within the `kohya_ss` virtual environment. - - ## SDXL training The documentation in this section will be moved to a separate document later. -### Training scripts for SDXL - -- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - - `--full_bf16` option is added. Thanks to KohakuBlueleaf! - - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. - - The full bfloat16 training might be unstable. Please use it at your own risk. - - The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. - - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. -- `prepare_buckets_latents.py` now supports SDXL fine-tuning. - -- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - -- Both scripts has following additional options: - - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. - - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. - -- `--weighted_captions` option is not supported yet for both scripts. - -- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. - - `--cache_text_encoder_outputs` is not supported. - - There are two options for captions: - 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. - 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. - - See below for the format of the embeddings. - -- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. - -### Utility scripts for SDXL - -- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. - - The options are almost the same as `sdxl_train.py'. See the help message for the usage. - - Please launch the script as follows: - `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` - - This script should work with multi-GPU, but it is not tested in my environment. - -- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. - - The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage. - -- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage. - -### Tips for SDXL training - -- The default resolution of SDXL is 1024x1024. -- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: - - Train U-Net only. - - Use gradient checkpointing. - - Use `--cache_text_encoder_outputs` option and caching latents. - - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. -- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: - - Train U-Net only. - - Use gradient checkpointing. - - Use `--cache_text_encoder_outputs` option and caching latents. - - Use one of 8bit optimizers or Adafactor optimizer. - - Use lower dim (4 to 8 for 8GB GPU). -- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. -- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. -- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. - -Example of the optimizer settings for Adafactor with the fixed learning rate: - -```toml -optimizer_type = "adafactor" -optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] -lr_scheduler = "constant_with_warmup" -lr_warmup_steps = 100 -learning_rate = 4e-7 # SDXL original learning rate -``` - -### Format of Textual Inversion embeddings for SDXL - -```python -from safetensors.torch import save_file - -state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} -save_file(state_dict, file) -``` - -### ControlNet-LLLite - -ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. - -### Sample image generation during training - A prompt file might look like this, for example - -``` -# prompt 1 -masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 - -# prompt 2 -masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 -``` - - Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. - - The prompt weighting such as `( )` and `[ ]` are working. +## Change History +### 2024/03/02 (v22.7.0) -## Change History -* 2024/03/02 (v22.7.0) - Major code refactoring thanks to @wkpark , This will make updating sd-script cleaner by keeping sd-scripts files separate from the GUI files. -* 2024/02/17 (v22.6.2) -- Fix issue with Lora Extract GUI - - Fix syntax issue where parameter lora_network_weights is actually called network_weights -- Merge sd-scripts v0.8.4 code update - - Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). - - Text Encoders were not moved to CPU. - - Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) - - The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu! - - The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library. - - If `rich` is not installed, the log output will be the same as before. - - The following options are available in each training script: - - `--console_log_simple` option can be used to switch to the previous log output. - - `--console_log_level` option can be used to specify the log level. The default is `INFO`. - - `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console). - - The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54! - - The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically. - - The `--new_conv_rank` option to specify the new rank of Conv2d is added to `networks/resize_lora.py`. PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) Thanks to mgz-dev! - - An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster. - - Currently, only the cache part of latents is optimized. - - The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0! - - Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev! - - DyLoRA is fixed to work with SDXL. PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) Thanks to tamlog06! - - The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added. - - External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later) - - The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization. - - Gradual Latent Hires fix is added to each generation script. See [here](./docs/gen_img_README-ja.md#about-gradual-latent) for details. - -* 2024/02/15 (v22.6.1) -- Add support for multi-gpu parameters in the GUI under the "Parameters > Advanced" tab. -- Significant rewrite of how parameters are created in the code. I hope I did not break anything in the process... Will make the code easier to update. -- Update TW locallisation -- Update gradio module version to latest 3.x - -* 2024/01/27 (v22.6.0) -- Merge sd-scripts v0.8.3 code update - - Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380! - - `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library. - - Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf! - - Fixed a bug that the training crashes when training ControlNet-LLLite. - -- Merge sd-scripts v0.8.2 code update - - [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf! - - Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`. - - PyTorch 2.1 or later is required. - - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. - - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. - - - [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc. - - This is an experimental option and may be removed or changed in the future. - - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. - - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. - - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. - - - Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage. - - `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision. - - `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL. - - - The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf! - - - The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx! - - - Fixed a bug in multi-GPU Textual Inversion training. - - - `.toml` example for network multiplier - - ```toml - [general] - [[datasets]] - resolution = 512 - batch_size = 8 - network_multiplier = 1.0 - - ... subset settings ... - - [[datasets]] - resolution = 512 - batch_size = 8 - network_multiplier = -1.0 - - ... subset settings ... - ``` - -- Merge sd-scripts v0.8.1 code update - - - Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`). - - Text Encoders were not moved to CPU. - - - Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) - -* 2024/01/15 (v22.5.0) -- Merged sd-scripts v0.8.0 updates - - Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade). - - Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded. - - `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev! - - This feature works only on Linux or WSL. - - Please specify `--torch_compile` option in each training script. - - You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work. - - Please use `--spda` option instead of `--xformers` option. - - PyTorch 2.1 or later is recommended. - - Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details. - - The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t! - - IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0! - - Fixed a bug that Diffusers format model cannot be saved. -- Fix LoRA config display after load that would sometime hide some of the feilds - -* 2024/01/02 (v22.4.1) -- Minor bug fixed and enhancements. - -* 2023/12/28 (v22.4.0) -- Fixed to work `tools/convert_diffusers20_original_sd.py`. Thanks to Disty0! PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016) -- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000) - - `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training. -- IPEX support is updated. Thanks to Disty0! -- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008) -- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907) -- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906) -- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf! - - You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled. -- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO! -- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955) -- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009) -- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986) - -* 2023/12/20 (v22.3.1) -- Add goto button to manual caption utility -- Add missing options for various LyCORIS training algorithms -- Refactor how feilds are shown or hidden -- Made max value for network and convolution rank 512 except for LyCORIS/LoKr. - -* 2023/12/06 (v22.3.0) -- Merge sd-scripts updates: - - `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) - - Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) - - See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details. - - `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) - - The default values are same as the previous version. - - Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`. - - `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages. - - `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages. - - `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink. - - `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available. - - Add GLoRA support -- \ No newline at end of file diff --git "a/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" "b/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" deleted file mode 100644 index 68a14a921..000000000 --- "a/README_\344\270\255\346\226\207\346\225\231\347\250\213.md" +++ /dev/null @@ -1,174 +0,0 @@ -SDXL已得到支持。sdxl分支已合并到main分支。当更新仓库时,请执行升级步骤。由于accelerate版本也已经升级,请重新运行accelerate config。 - -有关SDXL训练的信息,请参见[此处](./README.md#sdxl-training)(英文)。 - -## 关于本仓库 - -用于Stable Diffusion的训练、图像生成和其他脚本的仓库。 - -[英文README](./README.md) <- 更新信息在这里 - -[bmaltais的仓库](https://github.com/bmaltais/kohya_ss)中提供了GUI和PowerShell脚本等使其更易于使用的功能(英文),也请一并参阅。衷心感谢bmaltais。 - -包含以下脚本: - -* 支持DreamBooth、U-Net和Text Encoder的训练 -* 微调,同上 -* 支持LoRA的训练 -* 图像生成 -* 模型转换(在Stable Diffision ckpt/safetensors与Diffusers之间转换) - -## 使用方法 - -* [通用部分的训练信息](./docs/train_README-ja.md): 数据准备和选项等 -* [数据集设置](./docs/config_README-ja.md) -* [DreamBooth的训练信息](./docs/train_db_README-ja.md) -* [微调指南](./docs/fine_tune_README_ja.md) -* [LoRA的训练信息](./docs/train_network_README-ja.md) -* [Textual Inversion的训练信息](./docs/train_ti_README-ja.md) -* [图像生成脚本](./docs/gen_img_README-ja.md) -* note.com [模型转换脚本](https://note.com/kohya_ss/n/n374f316fe4ad) - -## Windows上需要的程序 - -需要Python 3.10.6和Git。 - -- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe -- git: https://git-scm.com/download/win - -如果要在PowerShell中使用,请按以下步骤更改安全设置以使用venv。 -(不仅仅是venv,这使得脚本的执行成为可能,所以请注意。) - -- 以管理员身份打开PowerShell。 -- 输入“Set-ExecutionPolicy Unrestricted”,并回答Y。 -- 关闭管理员PowerShell。 - -## 在Windows环境下安装 - -脚本已在PyTorch 2.0.1上通过测试。PyTorch 1.12.1也应该可以工作。 - -下例中,将安装PyTorch 2.0.1/CUDA 11.8版。如果使用CUDA 11.6版或PyTorch 1.12.1,请酌情更改。 - -(注意,如果python -m venv~这行只显示“python”,请将其更改为py -m venv~。) - -如果使用PowerShell,请打开常规(非管理员)PowerShell并按顺序执行以下操作: - -```powershell -git clone https://github.com/kohya-ss/sd-scripts.git -cd sd-scripts - -python -m venv venv -.\venv\Scripts\activate - -pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -pip install --upgrade -r requirements.txt -pip install xformers==0.0.20 - -accelerate config -``` - -在命令提示符下也相同。 - -(注:由于 ``python -m venv venv`` 比 ``python -m venv --system-site-packages venv`` 更安全,已进行更改。如果global python中安装了package,后者会引发各种问题。) - -在accelerate config的提示中,请按以下方式回答。(如果以bf16学习,最后一个问题回答bf16。) - -※从0.15.0开始,在日语环境中按方向键选择会崩溃(......)。请使用数字键0、1、2......进行选择。 - -```txt -- This machine -- No distributed training -- NO -- NO -- NO -- all -- fp16 -``` - -※有时可能会出现 ``ValueError: fp16 mixed precision requires a GPU`` 错误。在这种情况下,对第6个问题 ``(What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:`` -)回答“0”。(将使用id `0`的GPU)。 - -### 可选:``bitsandbytes``(8位优化器) - -`bitsandbytes`现在是可选的。在Linux上,可以通过pip正常安装(推荐0.41.1或更高版本)。 - -在Windows上,推荐0.35.0或0.41.1。 - -- `bitsandbytes` 0.35.0: 似乎是稳定的版本。可以使用AdamW8bit,但不能使用其他一些8位优化器和`full_bf16`学习时的选项。 -- `bitsandbytes` 0.41.1: 支持 Lion8bit、PagedAdamW8bit、PagedLion8bit。可以使用`full_bf16`。 - -注意:`bitsandbytes` 从0.35.0到0.41.0之间的版本似乎存在问题。 https://github.com/TimDettmers/bitsandbytes/issues/659 - -请按以下步骤安装`bitsandbytes`。 - -### 使用0.35.0 - -以下是PowerShell的例子。在命令提示符中,请使用copy代替cp。 - -```powershell -cd sd-scripts -.\venv\Scripts\activate -pip install bitsandbytes==0.35.0 - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py -``` - -### 使用0.41.1 - -请从[此处](https://github.com/jllllll/bitsandbytes-windows-webui)或其他地方安装jllllll发布的Windows whl文件。 - -```powershell -python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui -``` - - -### 可选:使用Lion8bit - -如果要使用Lion8bit,需要将`bitsandbytes`升级到0.38.0以上。首先卸载`bitsandbytes`,然后在Windows中安装适合Windows的whl文件,例如[这里的](https://github.com/jllllll/bitsandbytes-windows-webui)。例如: - -```powershell -pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl -``` - -升级时用`pip install .`更新这个仓库,并视情况升级其他包。 - -### 可选:使用PagedAdamW8bit和PagedLion8bit - -如果要使用PagedAdamW8bit和PagedLion8bit,需要将`bitsandbytes`升级到0.39.0以上。首先卸载`bitsandbytes`,然后在Windows中安装适合Windows的whl文件,例如[这里的](https://github.com/jllllll/bitsandbytes-windows-webui)。例如: - -```powershell -pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl -``` - -升级时用`pip install .`更新这个仓库,并视情况升级其他包。 - -## 升级 - -如果有新版本,可以用以下命令更新: - -```powershell -cd sd-scripts -git pull -.\venv\Scripts\activate -pip install --use-pep517 --upgrade -r requirements.txt -``` - -如果命令成功,就可以使用新版本了。 - -## 致谢 - -LoRA实现基于[cloneofsimo的仓库](https://github.com/cloneofsimo/lora)。表示感谢。 - -将Conv2d 3x3扩展到所有层起初由 [cloneofsimo](https://github.com/cloneofsimo/lora) 发布, [KohakuBlueleaf](https://github.com/KohakuBlueleaf/LoCon) 证明了其有效性。深深感谢 KohakuBlueleaf。 - -## 许可 - -脚本遵循 ASL 2.0 许可,但包含其他许可的代码部分(Diffusers和cloneofsimo的仓库)。 - -[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT - -[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT - -[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause diff --git a/XTI_hijack.py b/XTI_hijack.py deleted file mode 100644 index 93bc1c0b1..000000000 --- a/XTI_hijack.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -from library.device_utils import init_ipex -init_ipex() - -from typing import Union, List, Optional, Dict, Any, Tuple -from diffusers.models.unet_2d_condition import UNet2DConditionOutput - -from library.original_unet import SampleOutput - - -def unet_forward_XTI( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, -) -> Union[Dict, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a dict instead of a plain tuple. - - Returns: - `SampleOutput` or `tuple`: - `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある - # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する - # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - # 64で割り切れないときはupsamplerにサイズを伝える - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 1. time - timesteps = timestep - timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - # timestepsは重みを含まないので常にfloat32のテンソルを返す - # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある - # time_projでキャストしておけばいいんじゃね? - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - down_i = 0 - for downsample_block in self.down_blocks: - # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 - # まあこちらのほうがわかりやすいかもしれない - if downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], - ) - down_i += 2 - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) - - # 5. up - up_i = 7 - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection - - # if we have not reached the final block and need to forward the upsample size, we do it here - # 前述のように最後のブロック以外ではupsample_sizeを伝える - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], - upsample_size=upsample_size, - ) - up_i += 3 - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return SampleOutput(sample=sample) - - -def downblock_forward_XTI( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None -): - output_states = () - i = 0 - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample - - output_states += (hidden_states,) - i += 1 - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -def upblock_forward_XTI( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, -): - i = 0 - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample - - i += 1 - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states diff --git a/config_README-zh.md b/config_README-zh.md deleted file mode 100644 index 1266e7b25..000000000 --- a/config_README-zh.md +++ /dev/null @@ -1,289 +0,0 @@ - - -paste.txt -For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future. - -`--dataset_config` 可以传递设置文件的说明。 - -## 概述 - -通过传递设置文件,用户可以进行更详细的设置。 - -* 可以设置多个数据集 - * 例如,可以为每个数据集设置 `resolution`,并混合它们进行训练。 - * 对于支持DreamBooth方法和fine tuning方法的训练方法,可以混合使用DreamBooth方式和fine tuning方式的数据集。 -* 可以为每个子集更改设置 - * 数据子集是指根据图像目录或元数据拆分的数据集。几个子集构成一个数据集。 - * 诸如 `keep_tokens` 和 `flip_aug` 之类的选项可以为每个子集设置。另一方面,诸如 `resolution` 和 `batch_size` 之类的选项可以为每个数据集设置,属于同一数据集的子集的值将是共享的。具体细节将在下文中说明。 - -设置文件的格式可以是JSON或TOML。考虑到描述的便利性,建议使用 [TOML](https://toml.io/zh/v1.0.0-rc.2)。下面将以TOML为前提进行说明。 - -下面是一个用TOML描述的设置文件的示例。 - -```toml -[general] -shuffle_caption = true -caption_extension = '.txt' -keep_tokens = 1 - -# 这是一个DreamBooth方式的数据集 -[[datasets]] -resolution = 512 -batch_size = 4 -keep_tokens = 2 - - [[datasets.subsets]] - image_dir = 'C:\hoge' - class_tokens = 'hoge girl' - # 此子集为keep_tokens = 2(使用属于的数据集的值) - - [[datasets.subsets]] - image_dir = 'C:\fuga' - class_tokens = 'fuga boy' - keep_tokens = 3 - - [[datasets.subsets]] - is_reg = true - image_dir = 'C:\reg' - class_tokens = 'human' - keep_tokens = 1 - -# 这是一个fine tuning方式的数据集 -[[datasets]] -resolution = [768, 768] -batch_size = 2 - - [[datasets.subsets]] - image_dir = 'C:\piyo' - metadata_file = 'C:\piyo\piyo_md.json' - # 此子集为keep_tokens = 1(使用general的值) -``` - -在此示例中,有3个目录作为DreamBooth方式的数据集以512x512(批量大小4)进行训练,有1个目录作为fine tuning方式的数据集以768x768(批量大小2)进行训练。 - -## 数据子集的设置 - -数据集和子集的相关设置分为几个可以注册的部分。 - -* `[general]` - * 适用于所有数据集或所有子集的选项指定部分。 - * 如果在数据集设置或子集设置中存在同名选项,则会优先使用数据集或子集的设置。 -* `[[datasets]]` - * `datasets`是数据集设置注册部分。这是指定适用于每个数据集的选项的部分。 - * 如果存在子集设置,则会优先使用子集设置。 -* `[[datasets.subsets]]` - * `datasets.subsets`是子集设置的注册部分。这是指定适用于每个子集的选项的部分。 - -下面是先前示例中图像目录和注册部分的对应关系图。 - -``` -C:\ -├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐ -├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general] -├─ reg -> [[datasets.subsets]] No.3 ┘ | -└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘ -``` - -每个图像目录对应一个 `[[datasets.subsets]]`。然后1个或多个 `[[datasets.subsets]]` 构成一个 `[[datasets]]`。所有的 `[[datasets]]`、`[[datasets.subsets]]` 属于 `[general]`。 - -根据注册部分,可以指定不同的选项,但如果指定了同名选项,则下级注册部分中的值将被优先使用。建议参考先前 keep_tokens 选项的处理,这将有助于理解。 - -此外,可指定的选项会根据训练方法支持的技术而变化。 - -* 仅DreamBooth方式的选项 -* 仅fine tuning方式的选项 -* 当可以使用caption dropout技术时的选项 - -在可同时使用DreamBooth方法和fine tuning方法的训练方法中,它们可以一起使用。 -需要注意的是,判断是DreamBooth方式还是fine tuning方式是按数据集进行的,因此同一数据集中不能混合存在DreamBooth方式子集和fine tuning方式子集。 -也就是说,如果要混合使用这两种方式,则需要设置不同方式的子集属于不同的数据集。 - -从程序行为上看,如果存在 `metadata_file` 选项,则判断为fine tuning方式的子集。 -因此,对于属于同一数据集的子集,要么全部都具有 `metadata_file` 选项,要么全部都不具有 `metadata_file` 选项,这两种情况之一。 - -下面解释可用的选项。对于与命令行参数名称相同的选项,基本上省略解释。请参阅其他自述文件。 - -### 所有训练方法通用的选项 - -不管训练方法如何,都可以指定的选项。 - -#### 数据集的选项 - -与数据集设置相关的选项。不能在 `datasets.subsets` 中描述。 - -| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | -| ---- | ---- | ---- | ---- | -| `batch_size` | `1` | o | o | -| `bucket_no_upscale` | `true` | o | o | -| `bucket_reso_steps` | `64` | o | o | -| `enable_bucket` | `true` | o | o | -| `max_bucket_reso` | `1024` | o | o | -| `min_bucket_reso` | `128` | o | o | -| `resolution` | `256`, `[512, 512]` | o | o | - -* `batch_size` - * 等同于命令行参数 `--train_batch_size`。 - -这些设置对每个数据集是固定的。 -也就是说,属于该数据集的子集将共享这些设置。 -例如,如果要准备不同分辨率的数据集,可以像上面的示例那样将它们定义为单独的数据集,以便可以为它们单独设置不同的分辨率。 - -#### 子集的选项 - -与子集设置相关的选项。 - -| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | -| ---- | ---- | ---- | ---- | ---- | -| `color_aug` | `false` | o | o | o | -| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o | -| `flip_aug` | `true` | o | o | o | -| `keep_tokens` | `2` | o | o | o | -| `num_repeats` | `10` | o | o | o | -| `random_crop` | `false` | o | o | o | -| `shuffle_caption` | `true` | o | o | o | -| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o | -| `caption_suffix` | `“, from side”` | o | o | o | - -* `num_repeats` - * 指定子集中图像的重复次数。在fine tuning中相当于 `--dataset_repeats`,但 `num_repeats` - 在任何训练方法中都可以指定。 -* `caption_prefix`, `caption_suffix` - * 指定在标题前后附加的字符串。包括这些字符串的状态下进行洗牌。当指定 `keep_tokens` 时请注意。 - -### 仅DreamBooth方式的选项 - -DreamBooth方式的选项仅存在于子集选项中。 - -#### 子集的选项 - -DreamBooth方式子集的设置相关选项。 - -| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | -| ---- | ---- | ---- | ---- | ---- | -| `image_dir` | `‘C:\hoge’` | - | - | o(必需) | -| `caption_extension` | `".txt"` | o | o | o | -| `class_tokens` | `“sks girl”` | - | - | o | -| `is_reg` | `false` | - | - | o | - -首先需要注意的是,`image_dir` 必须指定图像文件直接放置在路径下。与以往的 DreamBooth 方法不同,它与放置在子目录中的图像不兼容。另外,即使使用像 `5_cat` 这样的文件夹名,也不会反映图像的重复次数和类名。如果要单独设置这些,需要使用 `num_repeats` 和 `class_tokens` 明确指定。 - -* `image_dir` - * 指定图像目录路径。必填选项。 - * 图像需要直接放置在目录下。 -* `class_tokens` - * 设置类标记。 - * 仅在不存在与图像对应的 caption 文件时才在训练中使用。是否使用是按图像判断的。如果没有指定 `class_tokens` 且也找不到 caption 文件,将报错。 -* `is_reg` - * 指定子集中的图像是否用于正则化。如果未指定,则视为 `false`,即视为非正则化图像。 - -### 仅fine tuning方式的选项 - -fine tuning方式的选项仅存在于子集选项中。 - -#### 子集的选项 - -fine tuning方式子集的设置相关选项。 - -| 选项名称 | 设置示例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | -| ---- | ---- | ---- | ---- | ---- | -| `image_dir` | `‘C:\hoge’` | - | - | o | -| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必需) | - -* `image_dir` - * 指定图像目录路径。与 DreamBooth 方式不同,不是必填的,但建议设置。 - * 不需要指定的情况是在生成元数据时指定了 `--full_path`。 - * 图像需要直接放置在目录下。 -* `metadata_file` - * 指定子集使用的元数据文件路径。必填选项。 - * 与命令行参数 `--in_json` 等效。 - * 由于子集需要按元数据文件指定,因此避免跨目录创建一个元数据文件。强烈建议为每个图像目录准备元数据文件,并将它们作为单独的子集进行注册。 - -### 当可使用caption dropout技术时可以指定的选项 - -当可以使用 caption dropout 技术时的选项仅存在于子集选项中。 -无论是 DreamBooth 方式还是 fine tuning 方式,只要训练方法支持 caption dropout,就可以指定。 - -#### 子集的选项 - -可使用 caption dropout 的子集的设置相关选项。 - -| 选项名称 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` | -| ---- | ---- | ---- | ---- | -| `caption_dropout_every_n_epochs` | o | o | o | -| `caption_dropout_rate` | o | o | o | -| `caption_tag_dropout_rate` | o | o | o | - -## 存在重复子集时的行为 - -对于 DreamBooth 方式的数据集,如果其中 `image_dir` 相同的子集将被视为重复。 -对于 fine tuning 方式的数据集,如果其中 `metadata_file` 相同的子集将被视为重复。 -如果数据集中存在重复的子集,则从第二个开始将被忽略。 - -另一方面,如果它们属于不同的数据集,则不会被视为重复。 -例如,像下面这样在不同的数据集中放置相同的 `image_dir` 的子集,不会被视为重复。 -这在想要以不同分辨率训练相同图像的情况下很有用。 - -```toml -# 即使存在于不同的数据集中,也不会被视为重复,两者都将用于训练 - -[[datasets]] -resolution = 512 - - [[datasets.subsets]] - image_dir = 'C:\hoge' - -[[datasets]] -resolution = 768 - - [[datasets.subsets]] - image_dir = 'C:\hoge' -``` - -## 与命令行参数的组合使用 - -设置文件中的某些选项与命令行参数选项的作用相同。 - -如果传递了设置文件,则会忽略下列命令行参数选项。 - -* `--train_data_dir` -* `--reg_data_dir` -* `--in_json` - -如果同时在命令行参数和设置文件中指定了下列命令行参数选项,则设置文件中的值将优先于命令行参数的值。除非另有说明,否则它们是同名选项。 - -| 命令行参数选项 | 优先的设置文件选项 | -| ---- | ---- | -| `--bucket_no_upscale` | | -| `--bucket_reso_steps` | | -| `--caption_dropout_every_n_epochs` | | -| `--caption_dropout_rate` | | -| `--caption_extension` | | -| `--caption_tag_dropout_rate` | | -| `--color_aug` | | -| `--dataset_repeats` | `num_repeats` | -| `--enable_bucket` | | -| `--face_crop_aug_range` | | -| `--flip_aug` | | -| `--keep_tokens` | | -| `--min_bucket_reso` | | -| `--random_crop`| | -| `--resolution` | | -| `--shuffle_caption` | | -| `--train_batch_size` | `batch_size` | - -## 错误指南 - -目前,正在使用外部库来检查设置文件是否正确,但维护不够充分,错误消息难以理解。 -计划在将来改进此问题。 - -作为次佳方案,列出一些常见错误和解决方法。 -如果确信设置正确但仍出现错误,或者完全不明白错误内容,可能是bug,请联系我们。 - -* `voluptuous.error.MultipleInvalid: required key not provided @ ...`:: 缺少必填选项的错误。可能忘记指定或错误输入了选项名。 - * `...`部分显示了错误发生位置。例如,`voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` 意味着在第0个 `datasets` 的第0个 `subsets` 的设置中缺失 `image_dir`。 -* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`:值的格式错误。输入值的格式可能错误。可以参考本 README 中选项的「设置示例」部分。 -* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`:存在不支持的选项名。可能错误输入或误输入了选项名。 - - - - \ No newline at end of file diff --git a/converted_markdown.md b/converted_markdown.md deleted file mode 100644 index 684af19ad..000000000 --- a/converted_markdown.md +++ /dev/null @@ -1,1726 +0,0 @@ -[ 読者になる ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://blog.hatena.ne.jp/hoshikat/hoshikat.hatenablog.com/subscribe?utm_medium%3Dbutton%26utm_source%3Dblogs_topright_button%26utm_campaign%3Dsubscribe_blog) - -# [人工知能と親しくなるブログ](https://hoshikat-hatenablog- -com.translate.goog/?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -## 人工知能に関するトピックを取り上げるブログです - -[ 2023-05-26 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023/05/26?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -# [誰でもわかるStable Diffusion Kohya_ssを使ったLoRA学習設定を徹底解説](https://hoshikat- -hatenablog- -com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -前回の記事では、Stable Diffusionモデルを追加学習するためのWebUI環境「kohya_ss」の導入法について解説しました。 - -今回は、LoRAのしくみを大まかに説明し、その後にkohya_ssを使ったLoRA学習設定について解説していきます。 - -※今回の記事は非常に長いです! - - - -**この記事では「各設定の意味」のみ解説しています。** - -「学習画像の用意のしかた」とか「画像にどうキャプションをつけるか」とか「どう学習を実行するか」は解説していません。学習の実行法についてはまた別の記事で解説したいと思います。 - - - - * [LoRAの仕組みを知ろう](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRAの仕組みを知ろう) - * [「モデル」とは](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#モデルとは) - * [LoRAは小さいニューラルネットを追加する](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRAは小さいニューラルネットを追加する) - * [小さいニューラルネットの構造](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#小さいニューラルネットの構造) - * [LoRA学習対象1:U-Net](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRA学習対象1U-Net) - * [RoLA学習対象2:テキストエンコーダー](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#RoLA学習対象2テキストエンコーダー) - * [kohya_ssを立ち上げてみよう](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#kohya_ssを立ち上げてみよう) - * [LoRA学習の各設定](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRA学習の各設定) - * [LoRA設定のセーブ、ロード](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRA設定のセーブロード) - * [Source modelタブ: 学習に使うベースモデルの設定](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Source-modelタブ学習に使うベースモデルの設定) - * [Pretrained model name or path](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Pretrained-model-name-or-path) - * [Model Quick Pick](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Model-Quick-Pick) - * [Save trained model as](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Save-trained-model-as) - * [v2](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#v2) - * [v_parameterization](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#v_parameterization) - * [Foldersタブ: 学習画像の場所とLoRA出力先の設定](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Foldersタブ学習画像の場所とLoRA出力先の設定) - * [Image folder](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Image-folder) - * [Output folder](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Output-folder) - * [Regularisation folder](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Regularisation-folder) - * [Logging folder](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Logging-folder) - * [Model output name](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Model-output-name) - * [Training comment](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Training-comment) - * [Training parametersタブ: 学習の詳細設定](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Training-parametersタブ学習の詳細設定) - * [LoRA type](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRA-type) - * [LoRA network weights](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LoRA-network-weights) - * [DIM from weights](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#DIM-from-weights) - * [Train batch size](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Train-batch-size) - * [Epoch](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Epoch) - * [Save every N epochs](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Save-every-N-epochs) - * [Caption Extension](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Caption-Extension) - * [Mixed precision](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Mixed-precision) - * [Save precision](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Save-precision) - * [Number of CPU threads per core](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Number-of-CPU-threads-per-core) - * [Seed](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Seed) - * [Cache latents](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Cache-latents) - * [Cache latents to disk](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Cache-latents-to-disk) - * [Learning rate:](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Learning-rate) - * [LR Scheduler:](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LR-Scheduler) - * [LR warmup](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LR-warmup) - * [Optimizer](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Optimizer) - * [Optimizer extra arguments](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Optimizer-extra-arguments) - * [Text Encoder learning rate](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Text-Encoder-learning-rate) - * [Unet learning rate](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Unet-learning-rate) - * [Network Rank(Dimension)](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Network-RankDimension) - * [Network alpha:](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Network-alpha) - * [Max resolution](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Max-resolution) - * [Stop text encoder training](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Stop-text-encoder-training) - * [Enable buckets](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Enable-buckets) - * [Advanced Configuration](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Advanced-Configuration) - * [Weights、Blocks、Conv](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#WeightsBlocksConv) - * [Weights: Down LR weights/Mid LR weights/Up LR weights](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Weights-Down-LR-weightsMid-LR-weightsUp-LR-weights) - * [Weights: Blocks LR zero threshold](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Weights-Blocks-LR-zero-threshold) - * [Blocks: Block dims, Block alphas](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Blocks-Block-dims-Block-alphas) - * [Conv: Conv dims, Conv, alphas](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Conv-Conv-dims-Conv-alphas) - * [No token padding](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#No-token-padding) - * [Gradient accumulation steps](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Gradient-accumulation-steps) - * [Weighted captions](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Weighted-captions) - * [Prior loss weight](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Prior-loss-weight) - * [LR number of cycles](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LR-number-of-cycles) - * [LR power](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#LR-power) - * [Additional parameters](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Additional-parameters) - * [Save every N steps](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Save-every-N-steps) - * [Save last N steps](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Save-last-N-steps) - * [Keep n tokens](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Keep-n-tokens) - * [Clip skip](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Clip-skip) - * [Max Token Length](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Max-Token-Length) - * [Full fp16 training (experimental)](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Full-fp16-training-experimental) - * [Gradient checkpointing](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Gradient-checkpointing) - * [Shuffle caption](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Shuffle-caption) - * [Persistent data loader](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Persistent-data-loader) - * [Memory efficient attention](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Memory-efficient-attention) - * [Use xformers](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Use-xformers) - * [Color augmentation](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Color-augmentation) - * [Flip augmentation](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Flip-augmentation) - * [Min SNR gamma](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Min-SNR-gamma) - * [Don't upscale bucket resolution](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Dont-upscale-bucket-resolution) - * [Bucket resolution steps](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Bucket-resolution-steps) - * [Random crop instead of center crop](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Random-crop-instead-of-center-crop) - * [Noise offset type](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Noise-offset-type) - * [Noise offset](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Noise-offset) - * [Adaptive noise scale](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Adaptive-noise-scale) - * [Multires noise iterations](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Multires-noise-iterations) - * [Multires noise discount](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Multires-noise-discount) - * [Dropout caption every n epochs](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Dropout-caption-every-n-epochs) - * [Rate of caption dropout](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Rate-of-caption-dropout) - * [VAE batch size](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#VAE-batch-size) - * [Save training state](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Save-training-state) - * [Resume from saved training state](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Resume-from-saved-training-state) - * [Max train epoch](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Max-train-epoch) - * [Max num workers for DataLoader](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Max-num-workers-for-DataLoader) - * [WANDB API Key](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#WANDB-API-Key) - * [WANDB Logging](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#WANDB-Logging) - * [Sample images config](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Sample-images-config) - * [Sample every n steps](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Sample-every-n-steps) - * [Sample every n epochs](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Sample-every-n-epochs) - * [Sample sampler](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Sample-sampler) - * [Sample prompts](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#Sample-prompts) - * [まとめ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#まとめ) - - - -### LoRAの仕組みを知ろう - -kohya_ssの各設定の意味を知るには、LoRAがどういうメ[カニ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AB%25A5%25CB)ズムで追加学習をするのか知っておく必要があります。 - -追加学習の対象である「モデル」とは何なのかも合わせて説明します。 - - - -#### 「モデル」とは - -Stable Diffusionは「 **モデル** 」と呼ばれるモジュールを読み込んで使います。モデルとはいわば「脳みそ」で、その正体は「 -**[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)のウェイト情報**」です。 - -[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)はたくさんの「 -**[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)** -」からできていて、[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)のまとまりが何層もの「 -**レイヤー** -」を形作っています。あるレイヤーの[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)は違うレイヤーの[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)と線でつながっていて、そのつながりの強さを表すのが「 -**ウェイト** 」です。膨大な絵の情報を保持しているのは、この「ウェイト」なのです。 - - - -#### -LoRAは小さい[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を追加する - -LoRAは「追加学習」の一種ですが、追加学習とは[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)をバージョンアップすることです。 - -その方法はいろいろありますが、まず思いつくのは下の図のようにモデル全部を学習しなおす方法です。 - - - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230512/20230512000033.png) - -追加学習でモデルを鍛えなおす - -「DreamBooth」という追加学習法がこの方法を使っています。 - -この方法だと、もし追加学習データを公開したい場合、追加学習で新しくなったモデルを丸ごと配布する必要があります。 - -モデルのサイズは通常2G~5Gバイトあり、配布はなかなか大変です。 - -これに対して、LoRA学習ではモデルには手を付けず、新しい「小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)」を学習したい位置ごとに作ります。追加学習は、この小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を対象にして行われます。 - -LoRAを配布したいときはこの小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)だけを配布すればいいので、データサイズが少なく済みます。 - - - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230512/20230512005959.png) - -RoLA学習は小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)対象 - - - -#### 小さい[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)の構造 - -LoRAの小さい[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)は3つの層からできています。左の「入力層」、右の「出力層」の[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)の数は、ターゲットの[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)の「入力層」「出力層」の[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)数と同じです。真ん中の層(中間層)の[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)数は「ランク数」(または次元数)と呼ばれ、この数は学習するときに自由に決めることができます。 - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230512/20230512011058.png) - -小さな[ニューラルネットワーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8%25A5%25EF%25A1%25BC%25A5%25AF)の構造 - - - -では、この小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)はどこに追加されるのでしょう? - - - -#### LoRA学習対象1:U-Net - -下の図はStable -Diffusionの心臓部である「U-Net」というメ[カニ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AB%25A5%25CB)ズムです。 - -U-Netは「Down」(左半分)「Mid」(一番下)「Up」(右半分)に分けられます。 - -そして、Down12ブロック、Mid1ブロック、Up12ブロックの合計25ブロックからできています。 - -ここの中の赤い矢印の部分(オレンジ色のブロック)がLoRA学習対象です。つまり、この赤い矢印のブロックに小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)が追加されます。 - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230512/20230512011350.png) - -赤い矢印のブロックがLoRA学習対象の「Attentionブロック」 - - - -オレンジ色のブロックでは「テキスト処理」、つまりプロンプトとして与えられたテキストを画像に反映させる処理を行っています。 - -このブロックをさらに細かく見ると、以下のような処理を行っています。 - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230512/20230512011718.png) - -赤い矢印の部分にそれぞれ小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)が追加される - -ここにも赤い矢印がいくつもついていますが、この赤い矢印の処理全部にそれぞれ別の[ニューラルネットワーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8%25A5%25EF%25A1%25BC%25A5%25AF)が追加されます。 - -ここに追加される[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)のことをKohya_ssでは単純に「UNet」と呼んでいます。 - - - -#### -RoLA学習対象2:テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC) - -LoRAが[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を追加するのはここだけではありません。 - -上の図の「Cross -Attention」というブロックは、「テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)」というモジュールからテキスト情報を受け取ります。この「テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)」は、テキストデータであるプロンプトを数字の列(ベクトル)に変換するという役割があります。 - -テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)は1つしかなく、U-Net内のすべてのAttentionブロックで共通で使われます。このテキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)は本来Stable -Diffusion内では「完成品」として扱われ、モデル学習の対象にはなりませんが、LoRAによる追加学習ではこれも学習対象です。 - -LoRAでアップデートしたテキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)はすべてのAttentionブロックで使われるので、ここに追加される[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)は完成画像にとても大きな影響を及ぼします。 - -ここに追加される[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)の事をKohya_ssでは「Text -Encoder」と呼んでいます。 - - - -### kohya_ssを立ち上げてみよう - -LoRA学習のしくみを見たので、いよいよkohya_ssを使ってみましょう。 - -kohya_ssフォルダ内にある「[gui](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/gui).bat」をダブルクリックすると、[コマンドプロンプト](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25B3%25A5%25DE%25A5%25F3%25A5%25C9%25A5%25D7%25A5%25ED%25A5%25F3%25A5%25D7%25A5%25C8)(黒背景の文字だけのウィンドウ)が立ち上がります。しばらくするとそこにURLが表示されるので、それをウェブブラウザのURL欄に入力してリターンを押すとkohya_ssの画面がブラウザ上に表示されます。 - - - -kohya_ssを立ち上げると、UIの上部にタブがいくつか現れます。この中の「Dreambooth -LoRA」を選びましょう。これがLoRA学習のためのタブです。 - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230518/20230518004131.png) - -左から2番目のタブがLoRA学習設定 - - - -### LoRA学習の各設定 - -「Dreambooth LoRA」タブを選ぶと、たくさんの設定が出てきます。それらをここで解説します。 - - - -#### LoRA設定のセーブ、ロード - -一番上にあるのは「コンフィグファイル」です。ここでLoRA設定をコンフィグファイルとしてセーブ、ロードすることができます。 - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230518/20230518004415.png) - -設定はコンフィグファイルとしてセーブ、ロードできる - -設定をコンフィグファイルに保存しておけば後でそのコンフィグファイルをロードして設定を復元できるので、お気に入りの設定はなるべく保存しておきましょう。 - - - - - -次に、4つのタブがあります。最初の3つについてそれぞれ詳しく見ていきます。 - -(「Tools」タブはLoRA学習時には使わないので説明は省略します。) - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230518/20230518004914.png) - -モデル選択、学習画像フォルダ、詳細設定、ツールの各タブ - - - -#### Source modelタブ: 学習に使うベースモデルの設定 - - - -##### Pretrained model name or path - -ここにベースモデルの場所を指定します。最初に説明した通り、LoRAは既存のモデルに小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を追加します。つまり、LoRA学習とは「ベースモデル+α」の「+α」の部分を作る作業です。 - -LoRA学習はベースモデルの特徴に大きな影響を受けるので、 - - * 学習する画像と相性のいいベースモデル - * 画像生成時に使う(と想定される)モデルと相性のいいベースモデル - -を選ぶ必要があります。例えば学習画像が実写のような画像なら、実写生成が得意なモデルを選ぶといいでしょう。学習画像が2次元調でも実写調の画像生成を想定しているなら、2次元調と3次元調の混合モデルを選ぶべきかもしれません。 - -なお、学習後にできたLoRAファイルは「追加された[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)」のデータのみで、ベースモデルのデータは入っていません。そのため、完成したLoRAファイルを使って画像を生成するときは、ここで指定したベースモデルだけでなく、どんなモデルとも一緒に使うことができます。 - - - -##### Model Quick Pick - -ここでモデルを選ぶと、学習実行時にそのモデルが自動的にネット経由でダウンロードされ、ベースモデルとして使用されます。Pretrained model -name or pathの指定は無視されます。 - -もしモデルがPCに保存されていない場合や、どのモデルを使っていいのか分からない場合は、ここで選択できるモデルを選んで使いましょう。 - -「runwayml/stable-diffusion-v1-5」が使われることが多いようです。 - -自分で用意したモデルを使いたい場合はcustomにしましょう。 - - - -##### Save trained model as - -学習済みのLoRAファイルをどのファイル形式で保存するかを指定できます。 - -ckptはかつてStable -Diffusionで使われていた主流の形式でしたが、この形式にはセキュリティ上問題があったため、safetensorsというより安全なファイル形式が生まれました。現在ではsafetensorsが主流となっています。 - -特別な理由がない限りsafetensorsを選びましょう。 - - - -##### v2 - -Stable -Diffusionのモデルは「バージョン1系」と「バージョン2系」の2つのバージョンがあり、これらはデータ構造がそれぞれ違います。バージョン2系は(2023年5月時点で)まだ普及しておらず、ほとんどの有名モデルは「バージョン1系」です。 - -しかし、「バージョン2系」のモデルをベースモデルとして使う場合はこのオプションをオンにしましょう。 - -デフォルトはオフです。 - - - -##### v_parameterization - -v-parameterizationとは「バージョン2系」モデルで導入された手法で、従来よりも少ないサンプリングステップで安定して画像を生成するためのトリックです。 - -「バージョン1系」のモデルを使用するときはこのオプションはオフで構いませんが、お使いのベースモデルがv- -parameterizationを導入していることが分かっている場合はここをオンにしてください。これをオンにするときは必ずv2もオンにしましょう。 - -デフォルトはオフです。 - - - -#### Foldersタブ: 学習画像の場所とLoRA出力先の設定 - -##### Image folder - -学習画像を含むフォルダ(「10_cat」のような名前のフォルダ)がある場所を指定します。 - -「画像がある場所」ではありません!「画像を含むフォルダがある場所」を指定しましょう。 - - - -##### Output folder - -完成後のLoRAファイルの出力先を指定します。学習の途中経過のLoRAを出力する場合(後述)も、ここで指定した出力先に出力されます。 - - - -##### Regularisation folder - -LoRA学習では、学習画像の特徴が意図しない単語に強く結びつきすぎてしまい、その単語を入れるたびに学習画像に似た画像しか生成しなくなる、ということがしばしば起こります。 - -そこで「[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像」という「学習画像っぽくない」画像を一緒に学習させることで、特定の単語に学習対象が強く結びついてしまうのを防ぐことができます。 - -[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像の使用は必須ではありませんが、もし[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像を学習に使う場合は、ここで[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像を含んだフォルダの場所を指定します。 - - - -##### Logging folder - -学習時の処理をログとして出力して保存したい場合、その場所を指定します。 - -ここで指定した名前のフォルダが作業フォルダ内に作成され、さらにその中に学習日時を表す名前のフォルダができます。ログはそこに保存されます。 - -なお、ログファイルは「tensorboard」というツールまたは「WandB」というオンラインサービス(後述)でグラフ化できます。 - - - -##### Model output name - -完成したLoRAファイルの名前をここで指定します。拡張子をつける必要はありません。 - -「〇〇_ver1.0」(〇〇は学習対象の名前)のようにバージョン番号付きの名前にすると分かりやすいでしょう。 - -なお、名前には日本語を使わないようにしましょう。 - - - -##### Training comment - -完成したLoRAファイルには「[メタデータ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25E1%25A5%25BF%25A5%25C7%25A1%25BC%25A5%25BF)」としてテキストを埋め込むことができます。もし埋め込みたいテキストがある場合はここに記述します。 - -なお、[メタデータ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25E1%25A5%25BF%25A5%25C7%25A1%25BC%25A5%25BF)はStable -Diffusion WebUIのLoRA選択画面でⓘマークをクリックすると見ることができます。 - - - -#### Training parametersタブ: 学習の詳細設定 - -このタブではLoRA学習のためのほぼすべてのパラメータを設定します。 - - - -##### LoRA type - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230518/20230518010422.png) - -LoRA学習のタイプを指定します。上で解説したLoRAは「スタンダード」タイプです。「DyLoRA」は指定したRank以下の複数のランクを同時に学習するので、最適なランクを選びたいときに便利です。LoHaは高効率なLoRA、LoConは学習をU- -NetのResブロックまで広げたものです。 - -最初はStandardタイプで問題ありません。学習がうまくいかないときはほかのタイプを選んでみましょう。 - - - -##### LoRA network weights - -既に学習済みのLoRAファイルを使ってさらに追加学習をしたいときは、ここでLoRAファイルを指定します。 - -ここで指定したLoRAは学習開始時に読み込まれ、このLoRAの状態から学習がスタートします。学習後のLoRAはまた別のファイルとして保存されるので、ここで指定したLoRAファイルが上書きされることはありません。 - - - -##### DIM from weights - -これはLoRA network weightsで追加学習を行うとき限定のオプションです。 - -上の図にある通り、LoRAは小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を追加しますが、その中間層の[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)数(ランク数)はNetwork -Rank(後述)で自由に設定することができます。 - -しかし、このオプションをオンにすると、作成するLoRAのランク数がLoRA network -weightsで指定したLoRAと同じランク数に設定されます。ここをオンにしたときはNetwork Rankの指定は無視されます。 - -例えば追加学習に使うLoRAのランク数が32の時、作成するLoRAのランク数も32に設定されます。 - -デフォルトはオフです。 - - - -##### Train batch size - -バッチサイズを指定します。バッチとは「いっぺんに読み込む画像の枚数」です。バッチサイズ2なら、一度に2枚の画像を同時に学習します。違う絵を複数同時に学習すると個々の絵に対するチューニング精度は落ちますが、複数の絵の特徴を包括的にとらえる学習になるので、最終的な仕上がりはかえって良くなる可能性があります。 - -(特定の絵にチューニングしすぎると応用の利かないLoRAになってしまいます。) - -複数の絵を一度に学習するのでバッチサイズを上げれば上げるほど学習時間が短くなりますが、チューニング精度が下がるうえウェイト変更数も減るので、場合によっては学習不足になる可能性があります。 - -(バッチサイズを上げるときは学習率(Learning -rate、後述します)も上げた方がいいという報告もあります。例えばバッチサイズ2なら学習率を2倍にする、といった感じです。) - -また、バッチサイズを上げるほどメモリを多く消費します。お使いのPCのVRAMのサイズに合わせて決めましょう。 - -VRAMが6GBあればバッチサイズ2もかろうじて可能でしょう。 - -デフォルトは1です。 - -※バッチごとに同時に読み込む画像はすべて同じサイズでなければならないので、学習画像のサイズがバラバラだと、ここで指定したバッチ数よりも少ない枚数しか同時処理しないことあります。 - - - -##### Epoch - -1エポックは「1セットの学習」です。 - -例えば50枚の画像をそれぞれ10回ずつ読み込んで学習したいとします。この場合、1エポックは50x10=500回の学習です。2エポックならこれを2回繰り返すので、500x2=1000回の学習になります。 - -指定されたエポック数の学習が終わった後に、LoRAファイルが作成され、指定の場所に保存されます。 - -LoRAの場合、2~3エポックの学習でも十分効果を得られます。 - - - -##### Save every N epochs - -ここで指定したエポック数ごとに、途中経過をLoRAファイルとして保存することができます。 - -例えば「Epoch」で10と指定し、「Save every N -epochs」を2に指定すると、2エポックごと(2、4、6、8エポック終了時)に指定フォルダにLoRAファイルが保存されます。 - -途中経過のLoRA作成が不要の場合は、ここの数値を「Epoch」と同じ数値にしましょう。 - - - -##### Caption Extension - -もし画像ごとにキャプションファイルを用意している場合、そのキャプションファイルの拡張子をここで指定します。 - -ここが空欄の場合、拡張子は「.caption」になります。もしキャプションファイルの拡張子が「.txt」の時は、ここに「.txt」と指定しておきましょう。 - -キャプションファイルがない場合は、無視してかまいません。 - - - -##### Mixed precision - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230518/20230518015736.png) - -学習時のウェイトデータの混合精度のタイプを指定します。 - -本来ウェイトデータは32ビット単位(no選択の場合)ですが、必要に応じて16ビット単位のデータも混ぜて学習するとかなりのメモリ節約、スピードアップにつながります。fp16は精度を半分にした[データ形式](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C7%25A1%25BC%25A5%25BF%25B7%25C1%25BC%25B0)、bf16は32ビットデータと同じ数値の幅を取り扱えるよう工夫した[データ形式](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C7%25A1%25BC%25A5%25BF%25B7%25C1%25BC%25B0)です。 - -fp16で十分精度の高いLoRAを得られます。 - - - -##### Save precision - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230518/20230518021841.png) - -LoRAファイルに保存するウェイトデータのタイプを指定します。 - -floatは32ビット、fp16とbf16は16ビット単位です。下の二つの方がファイルサイズが小さくなります。 - -デフォルトはfp16です。 - - - -##### Number of CPU threads per core - -学習時のCPUコアごとのスレッドの数です。基本的に数値が大きいほど効率が上がりますが、スペックに応じて設定を調節する必要があります。 - -デフォルトは2です。 - - - -##### Seed - -学習時には「どういう順番で画像を読み込むか」や「学習画像にノイズをどれくらい乗せるか(詳細は省略)」など、ランダムな処理がいくつもあります。 - -Seedはそのランダム処理の手順を決めるためのIDのようなもので、同じSeedを指定すれば毎回同じランダム手順が使われるので学習結果を再現しやすくなります。 - -ただ、このSeedを使わないランダム処理(例えば画像をランダムに切り抜く処理など)もあるので、同じSeedを指定しても必ず同じ学習結果が得られるとは限りません。 - -デフォルトは空欄です。指定しなければ学習実行時にSeedが適当に設定されます。 - -結果をなるべく再現したいなら適当に(1234とか)数字を設定しておいて損はありません。 - - - -##### Cache latents - -学習画像はVRAMに読み込まれ、U-Netに入る前にLatentという状態に「圧縮」されて小さくなり、この状態でVRAM内で学習されます。通常、画像は読み込まれるたびに毎回「圧縮」されますが、Cache -latentsにチェックを入れると、「圧縮」済みの画像をメインメモリに保持するよう指定できます。 - -メインメモリに保持するとVRAMのスペース節約になり、スピードも上がりますが、「圧縮」前の画像加工ができなくなるので、flip_aug以外のaugmentation(後述)が使えなくなります。また、画像を毎回ランダムな範囲で切り抜くrandom -crop(後述)も使えなくなります。 - -デフォルトはオンです。 - - - -##### Cache latents to disk - -Cache latentsオプションと似ていますが、ここにチェックを入れると、圧縮画像データを一時ファイルとしてディスクに保存するよう指定できます。 - -kohya_ssを再起動した後もこの一時ファイルを再利用できるので、同じデータで何度もLoRA学習をしたい場合はこのオプションをオンにすると学習効率が上がります。 - -ただし、これをオンにするとflip_aug以外のaugmentationとrandom cropが使えなくなります。 - -デフォルトはオフです。 - - - -##### Learning rate: - -学習率を指定します。「学習」とは、与えられた絵とそっくりな絵を作れるように[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)内の配線の太さ(ウェイト)を変えていくことですが、毎回絵が与えられるごとにゴッソリ配線を変えてしまうと、与えられた絵のみにチューニングしすぎて、他の絵がまったく描けなくなってしまいます。 - -これを避けるため、毎回、与えられた絵をちょっとだけ取り込むように、ちょっとだけウェイトを変えます。この「ちょっとだけ」の量を決めるのが「学習率」(Learning -rate)です。 - -デフォルト値は0.0001です。 - - - -##### LR Scheduler: - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230519/20230519204207.png) - -学習の途中で学習率(Learning rate)を変えることができます。スケジューラーとは「どういうふうに学習率を変えていくかの設定」です。 - - * adafactor:[オプティマ](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AA%25A5%25D7%25A5%25C6%25A5%25A3%25A5%25DE)イザー(後述)をAdafactorに設定する場合はこれを選択する。VRAM節約のため状況に応じて学習率を自動調節しながら学習 - * constant:学習率は最初から最後まで変わらない - * constant_with_warmup:最初は学習率0から始めてウォームアップ中にLearning rate設定値に向けてだんだん増やし、本学習の時はLearning rate設定値を使う - * [cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://d.hatena.ne.jp/keyword/cosine):波(コ[サインカーブ](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25B5%25A5%25A4%25A5%25F3%25A5%25AB%25A1%25BC%25A5%25D6))を描きながら学習率をだんだん0に向けて減らす - * [cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://d.hatena.ne.jp/keyword/cosine)_with_restarts:[cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://d.hatena.ne.jp/keyword/cosine)を何度も繰り返す(LR number of cyclesの説明も見てください) - * linear:最初はLearning rate設定値で始め、0に向けて一直線に減らす - * polynomial:挙動はlinearと同じ、減らし方が少し複雑(LR powerの説明も見てください) - -学習率をLearning rate設定値に固定したいならconstantにしてください。 - -デフォルトは[cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/cosine)です。 - - - -##### LR warmup - -スケジューラーでconstant_with_warmupを選択した場合、ウォームアップをどれくらいの回数行うかをここで設定します。 - -ここで指定する数値は全体のステップ数のパーセントです。 - -例えば、50枚の画像をバッチサイズ1で10回学習、これを2エポック行うとき、総ステップ数は50x10x2=1000です。もしLR -warmupを10に設定すると、総ステップ1000のうち最初の10%、つまり100ステップがウォームアップになります。 - -スケジューラーがconstant_with_warmupでないならここは無視して構いません。 - -デフォルトは10です。 - - - -##### Optimizer - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230526/20230526012657.png) - -[オプティマ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AA%25A5%25D7%25A5%25C6%25A5%25A3%25A5%25DE)イザーとは「学習中に[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)のウェイトをどうアップデートするか」の設定です。賢く学習するためにいろいろな手法が提案されていますが、LoRA学習で最もよく使われるのは「AdamW」(32ビット)または「AdamW8bit」です。AdamW8bitはVRAMの使用量も低く、精度も十分なので迷ったらこれを使いましょう。 - -その他、Adam手法を取り入れつつ学習の進み具合に応じて学習率を適切に調節する「Adafactor」もよく使われるようです(Adafactorを使う場合はLearning -rate設定は無視されます)。 - -「DAdapt」は学習率を調節する[オプティマ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AA%25A5%25D7%25A5%25C6%25A5%25A3%25A5%25DE)イザー、「Lion」は比較的新しい[オプティマ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AA%25A5%25D7%25A5%25C6%25A5%25A3%25A5%25DE)イザーですがまだ十分検証されていません。「SGDNesterov」は学習精度は良いものの速度が下がるという報告があります。 - -デフォルトはAdamW8bitです。基本的にこのままで問題ありません。 - - - -##### Optimizer extra arguments - -指定した[オプティマ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25AA%25A5%25D7%25A5%25C6%25A5%25A3%25A5%25DE)イザーに対してさらに細かく設定したい場合は、ここでコマンドを書きます。 - -通常は空欄のままで構いません。 - - - -##### Text Encoder learning rate - -テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)に対する学習率を設定します。最初のほうで書いた通り、テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)の追加学習の影響はU- -Net全体に及びます。 - -そのため、通常はU-Netの各ブロックに対する学習率(Unet learning rate)よりも低くしておきます。 - -デフォルト値は0.00005(5e-5)です。 - -ここで数値を指定した場合、Learning rateの値よりもこちらが優先されます。 - - - -##### Unet learning rate - -U-Netに対する学習率を設定します。U-Netの中にある各Attentionブロック(設定によっては他のブロックも)に追加学習を行うときの学習率です。 - -デフォルト値は0.0001です。 - -ここで数値を指定した場合、Learning rateの値よりもこちらが優先されます。 - - - -##### Network Rank(Dimension) - -記事の上の方で説明した「追加する小さな[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)」の中間層の[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)の数を指定します(詳細は上の図を見てください)。 - -[ニューロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25ED%25A5%25F3)の数が多いほど学習情報を多く保持できますが、学習対象以外の余計な情報まで学習してしまう可能性が高くなり、LoRAのファイルサイズも大きくなります。 - -一般的に最大128程度で設定することが多いですが、32で十分という報告もあります。 - -試験的にLoRAを作る場合は2~8あたりから始めるのがいいかもしれません。 - -デフォルトは8です。 - - - -##### Network alpha: - -これは、LoRA保存時にウェイトが0に丸め込まれてしまうのを防ぐための便宜上の処置として導入されたものです。 - -LoRAはその構造上、[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)のウェイト値が小さくなりがちで、小さくなりすぎるとゼロ(つまりなにも学習していないのと同じ)と見分けがつかなくなってしまう恐れがあります。そこで、実際の(保存される)ウェイト値は大きく保ちつつ、学習時には常にウェイトを一定の割合で弱めてウェイト値を小さく見せかける、というテクニックが提案されました。この「ウェイトを弱める割合」を決めるのがNetwork -alphaです。 - -**Network -alpha値が小さければ小さいほど、保存されるLoRAの[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)のウェイト値が大きくなります。** - - - -使用時にウェイトがどれだけ弱まるか(使用強度)は「Network_Alpha/Network_Rank」で計算され(ほぼ0~1の値)、Network -Rank数と深く関係しています。 - -学習後のLoRAの精度がいまいちな場合、ウェイトデータが小さすぎて0に潰れてしまっている可能性があります。そんな時はNetwork -Alpha値を下げてみる(=保存ウェイト値を大きくする)とよいでしょう。 - -デフォルトは1(つまり保存ウェイト値をなるべく最大にする)です。 - -Network AlphaとNetwork Rankが同じ値の場合、効果はオフになります。 - -※Network Alpha値がNetwork Rank値を超えてはいけません。超える数字を指定することは可能ですが、高確率で意図しないLoRAになります。 - -また、Network Alphaを設定するときは、学習率への影響を考える必要があります。 - -例えばAlphaが16、Rankが32の場合、ウェイトの使用強度は16/32 = 0.5になり、つまり学習率が「Learning -Rate」設定値のさらに半分の効力しか持たないことになります。 - -AlphaとRankが同じ数字であれば使用強度は1になり、学習率に何の影響も与えません。 - - - -##### Max resolution - -学習画像の最大解像度を「幅、高さ」の順で指定します。もし学習画像がここで指定した解像度を超える場合、この解像度まで縮小されます。 - -デフォルトは「512,512」です。多くのモデルがこのサイズの画像を使っているので、LoRA学習の時もこのサイズの画像を使うのが無難です。 - - - -##### Stop text encoder training - -テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)の学習は途中でストップすることができます。上で書いた通り、テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)のアップデートは全体に大きな影響を及ぼすので[過学習](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25B2%25E1%25B3%25D8%25BD%25AC)(学習画像にチューニングしすぎて他の画像が描けなくなる)に陥りやすく、ほどほどのところで学習をストップさせるのも[過学習](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25B2%25E1%25B3%25D8%25BD%25AC)を防ぐ一つの手です。 - -ここで指定した数値は全学習ステップのパーセントです。学習がこのパーセントに達したら、テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)は学習をストップします。 - -例えば、総ステップ数が1000だった場合、ここで80と指定したら、学習進行度が80%の時、つまり1000x0.8=800ステップの時点でテキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)の学習が終了します。 - -U-Netの学習は残り200ステップで引き続き行われます。 - -ここが0の場合、テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)の学習は最後までストップしません。 - - - -##### Enable buckets - -「[bucket](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/bucket)」とはその名の通り「バケツ」(入れ物)です。LoRAで使う学習画像はサイズが統一されていなくてもかまわないのですが、違うサイズの画像を同時に学習することはできません。そのため、学習前に画像をサイズに応じて「バケツ」に振り分ける必要があります。似たサイズの画像は同じバケツに入れ、違うサイズの画像は別のバケツに入れていきます。 - -デフォルトはオンです。 - -もし学習画像のサイズがすべて同じならこのオプションはオフにして構いませんが、オンのままでも影響はありません。 - -※もし学習画像のサイズが統一されていない時にEnable bucketsをオフにすると、学習画像は拡大、縮小されてサイズが同じ大きさに揃えられます。 - -拡大、縮小は画像の[アスペクト比](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A2%25A5%25B9%25A5%25DA%25A5%25AF%25A5%25C8%25C8%25E6)を保ったまま行われます。[アスペクト比](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A2%25A5%25B9%25A5%25DA%25A5%25AF%25A5%25C8%25C8%25E6)が基準サイズと同じでない場合、拡大縮小後の画像のタテかヨコが基準サイズからはみ出すことがあります。例えば、基準サイズが512x512([アスペクト比](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A2%25A5%25B9%25A5%25DA%25A5%25AF%25A5%25C8%25C8%25E6)1)で、画像サイズが1536x1024([アスペクト比](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A2%25A5%25B9%25A5%25DA%25A5%25AF%25A5%25C8%25C8%25E6)1.5)の場合、画像は縮小されて768x512([アスペクト比](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A2%25A5%25B9%25A5%25DA%25A5%25AF%25A5%25C8%25C8%25E6)1.5のまま)になります。 - - - -#### Advanced Configuration - -ここより後は、「Advanced Configuration」セクションにあるオプションです。 - - - -##### Weights、Blocks、Conv - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230519/20230519233453.png) - -これらはU-Net内の各ブロックの「学習の重み付け」と「ランク」の設定です。それぞれのタブを選択すると、対応する設定画面が表示されます。 - -※これらの設定は上級者向けです。こだわりがないならすべて空欄のままで構いません。 - - - -##### Weights: Down LR weights/Mid LR weights/Up LR weights - -U-Netの構造図からわかる通り、U-Netは12個のINブロック、1個のMIDブロック、12個のOUTブロックの計25個のブロックからできています。 - -それぞれのブロックの学習率のウェイト(重み)を変えたい場合、ここで個別に設定することができます。 - -ここでいうウェイトとは0~1の数値で表される「学習の強さ」で、0の場合は「まったく学習しない」、1の場合は「Learning -rateで設定した学習率で学習」という感じで学習の強さを変えることができます。 - -ウェイトを0.5にした場合、Learning rateの半分の学習率になります。 - -「Down LR weights」は12個のINブロックのそれぞれのウェイトを指定します。 - -「Mid LR weights」はMIDブロックのウェイトを指定します。 - -「Up LR weights」は12個のOUTブロックのそれぞれのウェイトを指定します。 - - - -##### Weights: Blocks LR zero threshold - -「LoRAは[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を追加する」と説明しましたが、ウェイトが小さすぎる(つまりほとんど学習していない)[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)は追加しても意味がありません。そこで、「ウェイトが小さすぎるブロックには[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)を追加しない」という設定ができます。 - -ここで設定したウェイト値を超えないブロックでは、[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)が追加されません。例えばここに0.1と指定した場合、ウェイトを0.1以下に設定したブロックには[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)が追加されません(排除対象が指定値も含んでいることに注意してください!)。 - -デフォルトは空欄で、空欄の場合は0(何もしない)です。 - - - -##### Blocks: Block dims, Block alphas - -ここで、IN0~11、MID、OUT0~11の25個の各ブロックに対しそれぞれ違うランク(dim)値とアルファ値を設定することができます。 - -ランク値とアルファ値についてはNetwork Rank、Network alphaの説明を見てください。 - -ランクの大きいブロックはより多くの情報を保持できることが期待されます。 - -このパラメータ値は常に25個の数字を指定しなければいけませんが、LoRAはAttentionブロックを学習対象としているので、Attentionブロックの存在しないIN0、IN3、IN6、IN9、IN10、IN11、OUT0、IN1、IN2に対する設定(1、4、7、11、12、14、15、16番目の数字)は学習時は無視されます。 - -※上級者向け設定です。こだわりがないなら空欄のままで構いません。ここを指定しない場合は「Network Rank(Dimension)」値と「Network -Alpha」値がすべてのブロックに適応されます。 - - - -##### Conv: Conv dims, Conv, alphas - -LoRAが学習対象としているAttentionブロックには「Conv」という[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)があり、そこも追加学習によりアップデートされます(記事上部のAttention層の構造の図を見てください)。これは「畳み込み」と言われる処理で、そこで使われている「フィルター」の大きさは1x1マスです。 - -畳み込みについては[この記事](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://github.com/kohya-ss/sd-scripts/pull/121)を読んでください。 - -一方、Attention以外のブロック(Res、Downブロック)やOUTにあるAttentionブロックの一部には、[3x3](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/3x3)マスのフィルターを使った畳み込みを行っている部分もあります。本来そこはLoRAの学習対象ではありませんが、このパラメータで指定することでResブロックの[3x3](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/3x3)の畳み込みも学習対象にすることができます。 - -学習対象が増えるので、より精密なLoRA学習を行える可能性があります。 - -設定方法は「Blocks: Blocks dims, Blocks alphas」と同じです。 - -[3x3](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/3x3)のConvは25層すべてに存在します。 - -※上級者向け設定です。こだわりがないなら空欄のままで構いません。 - - - -##### No token padding - -学習画像につけるキャプションは、75[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ンごとに処理されます(「[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ン」は基本的に「単語」と捉えて問題ありません)。 - -キャプションの長さが75[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ン未満の場合、キャプションの後に終端記号が必要なだけ追加され、75[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ンに揃えられます。これを「パディング」と言います。 - -ここでは、[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ンのパディングを行わないようにする指定ができます。 - -デフォルトはオフです。基本的にオフのままで構いません。 - - - -##### Gradient accumulation steps - -ウェイトの変更(つまり「学習」)は通常は1バッチ読み込むごとに行いますが、学習を複数バッチまとめていっぺんに行うこともできます。何バッチまとめていっぺんに学習するかを指定するのがこのオプションです。 - -これはバッチ数を上げる働きと似た効果(「同じ効果」ではありません!)があります。 - -例えば、バッチサイズが4の場合、1バッチで同時に読み込まれる画像数は4枚です。つまり4枚読み込むごとに1回学習が行われます。ここでGradient -accumulation -stepsを2にすると、2バッチごとに1回学習が行われるので、結果的に8枚読み込むごとに1回学習が行われることになります。これはバッチ数8と似た働き(同じではありません!)です。 - -この数値を上げると学習回数が減るので処理が速くなりますがメモリを多く消費します。 - -デフォルトは1です。 - - - -##### Weighted captions - -現在一番人気のStable Diffusion利用環境は「Stable Diffusion -WebUI」ですが、これには独特のプロンプト記述法があります。例えばプロンプトに「[black -cat](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/black%2520cat)」と指定する時に「Black」をものすごく強調したい場合、「(black:1.2) -cat」という風に強調したいワードをかっこで囲み、ワードの後に「:数字」と入れると、その数字の倍数だけワードが強調されます。 - -この記述法を学習画像のキャプションでも使えるようにするのがこのオプションです。 - -複雑なキャプションを書きたい場合は試してみるのもいいでしょう。 - -デフォルトはオフです。 - - - -##### Prior loss weight - -学習時に「[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像」(詳しくは上のRegularisation -folderの説明を見てください)をどれだけ重要視するかを決めるのがPrior loss weightです。 - -この値が低いと、[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像はそれほど重要でないと判断され、より学習画像の特徴が強く現れるLoRAが生成されます。 - -[正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像を使わない場合はこの設定は意味がありません。 - -これは0~1の値で、デフォルトは1([正則化](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25C0%25B5%25C2%25A7%25B2%25BD)画像も重視)です。 - - - -##### LR number of cycles - -スケジューラーに「[Cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Cosine) with -restart」または「Polynomial」を選んだ場合、学習中にスケジューラー何サイクル実行するかを指定するオプションです。 - -このオプションの数値が2以上の場合、1回の学習中にスケジューラーが複数回実行されます。 - -[Cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Cosine) with -restartもPolynomialも、学習が進むにつれて学習率が0までだんだん下がっていきますが、サイクル数が2以上の場合、学習率が0に達したら学習率をリセットして再スタートします。 - -下の図[(引用元)](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://github.com/kohya-ss/sd- -scripts/pull/121)は[Cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Cosine) with -restart(紫)とPolynomial(薄緑)の学習率の変化の例です。 - -紫の例ではサイクル数が4に設定されています。薄緑の例ではサイクル数は1です。 - -指定されたサイクル数を決められた学習ステップ内で実行するので、サイクル数が増えれば増えるほど、学習率の変化が激しくなります。 - -デフォルトは空欄で、空欄の場合は1になります。 - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230525/20230525001355.png) - -学習率の動きの例 -[Cosine](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Cosine) with restartで「LR -number of cycle = 4」 (紫) -Polynomialで「LR power = 2」 (薄緑) - - - -##### LR power - -これはスケジューラーにPolynomialを設定した場合のオプションで、この数が大きければ大きいほど最初の学習率の下がり方が急激になります。(上の図の薄緑の線のスロープが急激になります)。 - -powerが1の時はlinearスケジューラーと同じ形になります。 - -あまり数を大きくしすぎると学習率が0ちかくに張り付いてしまって学習不足に陥るので気をつけましょう。 - -デフォルトは空欄で、空欄の場合は1(つまりlinearスケジューラーと同じ)になります。 - - - -##### Additional parameters - -kohya_ssの[GUI](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/GUI)に表示されていない学習設定パラメータをいじりたい場合、ここでコマンドとして入力します。 - -通常は空欄のままで構いません。 - - - -##### Save every N steps - -ここで指定したステップ数の学習が終了するごとに、LoRAファイルが作成され、保存されます。 - -例えば総学習ステップ数が1000の時、ここで200と指定すると、200、400、600、800ステップ終了時にLoRAファイルが保存されます。 - -途中経過のLoRA保存については「Save every N epochs」も参照してください。 - -デフォルトは0(途中経過LoRAを保存しない)です。 - - - -##### Save last N steps - -学習途中のLoRAを保存するようSave every N stepsで指定した場合のオプションです。 - -もし最近のLoRAファイルのみ保持して古いLoRAファイルは破棄したい場合、ここで「最近何ステップ分のLoRAファイルを保持しておくか」を設定できます。 - -例えば総学習ステップが600の時、Save every N -stepsオプションで100ステップごとに保存するよう指定したとします。すると100、200、300、400、500ステップ目にLoRAファイルが保存されますが、Save -every N -stepsを300と指定した場合、最近300ステップ分のLoRAファイルのみ保持されます。つまり500ステップ目には200(=500-300)ステップ目より古いLoRA(つまり100ステップ目のLoRA)は消去されます。 - -デフォルトは0です。 - - - -##### Keep n tokens - -学習画像にキャプションがついている場合、そのキャプション内のコンマで区切られた単語をランダムに入れ替えることができます(詳しくはShuffle -captionオプションを見てください)。しかし、ずっと先頭に置いておきたい単語がある場合は、このオプションで「最初の〇単語は先頭に固定しておいて」と指定できます。 - -ここで指定した数の最初の単語は、いつも先頭に固定されます。 - -デフォルトは0です。Shuffle captionオプションがオフの場合はこのオプションは何もしません。 - -※ここでいう「単語」とは、コンマで区切られたテキストのことです。区切られたテキストがいくつ単語を含んでいようと、それは「1単語」としてカウントされます。 - -「[black cat](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/black%2520cat), eating, -sitting」の場合、「[black -cat](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/black%2520cat)」で1単語です。 - - - -##### Clip skip - -テキスト[エンコーダー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A8%25A5%25F3%25A5%25B3%25A1%25BC%25A5%25C0%25A1%25BC)には「CLIP」という仕組みが使われていますが、これは12層の似たようなレイヤーからできています。 - -テキスト([トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ン)は本来、この12層のレイヤーを通って数字の列(ベクトル)に変換され、最後のレイヤーから出てきたベクトルがU- -NetのAttentionブロックに送られます。 - -しかし、「Novel AI」というサービスが独自に開発したモデル、通称「Novel -AIモデル」は、最後のレイヤーでなく最後から2番目のレイヤーが出力したベクトルを使う、という独自仕様を採用しました。Novel -AIモデルから派生したモデルも同様です。そのため、「学習に使うベースモデルがCLIPのどのレイヤーから出てきたベクトルを使っているか」という指定が必要になります。 - -この「最後から〇番目」のレイヤー番号を指定するのが「Clip skip」です。 - -ここを2にすると、最後から2番目のレイヤーの出力ベクトルがAttentionブロックに送られます。1の場合は、最後のレイヤーの出力ベクトルが使われます。 - -ベースモデルにNovel AIモデル(またはそのミックスモデル)が使われている場合は、2にした方がいいでしょう。そのほかの場合は1で構いません。 - - - -##### Max Token Length - -![](https://cdn-ak.f.st- -hatena.com/images/fotolife/h/hoshikat/20230520/20230520021639.png) - -キャプションに含まれる最大の[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ンの長さを指定します。 - -ここでいう「[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ン」は単語数ではなく、[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ン数は単語数とだいたい同じ~1.5倍ぐらいの数になります。コンマも1[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ンとカウントされることに注意してください。 - -75[トーク](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C8%25A1%25BC%25A5%25AF)ンを超えるキャプションを使うことはめったにないでしょうが、「キャプションの文が長いな」と思ったときは、ここでより大きな数字を指定してください。 - - - -##### Full fp16 training (experimental) - -上で説明したオプション「Mixed -precision」をオン(fp16またはbf16)にすると、学習時に32ビットと16ビットのデータが混合して使用されますが、このオプションをオンにするとすべてのウェイトデータが16ビット(fp16形式)に揃えられます。メモリの節約にはなりますが、一部データ精度が半分になるので学習精度も落ちる可能性があります。 - -デフォルトはオフです。よっぽどメモリを節約したいとき以外はオフのままでいいでしょう。 - - - -##### Gradient checkpointing - -通常の場合、学習中は、画像が読み込まれるごとに膨大な数の[ニューラルネット](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25CB%25A5%25E5%25A1%25BC%25A5%25E9%25A5%25EB%25A5%25CD%25A5%25C3%25A5%25C8)のウェイトを一斉に修正しアップデートします。これを「一斉」でなく「少しずつ」修正することで、計算処理を減らしてメモリを節約できます。 - -このオプションはウェイト計算を少しずつ行うように指定します。ここをオンにしてもオフにしてもLoRAの学習結果に影響はありません。 - -デフォルトはオフです。 - - - -##### Shuffle caption - -学習画像にキャプションがついている場合、キャプションの多くは「[black -cat](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/black%2520cat), eating, -sitting」といった具合にコンマで区切られた単語が並んだ形式で書かれていることが多いでしょう。このコンマで区切られている単語の順番を毎回ランダムに入れ替えるのがShuffle -captionオプションです。 - -一般的にキャプション内の単語は先頭に近いほど重視されます。そのため、単語の順番が固定されていると後方の単語がうまく学習されなかったり、前方の単語が学習画像と意図しない結びつきをする可能性があります。画像を読み込むごとに毎回単語の順番を入れ替えることで、このかたよりを修正できることが期待されます。 - -キャプションがコンマ区切りでなく文章になっている場合はこのオプションは意味がありません。 - -デフォルトはオフです。 - -※ここでいう「単語」とは、コンマで区切られたテキストのことです。区切られたテキストがいくつ単語を含んでいようと、それは「1単語」としてカウントされます。 - -「[black cat](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/black%2520cat), eating, -sitting」の場合、「[black -cat](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/black%2520cat)」で1単語です。 - - - -##### Persistent data loader - -学習に必要なデータは1つのエポックが終わるごとに破棄され、再読み込みされます。これを破棄せずに保持しておくためのオプションです。このオプションをオンにすると新しいエポックの学習が始まる速度が上がりますが、データを保持する分メモリを消費します。 - -デフォルトはオフです。 - - - -##### Memory efficient attention - -これにチェックを入れるとVRAMの使用を抑えてAttentionブロックの処理を行います。次のオプションの「xformers」に比べてスピードは遅くなります。VRAMの容量が少ない場合はオンにしましょう。 - -デフォルトはオフです。 - - - -##### Use xformers - -「xformers」という[Python](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Python)ライブラリを使用すると、若干のスピード低下と引き換えにVRAMの使用を抑えてAttentionブロック処理を行います。VRAMの容量が少ない場合はオンにしましょう。 - -デフォルトはオンです。 - - - -##### Color augmentation - -「augmentation」とは「画像の水増し」を意味します。学習画像を毎回少し加工することにより、学習画像の種類を疑似的に増やします。 - -Color -augmentationをオンにすると、画像の色相(Hue)を毎回ランダムに少し変化させます。これによって学習したLoRAは色調に若干の幅が出ることが期待されます。 - -Cache latentsオプションがオンの場合は使用できません。 - -デフォルトはオフです。 - - - -##### Flip augmentation - -このオプションをオンにすると、ランダムに画像が左右反転します。左右のアングルを学習できるので、 **左右対称** -の人物や物体を学習したいときは有益でしょう。 - -デフォルトはオフです。 - - - -##### Min [SNR](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/SNR) gamma - -LoRA学習では学習画像にいろいろな強さのノイズを乗せて学習します(このあたりの詳細は省略)が、乗っているノイズの強さの違いによって学習目標に近寄ったり離れたりして学習が安定しないことがあり、Min -[SNR](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/SNR) -gammaはそれを補正するために導入されました。特にノイズがあまり乗っていない画像を学習するときは目標から大きく離れたりするので、このジャンプを抑えるようにします。 - -詳細はややこしいので省略しますが、この値は0~20で設定でき、デフォルトは0です。 - -この方法を提唱した論文によると最適値は5だそうです。 - -どれほど効果的なのかは不明ですが、学習結果に不満がある時はいろいろな値を試してみるといいでしょう。 - - - -##### Don't upscale -[bucket](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/bucket) resolution - -[Bucket](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Bucket)(バケツ)のサイズはデフォルトでは256~1024[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)(またはMax -resolutionオプションで最大解像度を指定している場合はそちらが優先されます)に設定されています。タテかヨコのどちらか一方でもこのサイズ範囲から外れた画像は、指定範囲内のサイズになるように([アスペクト比](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25A2%25A5%25B9%25A5%25DA%25A5%25AF%25A5%25C8%25C8%25E6)を保ったまま)拡大または縮小されます。 - -しかし、このオプションをオンにするとバケツサイズの範囲設定は無視され、学習画像のサイズに応じて自動的にバケツが用意されるので、すべての学習画像が拡大縮小されずに読み込まれるようになります。ただしこの時も[Bucket](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Bucket) resolution -steps(後述)にサイズを合わせるため画像の一部が切り取られる可能性はあります。 - -デフォルトはオンです。 - - - -##### [Bucket](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Bucket) resolution steps - -[Bucket](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Bucket)(バケツ)を使用する場合、各バケツの解像度間隔をここで指定します。 - -例えばここで64を指定した場合、それぞれの学習画像をサイズに応じて64[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)ごとに別のバケツに振り分けます。この振り分けはタテヨコそれぞれに対して行われます。 - -もし画像サイズがバケツの指定するサイズピッタリでない場合、はみ出した部分は切り取られます。 - -例えば、最大解像度が512[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)でバケツのステップサイズが64[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)ごとの場合、バケツは512、448、384…となりますが、500[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)の画像は448[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)のバケツに入れられ、サイズを合わせるため余分な52[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)が切り取られます。 - -デフォルトは64[ピクセル](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D4%25A5%25AF%25A5%25BB%25A5%25EB)です。 - -※この数値をあまり小さくしすぎるとバケツの振り分けが細かくなりすぎてしまい、最悪「画像1枚ごとに1つのバケツ」のような状態になってしまいます。 - -1バッチにつき必ず同じバケツから画像を読み込むので、バケツの中の画像が少なすぎるとバッチ数が意図せず少なくなってしまうことに注意してください。 - - - -##### Random crop instead of center crop - -上記のように、中途半端なサイズの画像はバケツに振り分けた後に一部が切り取られてサイズが揃えられますが、通常は画像の中心を保つように切り取られます。 - -このオプションをオンにすると、絵のどの部分が切り取られるかがランダムに決まります。学習の範囲を画像の中心以外に広げたいときはこのオプションをオンにします。 - -※cache latentsオプションをオンにしているときはこのオプションは使えません。 - - - -##### Noise offset type - -学習画像に追加ノイズを乗せるときに、どの手法で乗せるのかを指定するオプションです。学習時には必ず画像にノイズを乗せる(この辺の詳細は省略します)のですが、このノイズは「予測しづらい」ノイズである方がより好ましいため、さらに多くのノイズを乗せることでより「予測しづらい」ノイズにします。 - -デフォルトはOriginalです。Multiresはもう少し複雑な方法でノイズを追加します。 - - - -##### Noise offset - -Noise offset -typeに「Original」を選択したときのオプションです。ここで0より大きな値を入れると追加ノイズが乗ります。値は0~1で、0の場合はまったくノイズを追加しません。1の場合は強いノイズを追加します。 - -0.1程度のノイズを追加するとLoRAの色合いが鮮やかになる(明暗がはっきりする)という報告があります。デフォルトは0です。 - - - -##### Adaptive noise scale - -Noise offsetオプションとペアで使います。ここに数値を指定すると、Noise -offsetで指定した追加ノイズ量がさらに調整され増幅あるいは減衰します。増幅(または減衰)する量は、「画像に現在どのくらいノイズが乗っているか」によって自動的に調整されます。値は-1~1で、プラスを指定すると追加ノイズ量が増え、マイナスを指定した場合は追加ノイズ量が減ります。 - -デフォルトは0です。 - - - -##### Multires noise iterations - -Noise offset typeに「Multires」を選択したときのオプションです。ここで0より大きな値を入れると追加ノイズが乗ります。 - -Multiresでは、様々な解像度のノイズを作ってそれらを足すことで最終的な追加ノイズを作成します。ここでは「様々な解像度」をいくつ作るかを指定します。 - -デフォルトは0で、0の時は追加ノイズは乗りません。使用したい場合は6に設定することがが推奨されています。 - - - -##### Multires noise discount - -Multires noise -iterationsオプションとペアで使います。各解像度のノイズ量をある程度弱めるための数値です。0~1の値で、数字が小さいほどノイズがより弱まります。ちなみに弱める量は解像度によって違い、解像度の低いノイズはたくさん弱めます。 - -デフォルトは0で、0の場合は使用時に0.3に設定されます。通常は0.8が推奨されています。学習画像が比較的少ない場合は0.3程度に下げると良いようです。 - - - -##### Dropout caption every n epochs - -通常、画像とキャプションはペアで学習されますが、特定のエポックごとにキャプションを使わず「キャプションなしの画像」のみ学習させることができます。 - -このオプションは「〇エポックごとにキャプションを使わない([ドロップアウト](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C9%25A5%25ED%25A5%25C3%25A5%25D7%25A5%25A2%25A5%25A6%25A5%25C8))」という指定を行えます。 - -例えばここで2を指定すると、2エポックごとに(2エポック目、4エポック目、6エポック目…)キャプションを使わない画像学習を行います。 - -キャプションのない画像を学習すると、そのLoRAはより包括的な画像の特徴を学習することが期待されます。また、特定の単語に画像の特徴を結び付けすぎないようにする効果も期待できます。ただしあまりキャプションを使わなすぎると、そのLoRAはプロンプトの効かないLoRAになってしまう可能性があるので気をつけましょう。 - -デフォルトは0で、0の場合はキャプションの[ドロップアウト](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25C9%25A5%25ED%25A5%25C3%25A5%25D7%25A5%25A2%25A5%25A6%25A5%25C8)を行いません。 - - - -##### Rate of caption dropout - -上記のDropout caption every n -epochsと似ていますが、学習の全工程のうち、ある一定の割合だけキャプションを使わず「キャプションなしの画像」として学習させることができます。 - -ここでキャプションなし画像の割合を設定できます。0は「学習中必ずキャプションを使う」設定、1は「学習中キャプションを全く使わない」設定です。 - -どの画像が「キャプションなし画像」として学習されるかはランダムに決まります。 - -例えば、画像20枚をそれぞれ50回読み込むLoRA学習を1エポックだけ行う場合、画像学習の総数は20枚x50回x1エポック=1000回です。この時Rate -of caption dropoutを0.1に設定すると、1000回x0.1=100回は「キャプションなしの画像」として学習を行います。 - -デフォルトは0で、すべての画像をキャプション付きで学習します。 - - - -##### VAE batch size - -Cache -latentsオプションをオンにすると「圧縮」した状態の画像データをメインメモリに保持しておくことができますが、この「圧縮」画像を何枚一組で保持するかを設定するのがVAE -batch sizeです。バッチサイズ(Batch size)で指定した画像枚数を一度に学習するので、VAE batch -sizeもこれに合わせるのが普通です。 - -デフォルトは0で、この場合Batch sizeと同じ数値に設定されます。 - - - -##### Save training state - -学習画像、繰り返し数、エポック数が多いとLoRAの学習に長い時間がかかります。 - -このオプションをオンにすると、学習を途中で中断して後日続きから学習を再開することができます。 - -学習の途中経過データは「last-state」というフォルダに保存されます。 - - - -##### Resume from saved training state - -中断した学習を再開したい場合、ここに「last-state」フォルダの場所を指定します。 - -学習を再開するには、学習の途中経過データが保存されている必要があります。 - - - -##### Max train epoch - -学習のための最大エポック数を指定します。Epochオプションでエポック数を指定するのが基本ですが、ここで指定したエポック数に達すると必ず学習を終了します。 - -デフォルトは空欄です。空欄のままで構いません。 - - - -##### Max num workers for DataLoader - -学習のためのデータを読み込む時に使用するCPUプロセス数を指定するオプションです。この数値を上げるごとにサブプロセスが有効になりデータの読み込みスピードが上がりますが、数字を上げすぎるとかえって非効率になる場合があります。 - -なお、どれだけ大きい数字を指定しても、使用CPUの同時実行スレッド数以上にはなりません。 - -デフォルトは0で、CPUのメインプロセスでのみデータ読み込みを行います。 - - - -##### WANDB [API](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/API) Key - -「[WandB](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://wandb.ai/site)」(Weights&Biases)という[機械学習](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25B5%25A1%25B3%25A3%25B3%25D8%25BD%25AC)サービスがあります。これは最適な設定を見つけるために学習の進行状況をグラフで表示したり学習ログなどをオンラインで記録、共有するサービスですが、kohya_ssでもこのサービスを使用できるようになりました。 - -ただしこのサービスのアカウントが必要です。アカウントを作成した後、[https://app.wandb.ai/authorize](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://app.wandb.ai/authorize)から「[API](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/API) -key」を取得できます。取得した[API](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/API)キーをここに入力しておくと、学習時に自動的にログインし、WandBのサービスと連動できるようになります。 - -WandBに関する詳細は省きますが、「LoRA職人」を目指す人は試してみましょう。 - - - -##### WANDB Logging - -学習状況のログをWandBサービスを使って記録するかどうかをここで指定できます。 - -デフォルトはオフで、オフの場合は「tensorboard」というツールの形式でログを記録します。 - - - -#### Sample images config - -LoRAを使った画像生成がどんな感じになるのか学習途中でチェックしたい場合、ここで画像生成プロンプトを入力します。 - -ただ、LoRAは比較的学習時間が短いので、画像生成テストの必要はあまりないかもしれません。 - - - -##### Sample every n steps - -学習中、何ステップ目に画像を生成したいのかを指定します。例えば100と指定すると、100ステップごとに画像を生成します。 - -デフォルトは0で、0の場合は画像を生成しません。 - - - -##### Sample every n epochs - -学習中、何エポック目に画像を生成したいのかを指定します。例えば2と指定すると、2エポックごとに画像を生成します。 - -デフォルトは0で、0の場合は画像を生成しません。 - - - -##### Sample sampler - -画像生成に使う[サンプラー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25B5%25A5%25F3%25A5%25D7%25A5%25E9%25A1%25BC)を指定します。ここで指定する[サンプラー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25B5%25A5%25F3%25A5%25D7%25A5%25E9%25A1%25BC)の多くはStable -Diffusion Web -UIで用意されている[サンプラー](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25B5%25A5%25F3%25A5%25D7%25A5%25E9%25A1%25BC)と同じなので、詳細はWeb -UIの説明サイトを参照してください。 - -デフォルトはeuler_aです。 - - - -##### Sample prompts - -ここでプロンプトを入力します。 - -ただしここにはプロンプトだけでなく他の設定も入力できます。ほかの設定を入力する場合は「--n」のようにマイナス2つとアルファベットを組み合わせて設定を指定します。例えばネガティ[ブプロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D6%25A5%25D7%25A5%25ED%25A5%25F3)プトに「white, -dog」と入れたい場合、「--n white, dog」と書きます。 - -よく使いそうな設定の指定は以下の通りです。 - -\--n:ネガティ[ブプロン](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25A5%25D6%25A5%25D7%25A5%25ED%25A5%25F3)プト - -\--w:画像の幅 - -\--h:画像の高さ - -\--d:Seed - -\--l:CFG Scale - -\--s:ステップ数 - -デフォルトは空欄です。空欄の時に記述例が薄く表示されているので、それを参考にしてください。 - - - -### まとめ - -Stable Diffusionの追加学習のひとつであるLoRAのしくみと、LoRA学習を行うツールであるkohya_ssの各設定について解説しました。 - -設定する項目が非常に多いので混乱しそうですが、まずは推奨設定で軽く学習して、学習結果に応じて少しずつ設定を変えていくようにしましょう。 - -ここでの解説を参考にして、さらに高い精度のLoRA作成を目指してみてください。 - -[ Stable Diffusion ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/Stable%2520Diffusion?from%3Dhatenablog%26utm_source%3Dhoshikat.hatenablog.com%26utm_medium%3Dhatenablog%26utm_campaign%3Dblogtag%26utm_term%3DStable%2BDiffusion%26utm_content%3D%252Fentry%252F2023%252F05%252F26%252F223229) -[ お絵描きAI ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/%25E3%2581%258A%25E7%25B5%25B5%25E6%258F%258F%25E3%2581%258DAI?from%3Dhatenablog%26utm_source%3Dhoshikat.hatenablog.com%26utm_medium%3Dhatenablog%26utm_campaign%3Dblogtag%26utm_term%3D%25E3%2581%258A%25E7%25B5%25B5%25E6%258F%258F%25E3%2581%258DAI%26utm_content%3D%252Fentry%252F2023%252F05%252F26%252F223229) -[ LoRA ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/LoRA?from%3Dhatenablog%26utm_source%3Dhoshikat.hatenablog.com%26utm_medium%3Dhatenablog%26utm_campaign%3Dblogtag%26utm_term%3DLoRA%26utm_content%3D%252Fentry%252F2023%252F05%252F26%252F223229) -[ kohya_ss ](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://d.hatena.ne.jp/keyword/kohya_ss?from%3Dhatenablog%26utm_source%3Dhoshikat.hatenablog.com%26utm_medium%3Dhatenablog%26utm_campaign%3Dblogtag%26utm_term%3Dkohya_ss%26utm_content%3D%252Fentry%252F2023%252F05%252F26%252F223229) - -hoshikat [2023-05-26 22:32](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[![この記事をはてなブックマークに追加](https://b.st-hatena.com/images/entry-button/button- -only.gif)](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://b.hatena.ne.jp/entry/s/hoshikat.hatenablog.com/entry/2023/05/26/223229 -"この記事をはてなブックマークに追加") - -[Tweet](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://twitter.com/share) - -[広告を非表示にする](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=http://blog.hatena.ne.jp/guide/pro) - -関連記事 - - * [ ![誰でもわかるStable Diffusion テキストエンコーダー:CLIPのしくみ](https://cdn.image.st-hatena.com/image/square/a696ceac28eea86b473941ea554a68a3c78c1db1/backend=imagemagick;height=100;version=1;width=100/https%3A%2F%2Fcdn-ak.f.st-hatena.com%2Fimages%2Ffotolife%2Fh%2Fhoshikat%2F20230612%2F20230612183245.png) ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/06/13/002443?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -[ 2023-06-13 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023/06/13?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[誰でもわかるStable Diffusion テキストエンコーダー:CLIPのしくみ](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/06/13/002443?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -以前の記事でStable Diffusionがどのように絵を描いているか順… - - * [ ![誰でもわかるStable Diffusion LoRAを作ってみよう\(実践編\)](https://cdn.image.st-hatena.com/image/square/cb2cf5ca3ba98667012ef742fc824762a25929ff/backend=imagemagick;height=100;version=1;width=100/https%3A%2F%2Fcdn-ak.f.st-hatena.com%2Fimages%2Ffotolife%2Fh%2Fhoshikat%2F20230607%2F20230607174810.png) ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/06/07/215433?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -[ 2023-06-07 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023/06/07?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[誰でもわかるStable Diffusion LoRAを作ってみよう(実践編)](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/06/07/215433?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -以前の記事でLoRAを作るためのKohya_ss導入の解説を書きました*… - - * [ ![誰でもわかるStable Diffusion LoRAを作ってみよう\(導入編\)](https://cdn.image.st-hatena.com/image/square/39f7b31bec380175103c47de79c62e9e283be887/backend=imagemagick;height=100;version=1;width=100/https%3A%2F%2Fcdn-ak.f.st-hatena.com%2Fimages%2Ffotolife%2Fh%2Fhoshikat%2F20230505%2F20230505010543.png) ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/05/013600?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -[ 2023-05-05 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023/05/05?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[誰でもわかるStable Diffusion LoRAを作ってみよう(導入編)](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/05/05/013600?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -Stable Diffusionはそのままでも十分きれいな画像を描いてくれ… - - * [ ![誰でもわかるStable Diffusion その6:U-Net\(IN1、Resブロック\)](https://cdn.image.st-hatena.com/image/square/1d0b73be31a163a61a8b16a29af99ac4f17a735d/backend=imagemagick;height=100;version=1;width=100/https%3A%2F%2Fcdn-ak.f.st-hatena.com%2Fimages%2Ffotolife%2Fh%2Fhoshikat%2F20230331%2F20230331015314.png) ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/04/12/003127?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -[ 2023-04-12 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023/04/12?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[誰でもわかるStable Diffusion その6:U-Net(IN1、Resブロック)](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/04/12/003127?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -Stable DiffusionのU-Net解説の3回目です。 今回はIN1ブロック… - - * [ ![誰でもわかるStable diffusion その5:U-Net\(IN0ブロックと畳み込み\)](https://cdn.image.st-hatena.com/image/square/1d0b73be31a163a61a8b16a29af99ac4f17a735d/backend=imagemagick;height=100;version=1;width=100/https%3A%2F%2Fcdn-ak.f.st-hatena.com%2Fimages%2Ffotolife%2Fh%2Fhoshikat%2F20230331%2F20230331015314.png) ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/04/03/215537?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -[ 2023-04-03 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023/04/03?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[誰でもわかるStable diffusion その5:U-Net(IN0ブロックと畳み込み)](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/04/03/215537?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -Stable Diffusionで使われるU-Netの最初のブロック、IN0層につ… - - * もっと読む - -コメントを書く - -[ « 誰でもわかるStable Diffusion LoRAを作… ](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/06/07/215433?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) [ AIお絵描きをめぐる問題 これまでとこれ… » ](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/05/17/183410?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -プロフィール - -[ ![id:hoshikat](https://cdn.profile-image.st- -hatena.com/users/hoshikat/profile.png) ](https://hoshikat-hatenablog- -com.translate.goog/about?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) [id:hoshikat](https://hoshikat-hatenablog- -com.translate.goog/about?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -[ 読者です 読者をやめる 読者になる 読者になる ](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp#) - -____ - -[このブログについて](https://hoshikat-hatenablog- -com.translate.goog/about?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - -検索 - -リンク - - * [はてなブログ](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://hatenablog.com/) - * [ブログをはじめる](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://hatenablog.com/guide?via%3D200109) - * [週刊はてなブログ](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=http://blog.hatenablog.com) - * [はてなブログPro](https://translate.google.com/website?sl=auto&tl=en&hl=en-US&client=webapp&u=https://hatenablog.com/guide/pro) - -[ 最新記事 ](https://hoshikat-hatenablog- -com.translate.goog/archive?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - - * [誰でもわかるStable Diffusion リージョナルプロンプト](https://hoshikat-hatenablog-com.translate.goog/entry/2023/07/11/004307?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - - * [誰でもわかるStable Diffusion スケジューラー](https://hoshikat-hatenablog-com.translate.goog/entry/2023/06/30/212231?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - - * [誰でもわかるStable Diffusion CFGスケールのしくみ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/06/17/021610?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - - * [誰でもわかるStable Diffusion テキストエンコーダー:CLIPのしくみ](https://hoshikat-hatenablog-com.translate.goog/entry/2023/06/13/002443?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - - * [誰でもわかるStable Diffusion LoRAを作ってみよう(実践編)](https://hoshikat-hatenablog-com.translate.goog/entry/2023/06/07/215433?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -[月別アーカイブ](https://hoshikat-hatenablog- -com.translate.goog/archive?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - - * ▼ ▶ - -[ 2023 ](https://hoshikat-hatenablog- -com.translate.goog/archive/2023?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp) - - * [ 2023 / 7 ](https://hoshikat-hatenablog-com.translate.goog/archive/2023/7?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - * [ 2023 / 6 ](https://hoshikat-hatenablog-com.translate.goog/archive/2023/6?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - * [ 2023 / 5 ](https://hoshikat-hatenablog-com.translate.goog/archive/2023/5?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - * [ 2023 / 4 ](https://hoshikat-hatenablog-com.translate.goog/archive/2023/4?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - * [ 2023 / 3 ](https://hoshikat-hatenablog-com.translate.goog/archive/2023/3?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -![](https://www14.a8.net/0.gif?a8mat=3T8SCR+DG1FLE+2HOM+BS629) - -### はてなブログをはじめよう! - -hoshikatさんは、はてなブログを使っています。あなたもはてなブログをはじめてみませんか? - -[はてなブログをはじめる(無料)](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://blog.hatena.ne.jp/register?via%3D200227) - -[はてなブログとは](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://hatenablog.com/guide) - -[ ![人工知能と親しくなるブログ](https://cdn.image.st- -hatena.com/image/square/956bbfe235ea19a9079dd050f7013fea30fe3aa5/backend=imagemagick;height=128;version=1;width=128/https%3A%2F%2Fcdn.user.blog.st- -hatena.com%2Fblog_custom_icon%2F159049220%2F1684744907210605) 人工知能と親しくなるブログ -](https://hoshikat-hatenablog- -com.translate.goog/?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp) - -Powered by [Hatena -Blog](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://hatenablog.com/) | -[ブログを報告する](https://translate.google.com/website?sl=auto&tl=en&hl=en- -US&client=webapp&u=https://blog.hatena.ne.jp/-/abuse_report?target_url%3Dhttps%253A%252F%252Fhoshikat.hatenablog.com%252Fentry%252F2023%252F05%252F26%252F223229) - -__ - -__ - -引用をストックしました - -ストック一覧を見る 閉じる - -引用するにはまずログインしてください - -ログイン 閉じる - -引用をストックできませんでした。再度お試しください - -閉じる - -限定公開記事のため引用できません。 - -[ 読者です 読者をやめる 読者になる 読者になる ](https://hoshikat-hatenablog- -com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en- -US&_x_tr_pto=wapp#) - -____ - diff --git a/fine_tune.py b/fine_tune.py deleted file mode 100644 index 875a91951..000000000 --- a/fine_tune.py +++ /dev/null @@ -1,505 +0,0 @@ -# training with captions -# XXX dropped option: hypernetwork training - -import argparse -import math -import os -from multiprocessing import Value -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from accelerate.utils import set_seed -from diffusers import DDPMScheduler - -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - get_weighted_text_embeddings, - prepare_scheduler_for_custom_training, - scale_v_prediction_loss_like_noise_prediction, - apply_debiased_estimation, -) - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) - - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - - # Diffusers版のxformers使用フラグを設定する関数 - def set_diffusers_xformers_flag(model, valid): - # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう - # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) - # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか - # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - fn_recursive_set_mem_eff(model) - - # モデルに xformers とか memory efficient attention を組み込む - if args.diffusers_xformers: - accelerator.print("Use xformers by Diffusers") - set_diffusers_xformers_flag(unet, True) - else: - # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - accelerator.print("Disable Diffusers' xformers") - set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # 学習を準備する:モデルを適切な状態にする - training_models = [] - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) - - if args.train_text_encoder: - accelerator.print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - text_encoder.train() # required for gradient_checkpointing - else: - text_encoder.eval() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - for m in training_models: - m.requires_grad_(True) - - trainable_params = [] - if args.learning_rate_te is None or not args.train_text_encoder: - for m in training_models: - trainable_params.extend(m.parameters()) - else: - trainable_params = [ - {"params": list(unet.parameters()), "lr": args.learning_rate}, - {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, - ] - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - accelerator.print("running training / 学習開始") - accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - accelerator.print( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - ) - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - - # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for m in training_models: - m.train() - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(args.train_text_encoder): - # Get the text embedding for conditioning - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: - # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # mean over batch dimension - else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - False, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(text_encoder), - accelerator.unwrap_model(unet), - vae, - ) - - current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: - logs = {"loss": current_loss} - train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) - accelerator.log(logs, step=global_step) - - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(text_encoder), - accelerator.unwrap_model(unet), - vae, - ) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - is_main_process = accelerator.is_main_process - if is_main_process: - unet = accelerator.unwrap_model(unet) - text_encoder = accelerator.unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end( - args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae - ) - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" - ) - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") - parser.add_argument( - "--learning_rate_te", - type=float, - default=None, - help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/fine_tune_README.md b/fine_tune_README.md deleted file mode 100644 index 696360a90..000000000 --- a/fine_tune_README.md +++ /dev/null @@ -1,504 +0,0 @@ -# Fine tuning - -It is a fine tuning that corresponds to NovelAI's proposed learning method, automatic captioning, tagging, Windows + VRAM 12GB (for v1.4/1.5) environment, etc. - -## Overview - -Fine tuning of U-Net of Stable Diffusion using Diffusers. It corresponds to the following improvements in NovelAI's article (For Aspect Ratio Bucketing, I referred to NovelAI's code, but the final code is all original). - -* Use the output of the penultimate layer instead of the last layer of CLIP (Text Encoder). -* Learning at non-square resolutions (Aspect Ratio Bucketing). -* Extend token length from 75 to 225. -* Captioning with BLIP (automatic creation of captions), automatic tagging with DeepDanbooru or WD14Tagger. -* Also supports Hypernetwork learning. -* Supports Stable Diffusion v2.0 (base and 768/v). -* By acquiring the output of VAE in advance and saving it to disk, we aim to save memory and speed up learning. - -Text Encoder is not trained by default. For fine tuning of the whole model, it seems common to learn only U-Net (NovelAI seems to be the same). Text Encoder can also be learned as an option. - -## Additional features - -### Change CLIP output - -CLIP (Text Encoder) converts the text into features in order to reflect the prompt in the image. Stable diffusion uses the output of the last layer of CLIP, but you can change it to use the output of the penultimate layer. According to NovelAI, this will reflect prompts more accurately. -It is also possible to use the output of the last layer as is. -*Stable Diffusion 2.0 uses the penultimate layer by default. Do not specify the clip_skip option. - -### Training in non-square resolutions - -Stable Diffusion is trained at 512\*512, but also at resolutions such as 256\*1024 and 384\*640. It is expected that this will reduce the cropped portion and learn the relationship between prompts and images more correctly. -The learning resolution is adjusted vertically and horizontally in units of 64 pixels within a range that does not exceed the resolution area (= memory usage) given as a parameter. - -In machine learning, it is common to unify all input sizes, but there are no particular restrictions, and in fact it is okay as long as they are unified within the same batch. NovelAI's bucketing seems to refer to classifying training data in advance for each learning resolution according to the aspect ratio. And by creating a batch with the images in each bucket, the image size of the batch is unified. - -### Extending token length from 75 to 225 - -Stable diffusion has a maximum of 75 tokens (77 tokens including the start and end), but we will extend it to 225 tokens. -However, the maximum length that CLIP accepts is 75 tokens, so in the case of 225 tokens, we simply divide it into thirds, call CLIP, and then concatenate the results. - -*I'm not sure if this is the preferred implementation. It seems to be working for now. Especially in 2.0, there is no implementation that can be used as a reference, so I have implemented it independently. - -*Automatic1111's Web UI seems to divide the text with commas in mind, but in my case, it's a simple division. - -## Environmental arrangement - -See the [README](./README-en.md) in this repository. - -## Preparing teacher data - -Prepare the image data you want to learn and put it in any folder. No prior preparation such as resizing is required. -However, for images that are smaller than the training resolution, it is recommended to enlarge them while maintaining the quality using super-resolution. - -It also supports multiple teacher data folders. Preprocessing will be executed for each folder. - -For example, store an image like this: - -![Teacher data folder screenshot](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) - -## Automatic captioning - -Skip if you just want to learn tags without captions. - -Also, when preparing captions manually, prepare them in the same directory as the teacher data image, with the same file name, extension .caption, etc. Each file should be a text file with only one line. - -### Captioning with BLIP - -The latest version no longer requires BLIP downloads, weight downloads, and additional virtual environments. Works as-is. - -Run make_captions.py in the finetune folder. - -```shell -python finetune\make_captions.py --batch_size -``` - -If the batch size is 8 and the training data is placed in the parent folder train_data, it will be as follows. - -```shell -python finetune\make_captions.py --batch_size 8 ..\train_data -``` - -A caption file is created in the same directory as the teacher data image with the same file name and extension .caption. - -Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). -You can specify the maximum length of the caption with the max_length option. Default is 75. It may be longer if the model is trained with a token length of 225. -You can change the caption extension with the caption_extension option. Default is .caption (.txt conflicts with DeepDanbooru described later). - -If there are multiple teacher data folders, execute for each folder. - -Note that the inference is random, so the results will change each time you run it. If you want to fix it, specify a random number seed like "--seed 42" with the --seed option. - -For other options, please refer to the help with --help (there seems to be no documentation for the meaning of the parameters, so you have to look at the source). - -A caption file is generated with the extension .caption by default. - -![Folder where caption is generated](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) - -For example, with captions like: - -![captions and images](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) - -## Tagged by DeepDanbooru - -If you do not want to tag the danbooru tag itself, please proceed to "Preprocessing of caption and tag information". - -Tagging is done with DeepDanbooru or WD14Tagger. WD14Tagger seems to be more accurate. If you want to tag with WD14Tagger, skip to the next chapter. - -### Environmental arrangement - -Clone DeepDanbooru https://github.com/KichangKim/DeepDanbooru into your working folder, or download the zip and extract it. I unzipped it. -Also, download deepdanbooru-v3-20211112-sgd-e28.zip from Assets of "DeepDanbooru Pretrained Model v3-20211112-sgd-e28" on the DeepDanbooru Releases page https://github.com/KichangKim/DeepDanbooru/releases and extract it to the DeepDanbooru folder. - -Download from below. Click to open Assets and download from there. - -![DeepDanbooru download page](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) - -Make a directory structure like this - -![DeepDanbooru directory structure](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) - -Install the necessary libraries for the Diffusers environment. Go to the DeepDanbooru folder and install it (I think it's actually just adding tensorflow-io). - -```shell -pip install -r requirements.txt -``` - -Next, install DeepDanbooru itself. - -```shell -pip install . -``` - -This completes the preparation of the environment for tagging. - -### Implementing tagging - -Go to DeepDanbooru's folder and run deepdanbooru to tag. - -```shell -deepdanbooru evaluate --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt -``` - -If you put the training data in the parent folder train_data, it will be as follows. - -```shell -deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt -``` - -A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. It is slow because it is processed one by one. - -If there are multiple teacher data folders, execute for each folder. - -It is generated as follows. - -![DeepDanbooru generated files](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) - -A tag is attached like this (great amount of information...). - -![Deep Danbooru tag and image](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) - -## Tagging with WD14Tagger - -This procedure uses WD14Tagger instead of DeepDanbooru. - -Use the tagger used in Mr. Automatic1111's WebUI. I referred to the information on this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger). - -The modules required for the initial environment maintenance have already been installed. Weights are automatically downloaded from Hugging Face. - -### Implementing tagging - -Run the script to do the tagging. - -```shell -python tag_images_by_wd14_tagger.py --batch_size -``` - -If you put the training data in the parent folder train_data, it will be as follows. - -```shell -python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data -``` - -The model file will be automatically downloaded to the wd14_tagger_model folder on first launch (folder can be changed in options). It will be as follows. - -![downloaded file](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) - -A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. - -![generated tag file](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) - -![tags and images](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) - -With the thresh option, you can specify the number of confidences of the determined tag to attach the tag. The default is 0.35, same as the WD14Tagger sample. Lower values give more tags, but less accuracy. -Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). You can change the tag file extension with the caption_extension option. Default is .txt. -You can specify the folder where the model is saved with the model_dir option. -Also, if you specify the force_download option, the model will be re-downloaded even if there is a save destination folder. - -If there are multiple teacher data folders, execute for each folder. - -## Preprocessing caption and tag information - -Combine captions and tags into a single file as metadata for easy processing from scripts. - -### Caption preprocessing - -To put captions into the metadata, run the following in your working folder (if you don't use captions for learning, you don't need to run this) (it's actually a single line, and so on). - -```shell -python merge_captions_to_metadata.py ---in_json - -``` - -The metadata file name is an arbitrary name. -If the training data is train_data, there is no metadata file to read, and the metadata file is meta_cap.json, it will be as follows. - -```shell -python merge_captions_to_metadata.py train_data meta_cap.json -``` - -You can specify the caption extension with the caption_extension option. - -If there are multiple teacher data folders, please specify the full_path argument (metadata will have full path information). Then run it for each folder. - -```shell -python merge_captions_to_metadata.py --full_path - train_data1 meta_cap1.json -python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json - train_data2 meta_cap2.json -``` - -If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there. - -__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __ - -### Tag preprocessing - -Similarly, tags are also collected in metadata (no need to do this if tags are not used for learning). - -```shell -python merge_dd_tags_to_metadata.py - --in_json - -``` - -With the same directory structure as above, when reading meta_cap.json and writing to meta_cap_dd.json, it will be as follows. - -```shell -python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json -``` - -If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder. - -```shell -python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json - train_data1 meta_cap_dd1.json -python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json - train_data2 meta_cap_dd2.json -``` - -If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there. - -__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __ - -### Cleaning captions and tags - -Up to this point, captions and DeepDanbooru tags have been put together in the metadata file. However, captions with automatic captioning are subtle due to spelling variations (*), and tags include underscores and ratings (in the case of DeepDanbooru), so the editor's replacement function etc. You should use it to clean your captions and tags. - -*For example, when learning a girl in an anime picture, there are variations in captions such as girl/girls/woman/women. Also, it may be more appropriate to simply use "girl" for things like "anime girl". - -A script for cleaning is provided, so please edit the contents of the script according to the situation and use it. - -(It is no longer necessary to specify the teacher data folder. All data in the metadata will be cleaned.) - -```shell -python clean_captions_and_tags.py -``` - -Please note that --in_json is not included. For example: - -```shell -python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json -``` - -Preprocessing of captions and tags is now complete. - -## Get latents in advance - -In order to speed up the learning, we acquire the latent representation of the image in advance and save it to disk. At the same time, bucketing (classifying the training data according to the aspect ratio) is performed. - -In your working folder, type: - -```shell -python prepare_buckets_latents.py - - - --batch_size - --max_resolution - --mixed_precision -``` - -If the model is model.ckpt, batch size 4, training resolution is 512\*512, precision is no (float32), read metadata from meta_clean.json and write to meta_lat.json: - -```shell -python prepare_buckets_latents.py - train_data meta_clean.json meta_lat.json model.ckpt - --batch_size 4 --max_resolution 512,512 --mixed_precision no -``` - -Latents are saved in numpy npz format in the teacher data folder. - -Specify the --v2 option when loading a Stable Diffusion 2.0 model (--v_parameterization is not required). - -You can specify the minimum resolution size with the --min_bucket_reso option and the maximum size with the --max_bucket_reso option. The defaults are 256 and 1024 respectively. For example, specifying a minimum size of 384 will not use resolutions such as 256\*1024 or 320\*768. -If you increase the resolution to something like 768\*768, you should specify something like 1280 for the maximum size. - -If you specify the --flip_aug option, it will perform horizontal flip augmentation (data augmentation). You can artificially double the amount of data, but if you specify it when the data is not left-right symmetrical (for example, character appearance, hairstyle, etc.), learning will not go well. -(This is a simple implementation that acquires the latents for the flipped image and saves the \*\_flip.npz file. No options are required for fine_tune.py. If there is a file with \_flip, Randomly load a file without - -The batch size may be increased a little more even with 12GB of VRAM. -The resolution is a number divisible by 64, and is specified by "width, height". The resolution is directly linked to the memory size during fine tuning. 512,512 seems to be the limit with VRAM 12GB (*). 16GB may be raised to 512,704 or 512,768. Even with 256, 256, etc., it seems to be difficult with 8GB of VRAM (because parameters and optimizers require a certain amount of memory regardless of resolution). - -*There was also a report that learning batch size 1 worked with 12GB VRAM and 640,640. - -The result of bucketing is displayed as follows. - -![bucketing result](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) - -If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder. - -```shell -python prepare_buckets_latents.py --full_path - train_data1 meta_clean.json meta_lat1.json model.ckpt - --batch_size 4 --max_resolution 512,512 --mixed_precision no - -python prepare_buckets_latents.py --full_path - train_data2 meta_lat1.json meta_lat2.json model.ckpt - --batch_size 4 --max_resolution 512,512 --mixed_precision no\ -``` - -It is possible to make the read source and write destination the same, but separate is safer. - -__*It is safe to rewrite the argument each time and write it to a separate metadata file. __ - -## Run training - -For example: Below are the settings for saving memory. - -```shell -accelerate launch --num_cpu_threads_per_process 8 fine_tune.py - --pretrained_model_name_or_path=model.ckpt - --in_json meta_lat.json - --train_data_dir=train_data - --output_dir=fine_tuned - --shuffle_caption - --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 - --use_8bit_adam --xformers --gradient_checkpointing - --mixed_precision=bf16 - --save_every_n_epochs=4 -``` - -It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process of accelerate. - -Specify the model to be trained in pretrained_model_name_or_path (Stable Diffusion checkpoint or Diffusers model). Stable Diffusion checkpoint supports .ckpt and .safetensors (automatically determined by extension). - -Specifies the metadata file when caching latent to in_json. - -Specify the training data folder for train_data_dir and the output destination folder for the trained model for output_dir. - -If shuffle_caption is specified, captions and tags are shuffled and learned in units separated by commas (this is the method used in Waifu Diffusion v1.3). -(You can keep some of the leading tokens fixed without shuffling. See keep_tokens for other options.) - -Specify the batch size in train_batch_size. Specify 1 or 2 for VRAM 12GB. The number that can be specified also changes depending on the resolution. -The actual amount of data used for training is "batch size x number of steps". When increasing the batch size, the number of steps can be decreased accordingly. - -Specify the learning rate in learning_rate. For example Waifu Diffusion v1.3 seems to be 5e-6. -Specify the number of steps in max_train_steps. - -Specify use_8bit_adam to use the 8-bit Adam Optimizer. It saves memory and speeds up, but accuracy may decrease. - -Specifying xformers replaces CrossAttention to save memory and speed up. -* As of 11/9, xformers will cause an error in float32 learning, so please use bf16/fp16 or use memory-saving CrossAttention with mem_eff_attn instead (speed is inferior to xformers). - -Enable intermediate saving of gradients in gradient_checkpointing. It's slower, but uses less memory. - -Specifies whether to use mixed precision with mixed_precision. Specifying "fp16" or "bf16" saves memory, but accuracy is inferior. -"fp16" and "bf16" use almost the same amount of memory, and it is said that bf16 has better learning results (I didn't feel much difference in the range I tried). -If "no" is specified, it will not be used (it will be float32). - -* It seems that an error will occur when reading checkpoints learned with bf16 with Mr. AUTOMATIC1111's Web UI. This seems to be because the data type bfloat16 causes an error in the Web UI model safety checker. Save in fp16 or float32 format with the save_precision option. Or it seems to be good to store it in safetensors format. - -Specifying save_every_n_epochs will save the model being trained every time that many epochs have passed. - -### Supports Stable Diffusion 2.0 - -Specify the --v2 option when using Hugging Face's stable-diffusion-2-base, and specify both --v2 and --v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt please. - -### Increase accuracy and speed when memory is available - -First, removing gradient_checkpointing will speed it up. However, the batch size that can be set is reduced, so please set while looking at the balance between accuracy and speed. - -Increasing the batch size increases speed and accuracy. Increase the speed while checking the speed per data within the range where the memory is sufficient (the speed may actually decrease when the memory is at the limit). - -### Change CLIP output used - -Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used. -The learned model should be able to be inferred by Automatic1111's web UI. - -*SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0. - -If the model being trained was originally trained to use the second layer, 2 is a good value. - -If you were using the last layer instead, the entire model would have been trained on that assumption. Therefore, if you train again using the second layer, you may need a certain number of teacher data and longer learning to obtain the desired learning result. - -### Extending Token Length - -You can learn by extending the token length by specifying 150 or 225 for max_token_length. -The learned model should be able to be inferred by Automatic1111's web UI. - -As with clip_skip, learning with a length different from the learning state of the model may require a certain amount of teacher data and a longer learning time. - -### Save learning log - -Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved. - -For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder. -Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=fine_tune_style1" for identification. - -To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard). - -```shell -tensorboard --logdir=logs -``` - -### Learning Hypernetworks - -It will be explained in another article. - -### Learning with fp16 gradient (experimental feature) - -The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). As a result, it seems that the SD1.x 512*512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512*512 size can be learned with a VRAM usage of less than 12GB. - -Specify fp16 in advance in accelerate config and optionally set mixed_precision="fp16" (does not work with bf16). - -To minimize memory usage, use the xformers, use_8bit_adam, gradient_checkpointing options and set train_batch_size to 1. -(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.) - -It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). The accuracy will drop considerably, and the probability of learning failure on the way will also increase. The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk. - -### Other Options - -#### keep_tokens - -If a number is specified, the specified number of tokens (comma-separated strings) from the beginning of the caption are fixed without being shuffled. - -If there are both captions and tags, the prompts during learning will be concatenated like "caption, tag 1, tag 2...", so if you set "--keep_tokens=1", the caption will always be at the beginning during learning. will come. - -#### dataset_repeats - -If the number of data sets is extremely small, the epoch will end soon (it will take some time at the epoch break), so please specify a numerical value and multiply the data by some to make the epoch longer. - -#### train_text_encoder - -Text Encoder is also a learning target. Slightly increased memory usage. - -In normal fine tuning, the Text Encoder is not targeted for training (probably because U-Net is trained to follow the output of the Text Encoder), but if the number of training data is small, the Text Encoder is trained like DreamBooth. also seems to be valid. - -#### save_precision - -The data format when saving checkpoints can be specified from float, fp16, and bf16 (if not specified, it is the same as the data format during learning). It saves disk space, but the model produces different results. Also, if you specify float or fp16, you should be able to read it on Mr. 1111's Web UI. - -*For VAE, the data format of the original checkpoint will remain, so the model size may not be reduced to a little over 2GB even with fp16. - -#### save_model_as - -Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors. - -When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face. - -#### use_safetensors - -This option saves checkpoints in safetensors format. The save format will be the default (same format as loaded). - -#### save_state and resume - -The save_state option saves the learning state of the optimizer, etc. in addition to the checkpoint in the folder when saving midway and at the final save. This avoids a decrease in accuracy when learning is resumed after being interrupted (since the optimizer optimizes while having a state, if the state is reset, the optimization must be performed again from the initial state. not). Note that the number of steps is not saved due to Accelerate specifications. - -When starting the script, you can resume by specifying the folder where the state is saved with the resume option. - -Please note that the learning state will be about 5 GB per save, so please be careful of the disk capacity. - -#### gradient_accumulation_steps - -Updates the gradient in batches for the specified number of steps. Has a similar effect to increasing the batch size, but consumes slightly more memory. - -*The Accelerate specification does not support multiple learning models, so if you set Text Encoder as the learning target and specify a value of 2 or more for this option, an error may occur. - -#### lr_scheduler / lr_warmup_steps - -You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. - -With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details. - -#### diffusers_xformers - -Uses Diffusers' xformers feature rather than the script's own xformers replacement feature. Hypernetwork learning is no longer possible. diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py deleted file mode 100644 index 7d192cb26..000000000 --- a/finetune/blip/blip.py +++ /dev/null @@ -1,244 +0,0 @@ -''' - * Copyright (c) 2022, salesforce.com, inc. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause - * By Junnan Li -''' -import warnings -warnings.filterwarnings("ignore") - -# from models.vit import VisionTransformer, interpolate_pos_embed -# from models.med import BertConfig, BertModel, BertLMHeadModel -from blip.vit import VisionTransformer, interpolate_pos_embed -from blip.med import BertConfig, BertModel, BertLMHeadModel -from transformers import BertTokenizer - -import torch -from torch import nn -import torch.nn.functional as F - -import os -from urllib.parse import urlparse -from timm.models.hub import download_cached_file -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -class BLIP_Base(nn.Module): - def __init__(self, - med_config = 'configs/med_config.json', - image_size = 224, - vit = 'base', - vit_grad_ckpt = False, - vit_ckpt_layer = 0, - ): - """ - Args: - med_config (str): path for the mixture of encoder-decoder model's configuration file - image_size (int): input image size - vit (str): model size of vision transformer - """ - super().__init__() - - self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) - self.tokenizer = init_tokenizer() - med_config = BertConfig.from_json_file(med_config) - med_config.encoder_width = vision_width - self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) - - - def forward(self, image, caption, mode): - - assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" - text = self.tokenizer(caption, return_tensors="pt").to(image.device) - - if mode=='image': - # return image features - image_embeds = self.visual_encoder(image) - return image_embeds - - elif mode=='text': - # return text features - text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, - return_dict = True, mode = 'text') - return text_output.last_hidden_state - - elif mode=='multimodal': - # return multimodel features - image_embeds = self.visual_encoder(image) - image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) - - text.input_ids[:,0] = self.tokenizer.enc_token_id - output = self.text_encoder(text.input_ids, - attention_mask = text.attention_mask, - encoder_hidden_states = image_embeds, - encoder_attention_mask = image_atts, - return_dict = True, - ) - return output.last_hidden_state - - - -class BLIP_Decoder(nn.Module): - def __init__(self, - med_config = 'configs/med_config.json', - image_size = 384, - vit = 'base', - vit_grad_ckpt = False, - vit_ckpt_layer = 0, - prompt = 'a picture of ', - ): - """ - Args: - med_config (str): path for the mixture of encoder-decoder model's configuration file - image_size (int): input image size - vit (str): model size of vision transformer - """ - super().__init__() - - self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) - self.tokenizer = init_tokenizer() - med_config = BertConfig.from_json_file(med_config) - med_config.encoder_width = vision_width - self.text_decoder = BertLMHeadModel(config=med_config) - - self.prompt = prompt - self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 - - - def forward(self, image, caption): - - image_embeds = self.visual_encoder(image) - image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) - - text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) - - text.input_ids[:,0] = self.tokenizer.bos_token_id - - decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) - decoder_targets[:,:self.prompt_length] = -100 - - decoder_output = self.text_decoder(text.input_ids, - attention_mask = text.attention_mask, - encoder_hidden_states = image_embeds, - encoder_attention_mask = image_atts, - labels = decoder_targets, - return_dict = True, - ) - loss_lm = decoder_output.loss - - return loss_lm - - def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): - image_embeds = self.visual_encoder(image) - - if not sample: - image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) - - image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) - model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} - - prompt = [self.prompt] * image.size(0) - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) - input_ids[:,0] = self.tokenizer.bos_token_id - input_ids = input_ids[:, :-1] - - if sample: - #nucleus sampling - outputs = self.text_decoder.generate(input_ids=input_ids, - max_length=max_length, - min_length=min_length, - do_sample=True, - top_p=top_p, - num_return_sequences=1, - eos_token_id=self.tokenizer.sep_token_id, - pad_token_id=self.tokenizer.pad_token_id, - repetition_penalty=1.1, - **model_kwargs) - else: - #beam search - outputs = self.text_decoder.generate(input_ids=input_ids, - max_length=max_length, - min_length=min_length, - num_beams=num_beams, - eos_token_id=self.tokenizer.sep_token_id, - pad_token_id=self.tokenizer.pad_token_id, - repetition_penalty=repetition_penalty, - **model_kwargs) - - captions = [] - for output in outputs: - caption = self.tokenizer.decode(output, skip_special_tokens=True) - captions.append(caption[len(self.prompt):]) - return captions - - -def blip_decoder(pretrained='',**kwargs): - model = BLIP_Decoder(**kwargs) - if pretrained: - model,msg = load_checkpoint(model,pretrained) - assert(len(msg.missing_keys)==0) - return model - -def blip_feature_extractor(pretrained='',**kwargs): - model = BLIP_Base(**kwargs) - if pretrained: - model,msg = load_checkpoint(model,pretrained) - assert(len(msg.missing_keys)==0) - return model - -def init_tokenizer(): - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - tokenizer.add_special_tokens({'bos_token':'[DEC]'}) - tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) - tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] - return tokenizer - - -def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): - - assert vit in ['base', 'large'], "vit parameter must be base or large" - if vit=='base': - vision_width = 768 - visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, - num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, - drop_path_rate=0 or drop_path_rate - ) - elif vit=='large': - vision_width = 1024 - visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, - num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, - drop_path_rate=0.1 or drop_path_rate - ) - return visual_encoder, vision_width - -def is_url(url_or_filename): - parsed = urlparse(url_or_filename) - return parsed.scheme in ("http", "https") - -def load_checkpoint(model,url_or_filename): - if is_url(url_or_filename): - cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) - checkpoint = torch.load(cached_file, map_location='cpu') - elif os.path.isfile(url_or_filename): - checkpoint = torch.load(url_or_filename, map_location='cpu') - else: - raise RuntimeError('checkpoint url or path is invalid') - - state_dict = checkpoint['model'] - - state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) - if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): - state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], - model.visual_encoder_m) - for key in model.state_dict().keys(): - if key in state_dict.keys(): - if state_dict[key].shape!=model.state_dict()[key].shape: - del state_dict[key] - - msg = model.load_state_dict(state_dict,strict=False) - logger.info('load checkpoint from %s'%url_or_filename) - return model,msg - diff --git a/finetune/blip/med.py b/finetune/blip/med.py deleted file mode 100644 index 7b00a3545..000000000 --- a/finetune/blip/med.py +++ /dev/null @@ -1,955 +0,0 @@ -''' - * Copyright (c) 2022, salesforce.com, inc. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause - * By Junnan Li - * Based on huggingface code base - * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert -''' - -import math -import os -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -from torch import Tensor, device, dtype, nn -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -import torch.nn.functional as F - -from transformers.activations import ACT2FN -from transformers.file_utils import ( - ModelOutput, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - MaskedLMOutput, - MultipleChoiceModelOutput, - NextSentencePredictorOutput, - QuestionAnsweringModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput, -) -from transformers.modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from transformers.utils import logging -from transformers.models.bert.configuration_bert import BertConfig - - -logger = logging.get_logger(__name__) - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word and position embeddings.""" - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - - self.config = config - - def forward( - self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 - ): - if input_ids is not None: - input_shape = input_ids.size() - else: - input_shape = inputs_embeds.size()[:-1] - - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - embeddings = inputs_embeds - - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, config, is_cross_attention): - super().__init__() - self.config = config - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - if is_cross_attention: - self.key = nn.Linear(config.encoder_width, self.all_head_size) - self.value = nn.Linear(config.encoder_width, self.all_head_size) - else: - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - self.save_attention = False - - def save_attn_gradients(self, attn_gradients): - self.attn_gradients = attn_gradients - - def get_attn_gradients(self): - return self.attn_gradients - - def save_attention_map(self, attention_map): - self.attention_map = attention_map - - def get_attention_map(self): - return self.attention_map - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - if is_cross_attention and self.save_attention: - self.save_attention_map(attention_probs) - attention_probs.register_hook(self.save_attn_gradients) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs_dropped = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs_dropped = attention_probs_dropped * head_mask - - context_layer = torch.matmul(attention_probs_dropped, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - outputs = outputs + (past_key_value,) - return outputs - - -class BertSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - def __init__(self, config, is_cross_attention=False): - super().__init__() - self.self = BertSelfAttention(config, is_cross_attention) - self.output = BertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -class BertIntermediate(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, config, layer_num): - super().__init__() - self.config = config - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = BertAttention(config) - self.layer_num = layer_num - if self.config.add_cross_attention: - self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - mode=None, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - ) - attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - - if mode=='multimodal': - assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" - - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions=output_attentions, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output - ) - outputs = (layer_output,) + outputs - - outputs = outputs + (present_key_value,) - - return outputs - - def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - mode='multimodal', - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - next_decoder_cache = () if use_cache else None - - for i in range(self.config.num_hidden_layers): - layer_module = self.layer[i] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warn( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - mode=mode, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - mode=mode, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class BertPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = BertLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BertConfig - base_model_prefix = "bert" - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def _init_weights(self, module): - """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -class BertModel(BertPreTrainedModel): - """ - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an - input to the forward pass. - """ - - def __init__(self, config, add_pooling_layer=True): - super().__init__(config) - self.config = config - - self.embeddings = BertEmbeddings(config) - - self.encoder = BertEncoder(config) - - self.pooler = BertPooler(config) if add_pooling_layer else None - - self.init_weights() - - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (:obj:`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (:obj:`Tuple[int]`): - The shape of the input to the model. - device: (:obj:`torch.device`): - The device of the input to the model. - - Returns: - :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if is_decoder: - batch_size, seq_length = input_shape - - seq_ids = torch.arange(seq_length, device=device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] - # in case past_key_values are used we need to add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) - - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] - causal_mask = torch.cat( - [ - torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), - causal_mask, - ], - axis=-1, - ) - - extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( - input_shape, attention_mask.shape - ) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - is_decoder=False, - mode='multimodal', - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - batch_size, seq_length = input_shape - device = input_ids.device - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size, seq_length = input_shape - device = inputs_embeds.device - elif encoder_embeds is not None: - input_shape = encoder_embeds.size()[:-1] - batch_size, seq_length = input_shape - device = encoder_embeds.device - else: - raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, - device, is_decoder) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if encoder_hidden_states is not None: - if type(encoder_hidden_states) == list: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() - else: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - - if type(encoder_attention_mask) == list: - encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] - elif encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if encoder_embeds is None: - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - else: - embedding_output = encoder_embeds - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - mode=mode, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - - -class BertLMHeadModel(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=True, - reduction='mean', - mode='multimodal', - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are - ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - Returns: - Example:: - >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig - >>> import torch - >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - >>> config = BertConfig.from_pretrained("bert-base-cased") - >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - >>> prediction_logits = outputs.logits - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_decoder=is_decoder, - mode=mode, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - if return_logits: - return prediction_scores[:, :-1, :].contiguous() - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) - lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - if reduction=='none': - lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): - input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": past, - "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), - "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), - "is_decoder": True, - } - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past diff --git a/finetune/blip/med_config.json b/finetune/blip/med_config.json deleted file mode 100644 index dc12b99cf..000000000 --- a/finetune/blip/med_config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "architectures": [ - "BertModel" - ], - "attention_probs_dropout_prob": 0.1, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "hidden_size": 768, - "initializer_range": 0.02, - "intermediate_size": 3072, - "layer_norm_eps": 1e-12, - "max_position_embeddings": 512, - "model_type": "bert", - "num_attention_heads": 12, - "num_hidden_layers": 12, - "pad_token_id": 0, - "type_vocab_size": 2, - "vocab_size": 30524, - "encoder_width": 768, - "add_cross_attention": true - } - \ No newline at end of file diff --git a/finetune/blip/vit.py b/finetune/blip/vit.py deleted file mode 100644 index cec3d8e08..000000000 --- a/finetune/blip/vit.py +++ /dev/null @@ -1,305 +0,0 @@ -''' - * Copyright (c) 2022, salesforce.com, inc. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause - * By Junnan Li - * Based on timm code base - * https://github.com/rwightman/pytorch-image-models/tree/master/timm -''' - -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial - -from timm.models.vision_transformer import _cfg, PatchEmbed -from timm.models.registry import register_model -from timm.models.layers import trunc_normal_, DropPath -from timm.models.helpers import named_apply, adapt_input_conv - -from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper - -class Mlp(nn.Module): - """ MLP as used in Vision Transformer, MLP-Mixer and related networks - """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights - self.scale = qk_scale or head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.attn_gradients = None - self.attention_map = None - - def save_attn_gradients(self, attn_gradients): - self.attn_gradients = attn_gradients - - def get_attn_gradients(self): - return self.attn_gradients - - def save_attention_map(self, attention_map): - self.attention_map = attention_map - - def get_attention_map(self): - return self.attention_map - - def forward(self, x, register_hook=False): - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - if register_hook: - self.save_attention_map(attn) - attn.register_hook(self.save_attn_gradients) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if use_grad_checkpointing: - self.attn = checkpoint_wrapper(self.attn) - self.mlp = checkpoint_wrapper(self.mlp) - - def forward(self, x, register_hook=False): - x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class VisionTransformer(nn.Module): - """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, - use_grad_checkpointing=False, ckpt_layer=0): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - qk_scale (float): override default qk scale of head_dim ** -0.5 if set - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - drop_rate (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - norm_layer: (nn.Module): normalization layer - """ - super().__init__() - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pos_drop = nn.Dropout(p=drop_rate) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) - ) - for i in range(depth)]) - self.norm = norm_layer(embed_dim) - - trunc_normal_(self.pos_embed, std=.02) - trunc_normal_(self.cls_token, std=.02) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token'} - - def forward(self, x, register_blk=-1): - B = x.shape[0] - x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - - x = x + self.pos_embed[:,:x.size(1),:] - x = self.pos_drop(x) - - for i,blk in enumerate(self.blocks): - x = blk(x, register_blk==i) - x = self.norm(x) - - return x - - @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix=''): - _load_weights(self, checkpoint_path, prefix) - - -@torch.no_grad() -def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): - """ Load weights from .npz checkpoints for official Google Brain Flax implementation - """ - import numpy as np - - def _n2p(w, t=True): - if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: - w = w.flatten() - if t: - if w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif w.ndim == 2: - w = w.transpose([1, 0]) - return torch.from_numpy(w) - - w = np.load(checkpoint_path) - if not prefix and 'opt/target/embedding/kernel' in w: - prefix = 'opt/target/' - - if hasattr(model.patch_embed, 'backbone'): - # hybrid - backbone = model.patch_embed.backbone - stem_only = not hasattr(backbone, 'stem') - stem = backbone if stem_only else backbone.stem - stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) - stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) - stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) - if not stem_only: - for i, stage in enumerate(backbone.stages): - for j, block in enumerate(stage.blocks): - bp = f'{prefix}block{i + 1}/unit{j + 1}/' - for r in range(3): - getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) - getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) - getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) - if block.downsample is not None: - block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) - block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) - block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) - embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) - else: - embed_conv_w = adapt_input_conv( - model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) - model.patch_embed.proj.weight.copy_(embed_conv_w) - model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) - model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) - if pos_embed_w.shape != model.pos_embed.shape: - pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) - model.pos_embed.copy_(pos_embed_w) - model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) - model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) -# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: -# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) -# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) -# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: -# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) -# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) - for i, block in enumerate(model.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) - - -def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): - # interpolate position embedding - embedding_size = pos_embed_checkpoint.shape[-1] - num_patches = visual_encoder.patch_embed.num_patches - num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches - # height (== width) for the checkpoint position embedding - orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) - # height (== width) for the new position embedding - new_size = int(num_patches ** 0.5) - - if orig_size!=new_size: - # class_token and dist_token are kept unchanged - extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) - pos_tokens = torch.nn.functional.interpolate( - pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) - pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) - - return new_pos_embed - else: - return pos_embed_checkpoint \ No newline at end of file diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py deleted file mode 100644 index 5aeb17425..000000000 --- a/finetune/clean_captions_and_tags.py +++ /dev/null @@ -1,194 +0,0 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - -import argparse -import glob -import os -import json -import re - -from tqdm import tqdm -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') -PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') -PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') -PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') - -# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する -PATTERNS_REMOVE_IN_MULTI = [ - PATTERN_HAIR_LENGTH, - PATTERN_HAIR_CUT, - re.compile(r', [\w\-]+ eyes, '), - re.compile(r', ([\w\-]+ sleeves|sleeveless), '), - # 複数の髪型定義がある場合は削除する - re.compile( - r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), -] - - -def clean_tags(image_key, tags): - # replace '_' to ' ' - tags = tags.replace('^_^', '^@@@^') - tags = tags.replace('_', ' ') - tags = tags.replace('^@@@^', '^_^') - - # remove rating: deepdanbooruのみ - tokens = tags.split(", rating") - if len(tokens) == 1: - # WD14 taggerのときはこちらになるのでメッセージは出さない - # logger.info("no rating:") - # logger.info(f"{image_key} {tags}") - pass - else: - if len(tokens) > 2: - logger.info("multiple ratings:") - logger.info(f"{image_key} {tags}") - tags = tokens[0] - - tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 - - # 複数の人物がいる場合は髪色等のタグを削除する - if 'girls' in tags or 'boys' in tags: - for pat in PATTERNS_REMOVE_IN_MULTI: - found = pat.findall(tags) - if len(found) > 1: # 二つ以上、タグがある - tags = pat.sub("", tags) - - # 髪の特殊対応 - srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) - if srch_hair_len: - org = srch_hair_len.group() - tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) - - found = PATTERN_HAIR.findall(tags) - if len(found) > 1: - tags = PATTERN_HAIR.sub("", tags) - - if srch_hair_len: - tags = tags.replace(", @@@, ", org) # 戻す - - # white shirtとshirtみたいな重複タグの削除 - found = PATTERN_WORD.findall(tags) - for word in found: - if re.search(f", ((\w+) )+{word}, ", tags): - tags = tags.replace(f", {word}, ", "") - - tags = tags.replace(", , ", ", ") - assert tags.startswith(", ") and tags.endswith(", ") - tags = tags[2:-2] - return tags - - -# 上から順に検索、置換される -# ('置換元文字列', '置換後文字列') -CAPTION_REPLACEMENTS = [ - ('anime anime', 'anime'), - ('young ', ''), - ('anime girl', 'girl'), - ('cartoon female', 'girl'), - ('cartoon lady', 'girl'), - ('cartoon character', 'girl'), # a or ~s - ('cartoon woman', 'girl'), - ('cartoon women', 'girls'), - ('cartoon girl', 'girl'), - ('anime female', 'girl'), - ('anime lady', 'girl'), - ('anime character', 'girl'), # a or ~s - ('anime woman', 'girl'), - ('anime women', 'girls'), - ('lady', 'girl'), - ('female', 'girl'), - ('woman', 'girl'), - ('women', 'girls'), - ('people', 'girls'), - ('person', 'girl'), - ('a cartoon figure', 'a figure'), - ('a cartoon image', 'an image'), - ('a cartoon picture', 'a picture'), - ('an anime cartoon image', 'an image'), - ('a cartoon anime drawing', 'a drawing'), - ('a cartoon drawing', 'a drawing'), - ('girl girl', 'girl'), -] - - -def clean_caption(caption): - for rf, rt in CAPTION_REPLACEMENTS: - replaced = True - while replaced: - bef = caption - caption = caption.replace(rf, rt) - replaced = bef != caption - return caption - - -def main(args): - if os.path.exists(args.in_json): - logger.info(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - logger.error("no metadata / メタデータファイルがありません") - return - - logger.info("cleaning captions and tags.") - image_keys = list(metadata.keys()) - for image_key in tqdm(image_keys): - tags = metadata[image_key].get('tags') - if tags is None: - logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") - else: - org = tags - tags = clean_tags(image_key, tags) - metadata[image_key]['tags'] = tags - if args.debug and org != tags: - logger.info("FROM: " + org) - logger.info("TO: " + tags) - - caption = metadata[image_key].get('caption') - if caption is None: - logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") - else: - org = caption - caption = clean_caption(caption) - metadata[image_key]['caption'] = caption - if args.debug and org != caption: - logger.info("FROM: " + org) - logger.info("TO: " + caption) - - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding='utf-8') as f: - json.dump(metadata, f, indent=2) - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args, unknown = parser.parse_known_args() - if len(unknown) == 1: - logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") - logger.warning("All captions and tags in the metadata are processed.") - logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") - logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") - args.in_json = args.out_json - args.out_json = unknown[0] - elif len(unknown) > 0: - raise ValueError(f"error: unrecognized arguments: {unknown}") - - main(args) diff --git a/finetune/hypernetwork_nai.py b/finetune/hypernetwork_nai.py deleted file mode 100644 index dcaaa714a..000000000 --- a/finetune/hypernetwork_nai.py +++ /dev/null @@ -1,96 +0,0 @@ -# NAI compatible - -import torch - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, multiplier=1.0): - super().__init__() - - linear1 = torch.nn.Linear(dim, dim * 2) - linear2 = torch.nn.Linear(dim * 2, dim) - linear1.weight.data.normal_(mean=0.0, std=0.01) - linear1.bias.data.zero_() - linear2.weight.data.normal_(mean=0.0, std=0.01) - linear2.bias.data.zero_() - linears = [linear1, linear2] - - self.linear = torch.nn.Sequential(*linears) - self.multiplier = multiplier - - def forward(self, x): - return x + self.linear(x) * self.multiplier - - -class Hypernetwork(torch.nn.Module): - enable_sizes = [320, 640, 768, 1280] - # return self.modules[Hypernetwork.enable_sizes.index(size)] - - def __init__(self, multiplier=1.0) -> None: - super().__init__() - self.modules = [] - for size in Hypernetwork.enable_sizes: - self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) - self.register_module(f"{size}_0", self.modules[-1][0]) - self.register_module(f"{size}_1", self.modules[-1][1]) - - def apply_to_stable_diffusion(self, text_encoder, vae, unet): - blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks - for block in blocks: - for subblk in block: - if 'SpatialTransformer' in str(type(subblk)): - for tf_block in subblk.transformer_blocks: - for attn in [tf_block.attn1, tf_block.attn2]: - size = attn.context_dim - if size in Hypernetwork.enable_sizes: - attn.hypernetwork = self - else: - attn.hypernetwork = None - - def apply_to_diffusers(self, text_encoder, vae, unet): - blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks - for block in blocks: - if hasattr(block, 'attentions'): - for subblk in block.attentions: - if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ - for tf_block in subblk.transformer_blocks: - for attn in [tf_block.attn1, tf_block.attn2]: - size = attn.to_k.in_features - if size in Hypernetwork.enable_sizes: - attn.hypernetwork = self - else: - attn.hypernetwork = None - return True # TODO error checking - - def forward(self, x, context): - size = context.shape[-1] - assert size in Hypernetwork.enable_sizes - module = self.modules[Hypernetwork.enable_sizes.index(size)] - return module[0].forward(context), module[1].forward(context) - - def load_from_state_dict(self, state_dict): - # old ver to new ver - changes = { - 'linear1.bias': 'linear.0.bias', - 'linear1.weight': 'linear.0.weight', - 'linear2.bias': 'linear.1.bias', - 'linear2.weight': 'linear.1.weight', - } - for key_from, key_to in changes.items(): - if key_from in state_dict: - state_dict[key_to] = state_dict[key_from] - del state_dict[key_from] - - for size, sd in state_dict.items(): - if type(size) == int: - self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) - self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) - return True - - def get_state_dict(self): - state_dict = {} - for i, size in enumerate(Hypernetwork.enable_sizes): - sd0 = self.modules[i][0].state_dict() - sd1 = self.modules[i][1].state_dict() - state_dict[size] = [sd0, sd1] - return state_dict diff --git a/finetune/make_captions.py b/finetune/make_captions.py deleted file mode 100644 index be781995e..000000000 --- a/finetune/make_captions.py +++ /dev/null @@ -1,210 +0,0 @@ -import argparse -import glob -import os -import json -import random -import sys - -from pathlib import Path -from PIL import Image -from tqdm import tqdm -import numpy as np - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) # sys.path.append(os.path.dirname(__file__)) -from blip.blip import blip_decoder, is_url -import library.train_util as train_util -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -DEVICE = get_preferred_device() - - -IMAGE_SIZE = 384 - -# 正方形でいいのか? という気がするがソースがそうなので -IMAGE_TRANSFORM = transforms.Compose( - [ - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ] -) - - -# 共通化したいが微妙に処理が異なる…… -class ImageLoadingTransformDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths - - def __len__(self): - return len(self.images) - - def __getitem__(self, idx): - img_path = self.images[idx] - - try: - image = Image.open(img_path).convert("RGB") - # convert to tensor temporarily so dataloader will accept it - tensor = IMAGE_TRANSFORM(image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None - - return (tensor, img_path) - - -def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch - - -def main(args): - # fix the seed for reproducibility - seed = args.seed # + utils.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - - if not os.path.exists("blip"): - args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path - - cwd = os.getcwd() - logger.info(f"Current Working Directory is: {cwd}") - os.chdir("finetune") - if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): - args.caption_weights = os.path.join("..", args.caption_weights) - - logger.info(f"load images from {args.train_data_dir}") - train_data_dir_path = Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") - - logger.info(f"loading BLIP caption: {args.caption_weights}") - model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") - model.eval() - model = model.to(DEVICE) - logger.info("BLIP loaded") - - # captioningする - def run_batch(path_imgs): - imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) - - with torch.no_grad(): - if args.beam_search: - captions = model.generate( - imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length - ) - else: - captions = model.generate( - imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length - ) - - for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: - f.write(caption + "\n") - if args.debug: - logger.info(f'{image_path} {caption}') - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingTransformDataset(image_paths) - data = torch.utils.data.DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.max_data_loader_n_workers, - collate_fn=collate_fn_remove_corrupted, - drop_last=False, - ) - else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - img_tensor, image_path = data - if img_tensor is None: - try: - raw_image = Image.open(image_path) - if raw_image.mode != "RGB": - raw_image = raw_image.convert("RGB") - img_tensor = IMAGE_TRANSFORM(raw_image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - - b_imgs.append((image_path, img_tensor)) - if len(b_imgs) >= args.batch_size: - run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) - - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument( - "--caption_weights", - type=str, - default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", - help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", - ) - parser.add_argument( - "--caption_extention", - type=str, - default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", - ) - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument( - "--beam_search", - action="store_true", - help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument( - "--max_data_loader_n_workers", - type=int, - default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", - ) - parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") - parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") - parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") - parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") - parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") - parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - main(args) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py deleted file mode 100644 index edeebadf3..000000000 --- a/finetune/make_captions_by_git.py +++ /dev/null @@ -1,183 +0,0 @@ -import argparse -import os -import re - -from pathlib import Path -from PIL import Image -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -from transformers import AutoProcessor, AutoModelForCausalLM -from transformers.generation.utils import GenerationMixin - -import library.train_util as train_util -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -PATTERN_REPLACE = [ - re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), - re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), - re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), - re.compile(r"with the number \d+ on (it|\w+ \w+)"), - re.compile(r'with the words "'), - re.compile(r"word \w+ on it"), - re.compile(r"that says the word \w+ on it"), - re.compile("that says'the word \"( on it)?"), -] - -# 誤検知しまくりの with the word xxxx を消す - - -def remove_words(captions, debug): - removed_caps = [] - for caption in captions: - cap = caption - for pat in PATTERN_REPLACE: - cap = pat.sub("", cap) - if debug and cap != caption: - logger.info(caption) - logger.info(cap) - removed_caps.append(cap) - return removed_caps - - -def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch - - -def main(args): - r""" - transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト - - # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 - org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation - curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように - - # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す - # ここより上で置き換えようとするとすごく大変 - def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): - input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) - if input_ids.size()[0] != curr_batch_size[0]: - input_ids = input_ids.repeat(curr_batch_size[0], 1) - return input_ids - - GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch - """ - - logger.info(f"load images from {args.train_data_dir}") - train_data_dir_path = Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") - - # できればcacheに依存せず明示的にダウンロードしたい - logger.info(f"loading GIT: {args.model_id}") - git_processor = AutoProcessor.from_pretrained(args.model_id) - git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - logger.info("GIT loaded") - - # captioningする - def run_batch(path_imgs): - imgs = [im for _, im in path_imgs] - - # curr_batch_size[0] = len(path_imgs) - inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 - generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) - captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) - - if args.remove_words: - captions = remove_words(captions, args.debug) - - for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: - f.write(caption + "\n") - if args.debug: - logger.info(f"{image_path} {caption}") - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.utils.data.DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.max_data_loader_n_workers, - collate_fn=collate_fn_remove_corrupted, - drop_last=False, - ) - else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - image, image_path = data - if image is None: - try: - image = Image.open(image_path) - if image.mode != "RGB": - image = image.convert("RGB") - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - - b_imgs.append((image_path, image)) - if len(b_imgs) >= args.batch_size: - run_batch(b_imgs) - b_imgs.clear() - - if len(b_imgs) > 0: - run_batch(b_imgs) - - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument( - "--model_id", - type=str, - default="microsoft/git-large-textcaps", - help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID", - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument( - "--max_data_loader_n_workers", - type=int, - default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", - ) - parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") - parser.add_argument( - "--remove_words", - action="store_true", - help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する", - ) - parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - main(args) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py deleted file mode 100644 index 60765b863..000000000 --- a/finetune/merge_captions_to_metadata.py +++ /dev/null @@ -1,80 +0,0 @@ -import argparse -import json -from pathlib import Path -from typing import List -from tqdm import tqdm -import library.train_util as train_util -import os -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") - - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json - - if args.in_json is not None: - logger.info(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") - else: - logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} - - logger.info("merge caption texts to metadata json.") - for image_path in tqdm(image_paths): - caption_path = image_path.with_suffix(args.caption_extension) - caption = caption_path.read_text(encoding='utf-8').strip() - - if not os.path.exists(caption_path): - caption_path = os.path.join(image_path, args.caption_extension) - - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} - - metadata[image_key]['caption'] = caption - if args.debug: - logger.info(f"{image_key} {caption}") - - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - main(args) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py deleted file mode 100644 index 9ef8f14b0..000000000 --- a/finetune/merge_dd_tags_to_metadata.py +++ /dev/null @@ -1,75 +0,0 @@ -import argparse -import json -from pathlib import Path -from typing import List -from tqdm import tqdm -import library.train_util as train_util -import os -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def main(args): - assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" - - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") - - if args.in_json is None and Path(args.out_json).is_file(): - args.in_json = args.out_json - - if args.in_json is not None: - logger.info(f"loading existing metadata: {args.in_json}") - metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) - logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") - else: - logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") - metadata = {} - - logger.info("merge tags to metadata json.") - for image_path in tqdm(image_paths): - tags_path = image_path.with_suffix(args.caption_extension) - tags = tags_path.read_text(encoding='utf-8').strip() - - if not os.path.exists(tags_path): - tags_path = os.path.join(image_path, args.caption_extension) - - image_key = str(image_path) if args.full_path else image_path.stem - if image_key not in metadata: - metadata[image_key] = {} - - metadata[image_key]['tags'] = tags - if args.debug: - logger.info(f"{image_key} {tags}") - - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, - help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") - parser.add_argument("--caption_extension", type=str, default=".txt", - help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode, print tags") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - main(args) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py deleted file mode 100644 index 0389da388..000000000 --- a/finetune/prepare_buckets_latents.py +++ /dev/null @@ -1,265 +0,0 @@ -import argparse -import os -import json - -from pathlib import Path -from typing import List -from tqdm import tqdm -import numpy as np -from PIL import Image -import cv2 - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -from torchvision import transforms - -import library.model_util as model_util -import library.train_util as train_util -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -DEVICE = get_preferred_device() - -IMAGE_TRANSFORMS = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] -) - - -def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch - - -def get_npz_filename(data_dir, image_key, is_full_path, recursive): - if is_full_path: - base_name = os.path.splitext(os.path.basename(image_key))[0] - relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) - else: - base_name = image_key - relative_path = "" - - if recursive and relative_path: - return os.path.join(data_dir, relative_path, base_name) + ".npz" - else: - return os.path.join(data_dir, base_name) + ".npz" - - -def main(args): - # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" - if args.bucket_reso_steps % 8 > 0: - logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") - if args.bucket_reso_steps % 32 > 0: - logger.warning( - f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" - ) - - train_data_dir_path = Path(args.train_data_dir) - image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] - logger.info(f"found {len(image_paths)} images.") - - if os.path.exists(args.in_json): - logger.info(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding="utf-8") as f: - metadata = json.load(f) - else: - logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}") - return - - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - vae = model_util.load_vae(args.model_name_or_path, weight_dtype) - vae.eval() - vae.to(DEVICE, dtype=weight_dtype) - - # bucketのサイズを計算する - max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" - - bucket_manager = train_util.BucketManager( - args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps - ) - if not args.bucket_no_upscale: - bucket_manager.make_buckets() - else: - logger.warning( - "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" - ) - - # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する - img_ar_errors = [] - - def process_batch(is_last): - for bucket in bucket_manager.buckets: - if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) - bucket.clear() - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.utils.data.DataLoader( - dataset, - batch_size=1, - shuffle=False, - num_workers=args.max_data_loader_n_workers, - collate_fn=collate_fn_remove_corrupted, - drop_last=False, - ) - else: - data = [[(None, ip)] for ip in image_paths] - - bucket_counts = {} - for data_entry in tqdm(data, smoothing=0.0): - if data_entry[0] is None: - continue - - img_tensor, image_path = data_entry[0] - if img_tensor is not None: - image = transforms.functional.to_pil_image(img_tensor) - else: - try: - image = Image.open(image_path) - if image.mode != "RGB": - image = image.convert("RGB") - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - if image_key not in metadata: - metadata[image_key] = {} - - # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 - - reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) - img_ar_errors.append(abs(ar_error)) - bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 - - # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て - metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) - - if not args.bucket_no_upscale: - # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する - assert ( - resized_size[0] == reso[0] or resized_size[1] == reso[1] - ), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" - assert ( - resized_size[0] >= reso[0] and resized_size[1] >= reso[1] - ), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" - - assert ( - resized_size[0] >= reso[0] and resized_size[1] >= reso[1] - ), f"internal error resized size is small: {resized_size}, {reso}" - - # 既に存在するファイルがあればshape等を確認して同じならskipする - npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive) - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug): - continue - - # バッチへ追加 - image_info = train_util.ImageInfo(image_key, 1, "", False, image_path) - image_info.latents_npz = npz_file_name - image_info.bucket_reso = reso - image_info.resized_size = resized_size - image_info.image = image - bucket_manager.add_image(reso, image_info) - - # バッチを推論するか判定して推論する - process_batch(False) - - # 残りを処理する - process_batch(True) - - bucket_manager.sort() - for i, reso in enumerate(bucket_manager.resos): - count = bucket_counts.get(reso, 0) - if count > 0: - logger.info(f"bucket {i} {reso}: {count}") - img_ar_errors = np.array(img_ar_errors) - logger.info(f"mean ar error: {np.mean(img_ar_errors)}") - - # metadataを書き出して終わり - logger.info(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding="utf-8") as f: - json.dump(metadata, f, indent=2) - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument( - "--max_data_loader_n_workers", - type=int, - default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", - ) - parser.add_argument( - "--max_resolution", - type=str, - default="512,512", - help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)", - ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") - parser.add_argument( - "--bucket_reso_steps", - type=int, - default=64, - help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", - ) - parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" - ) - parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" - ) - parser.add_argument( - "--full_path", - action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", - ) - parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" - ) - parser.add_argument( - "--skip_existing", - action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", - ) - parser.add_argument( - "--recursive", - action="store_true", - help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py deleted file mode 100644 index b56d921a3..000000000 --- a/finetune/tag_images_by_wd14_tagger.py +++ /dev/null @@ -1,386 +0,0 @@ -import argparse -import csv -import os -from pathlib import Path - -import cv2 -import numpy as np -import torch -from huggingface_hub import hf_hub_download -from PIL import Image -from tqdm import tqdm - -import library.train_util as train_util -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# from wd14 tagger -IMAGE_SIZE = 448 - -# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 -DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" -FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] -FILES_ONNX = ["model.onnx"] -SUB_DIR = "variables" -SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] -CSV_FILE = FILES[-1] - - -def preprocess_image(image): - image = np.array(image) - image = image[:, :, ::-1] # RGB->BGR - - # pad to square - size = max(image.shape[0:2]) - pad_x = size - image.shape[1] - pad_y = size - image.shape[0] - pad_l = pad_x // 2 - pad_t = pad_y // 2 - image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) - - image = image.astype(np.float32) - return image - - -class ImageLoadingPrepDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths - - def __len__(self): - return len(self.images) - - def __getitem__(self, idx): - img_path = str(self.images[idx]) - - try: - image = Image.open(img_path).convert("RGB") - image = preprocess_image(image) - tensor = torch.tensor(image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None - - return (tensor, img_path) - - -def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch - - -def main(args): - # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする - # depreacatedの警告が出るけどなくなったらその時 - # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: - logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - files = FILES - if args.onnx: - files += FILES_ONNX - for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download( - args.repo_id, - file, - subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), - force_download=True, - force_filename=file, - ) - else: - logger.info("using existing wd14 tagger model") - - # 画像を読み込む - if args.onnx: - import onnx - import onnxruntime as ort - - onnx_path = f"{args.model_dir}/model.onnx" - logger.info("Running wd14 tagger with onnx") - logger.info(f"loading onnx model: {onnx_path}") - - if not os.path.exists(onnx_path): - raise Exception( - f"onnx model not found: {onnx_path}, please redownload the model with --force_download" - + " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください" - ) - - model = onnx.load(onnx_path) - input_name = model.graph.input[0].name - try: - batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value - except: - batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - - if args.batch_size != batch_size and type(batch_size) != str: - # some rebatch model may use 'N' as dynamic axes - logger.warning( - f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" - ) - args.batch_size = batch_size - - del model - - ort_sess = ort.InferenceSession( - onnx_path, - providers=["CUDAExecutionProvider"] - if "CUDAExecutionProvider" in ort.get_available_providers() - else ["CPUExecutionProvider"], - ) - else: - from tensorflow.keras.models import load_model - - model = load_model(f"{args.model_dir}") - - # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: - reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] - assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - - general_tags = [row[1] for row in rows[1:] if row[2] == "0"] - character_tags = [row[1] for row in rows[1:] if row[2] == "4"] - - # 画像を読み込む - - train_data_dir_path = Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) - logger.info(f"found {len(image_paths)} images.") - - tag_freq = {} - - caption_separator = args.caption_separator - stripped_caption_separator = caption_separator.strip() - undesired_tags = set(args.undesired_tags.split(stripped_caption_separator)) - - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) - - if args.onnx: - if len(imgs) < args.batch_size: - imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) - probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy - probs = probs[: len(path_imgs)] - else: - probs = model(imgs, training=False) - probs = probs.numpy() - - for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] - - # それ以降はタグなのでconfidenceがthresholdより高いものを追加する - # Everything else is tags: pick any where prediction confidence > threshold - combined_tags = [] - general_tag_text = "" - character_tag_text = "" - for i, p in enumerate(prob[4:]): - if i < len(general_tags) and p >= args.general_threshold: - tag_name = general_tags[i] - if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^ - tag_name = tag_name.replace("_", " ") - - if tag_name not in undesired_tags: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += caption_separator + tag_name - combined_tags.append(tag_name) - elif i >= len(general_tags) and p >= args.character_threshold: - tag_name = character_tags[i - len(general_tags)] - if args.remove_underscore and len(tag_name) > 3: - tag_name = tag_name.replace("_", " ") - - if tag_name not in undesired_tags: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += caption_separator + tag_name - combined_tags.append(tag_name) - - # 先頭のカンマを取る - if len(general_tag_text) > 0: - general_tag_text = general_tag_text[len(caption_separator) :] - if len(character_tag_text) > 0: - character_tag_text = character_tag_text[len(caption_separator) :] - - caption_file = os.path.splitext(image_path)[0] + args.caption_extension - - tag_text = caption_separator.join(combined_tags) - - if args.append_tags: - # Check if file exists - if os.path.exists(caption_file): - with open(caption_file, "rt", encoding="utf-8") as f: - # Read file and remove new lines - existing_content = f.read().strip("\n") # Remove newlines - - # Split the content into tags and store them in a list - existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()] - - # Check and remove repeating tags in tag_text - new_tags = [tag for tag in combined_tags if tag not in existing_tags] - - # Create new tag_text - tag_text = caption_separator.join(existing_tags + new_tags) - - with open(caption_file, "wt", encoding="utf-8") as f: - f.write(tag_text + "\n") - if args.debug: - logger.info("") - logger.info(f"{image_path}:") - logger.info(f"\tCharacter tags: {character_tag_text}") - logger.info(f"\tGeneral tags: {general_tag_text}") - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingPrepDataset(image_paths) - data = torch.utils.data.DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.max_data_loader_n_workers, - collate_fn=collate_fn_remove_corrupted, - drop_last=False, - ) - else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - image, image_path = data - if image is not None: - image = image.detach().numpy() - else: - try: - image = Image.open(image_path) - if image.mode != "RGB": - image = image.convert("RGB") - image = preprocess_image(image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - b_imgs.append((image_path, image)) - - if len(b_imgs) >= args.batch_size: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) - b_imgs.clear() - - if len(b_imgs) > 0: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) - - if args.frequency_tags: - sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) - print("Tag frequencies:") - for tag, freq in sorted_tags: - print(f"{tag}: {freq}") - - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument( - "--repo_id", - type=str, - default=DEFAULT_WD14_TAGGER_REPO, - help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID", - ) - parser.add_argument( - "--model_dir", - type=str, - default="wd14_tagger_model", - help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", - ) - parser.add_argument( - "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument( - "--max_data_loader_n_workers", - type=int, - default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", - ) - parser.add_argument( - "--caption_extention", - type=str, - default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", - ) - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") - parser.add_argument( - "--general_threshold", - type=float, - default=None, - help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", - ) - parser.add_argument( - "--character_threshold", - type=float, - default=None, - help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", - ) - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") - parser.add_argument( - "--remove_underscore", - action="store_true", - help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", - ) - parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument( - "--undesired_tags", - type=str, - default="", - help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", - ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") - parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") - parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") - parser.add_argument( - "--caption_separator", - type=str, - default=", ", - help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - if args.general_threshold is None: - args.general_threshold = args.thresh - if args.character_threshold is None: - args.character_threshold = args.thresh - - main(args) diff --git a/gen_img.py b/gen_img.py deleted file mode 100644 index daf88d2a1..000000000 --- a/gen_img.py +++ /dev/null @@ -1,3326 +0,0 @@ -import itertools -import json -from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable -import glob -import importlib -import importlib.util -import sys -import inspect -import time -import zipfile -from diffusers.utils import deprecate -from diffusers.configuration_utils import FrozenDict -import argparse -import math -import os -import random -import re - -import diffusers -import numpy as np -import torch - -from library.device_utils import init_ipex, clean_memory, get_preferred_device - -init_ipex() - -import torchvision -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - # UNet2DConditionModel, - StableDiffusionPipeline, -) -from einops import rearrange -from tqdm import tqdm -from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor -import PIL -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import library.model_util as model_util -import library.train_util as train_util -import library.sdxl_model_util as sdxl_model_util -import library.sdxl_train_util as sdxl_train_util -from networks.lora import LoRANetwork -import tools.original_control_net as original_control_net -from tools.original_control_net import ControlNetInfo -from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel -from library.sdxl_original_unet import InferSdxlUNet2DConditionModel -from library.original_unet import FlashAttentionFunction -from networks.control_net_lllite import ControlNetLLLite -from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - -# その他の設定 -LATENT_CHANNELS = 4 -DOWNSAMPLING_FACTOR = 8 - -CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - print("Enable memory efficient attention for U-Net") - - # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - print("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - print("Enable SDPA for U-Net") - unet.set_use_memory_efficient_attention(False, False) - unet.set_use_sdpa(True) - - -# TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? - vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う - elif sdpa: - replace_vae_attn_to_sdpa() - - -def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_flash_attn(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_flash_attn - - -def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") - import xformers.ops - - def forward_xformers(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_xformers_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_xformers(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_xformers - - -def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") - - def forward_sdpa(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = torch.nn.functional.scaled_dot_product_attention( - query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - out = rearrange(out, "b n h d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_sdpa_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_sdpa(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_sdpa - - -# endregion - -# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 -# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 - - -class PipelineLike: - def __init__( - self, - is_sdxl, - device, - vae: AutoencoderKL, - text_encoders: List[CLIPTextModel], - tokenizers: List[CLIPTokenizer], - unet: InferSdxlUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - ): - super().__init__() - self.is_sdxl = is_sdxl - self.device = device - self.clip_skip = clip_skip - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.vae = vae - self.text_encoders = text_encoders - self.tokenizers = tokenizers - self.unet: Union[InferUNet2DConditionModel, InferSdxlUNet2DConditionModel] = unet - self.scheduler = scheduler - self.safety_checker = None - - self.clip_vision_model: CLIPVisionModelWithProjection = None - self.clip_vision_processor: CLIPImageProcessor = None - self.clip_vision_strength = 0.0 - - # Textual Inversion - self.token_replacements_list = [] - for _ in range(len(self.text_encoders)): - self.token_replacements_list.append({}) - - # ControlNet - self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5 - self.control_net_lllites: List[ControlNetLLLite] = [] - self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - - self.gradual_latent: GradualLatent = None - - # Textual Inversion - def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): - self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids - - def set_enable_control_net(self, en: bool): - self.control_net_enabled = en - - def get_token_replacer(self, tokenizer): - tokenizer_index = self.tokenizers.index(tokenizer) - token_replacements = self.token_replacements_list[tokenizer_index] - - def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) - if isinstance(tokens, torch.Tensor): - tokens = tokens.tolist() - - new_tokens = [] - for token in tokens: - if token in token_replacements: - replacement = token_replacements[token] - new_tokens.extend(replacement) - else: - new_tokens.append(token) - return new_tokens - - return replace_tokens - - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets - - def set_control_net_lllites(self, ctrl_net_lllites): - self.control_net_lllites = ctrl_net_lllites - - def set_gradual_latent(self, gradual_latent): - if gradual_latent is None: - print("gradual_latent is disabled") - self.gradual_latent = None - else: - print(f"gradual_latent is enabled: {gradual_latent}") - self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 1024, - width: int = 1024, - original_height: int = None, - original_width: int = None, - original_height_negative: int = None, - original_width_negative: int = None, - crop_top: int = 0, - crop_left: int = 0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_guide_images=None, - emb_normalize_mode: str = "original", - **kwargs, - ): - # TODO support secondary prompt - num_images_per_prompt = 1 # fixed because already prompt is repeated - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - regional_network = " AND " in prompt[0] - - vae_batch_size = ( - batch_size - if vae_batch_size is None - else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - ) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - tes_text_embs = [] - tes_uncond_embs = [] - tes_real_uncond_embs = [] - - for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): - token_replacer = self.get_token_replacer(tokenizer) - - # use last text_pool, because it is from text encoder 2 - text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( - self.is_sdxl, - tokenizer, - text_encoder, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - emb_normalize_mode=emb_normalize_mode, - **kwargs, - ) - tes_text_embs.append(text_embeddings) - tes_uncond_embs.append(uncond_embeddings) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - self.is_sdxl, - token_replacer, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""] * batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - emb_normalize_mode=emb_normalize_mode, - **kwargs, - ) - tes_real_uncond_embs.append(real_uncond_embeddings) - - # concat text encoder outputs - text_embeddings = tes_text_embs[0] - uncond_embeddings = tes_uncond_embs[0] - for i in range(1, len(tes_text_embs)): - text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - if do_classifier_free_guidance: - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - if self.control_net_lllites: - # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - if isinstance(clip_guide_images[0], PIL.Image.Image): - clip_guide_images = [preprocess_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images) - if isinstance(clip_guide_images, list): - clip_guide_images = torch.stack(clip_guide_images) - - clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) - - # create size embs - if original_height is None: - original_height = height - if original_width is None: - original_width = width - if original_height_negative is None: - original_height_negative = original_height - if original_width_negative is None: - original_width_negative = original_width - if crop_top is None: - crop_top = 0 - if crop_left is None: - crop_left = 0 - if self.is_sdxl: - emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - uc_emb1 = sdxl_train_util.get_timestep_embedding( - torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 - ) - emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - - if regional_network: - # use last pool for conditioning - num_sub_prompts = len(text_pool) // batch_size - text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt - - if init_image is not None and self.clip_vision_model is not None: - print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") - vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) - pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) - - clip_vision_embeddings = self.clip_vision_model( - pixel_values=pixel_values, output_hidden_states=True, return_dict=True - ) - clip_vision_embeddings = clip_vision_embeddings.image_embeds - - if len(clip_vision_embeddings) == 1 and batch_size > 1: - clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) - - clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength - assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" - text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) - - c_vector = torch.cat([text_pool, c_vector], dim=1) - if do_classifier_free_guidance: - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - vector_embeddings = torch.cat([uc_vector, c_vector]) - else: - vector_embeddings = c_vector - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) - - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[-2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - init_latents = [] - for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): - init_latent_dist = self.vae.encode( - (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( - self.vae.dtype - ) - ).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) - - init_latents = (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * init_latents - - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) - - if self.control_net_lllites: - # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - if self.control_net_enabled: - for control_net, _ in self.control_net_lllites: - with torch.no_grad(): - control_net.set_cond_image(clip_guide_images) - else: - for control_net, _ in self.control_net_lllites: - control_net.set_cond_image(None) - - each_control_net_enabled = [self.control_net_enabled] * len(self.control_net_lllites) - - enable_gradual_latent = False - if self.gradual_latent: - if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) - else: - enable_gradual_latent = True - step_elapsed = 1000 - current_ratio = self.gradual_latent.ratio - - # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする - height, width = latents.shape[-2:] - org_dtype = latents.dtype - if org_dtype == torch.bfloat16: - latents = latents.float() - latents = torch.nn.functional.interpolate( - latents, scale_factor=current_ratio, mode="bicubic", align_corners=False - ).to(org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - if self.gradual_latent.gaussian_blur_ksize: - latents = self.gradual_latent.apply_unshark_mask(latents) - - for i, t in enumerate(tqdm(timesteps)): - resized_size = None - if enable_gradual_latent: - # gradually upscale the latents / latentsを徐々にアップスケールする - if ( - t < self.gradual_latent.start_timesteps - and current_ratio < 1.0 - and step_elapsed >= self.gradual_latent.every_n_steps - ): - current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) - # make divisible by 8 because size of latents must be divisible at bottom of UNet - h = int(height * current_ratio) // 8 * 8 - w = int(width * current_ratio) // 8 * 8 - resized_size = (h, w) - self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) - step_elapsed = 0 - else: - self.scheduler.set_gradual_latent_params(None, None) - step_elapsed += 1 - - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo - if self.control_net_lllites: - for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)): - if not enabled or ratio >= 1.0: - continue - if ratio < i / len(timesteps): - print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") - control_net.set_cond_image(None) - each_control_net_enabled[j] = False - - # predict the noise residual - if self.control_nets and self.control_net_enabled: - if regional_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - else: - text_emb_last = text_embeddings - - noise_pred = original_control_net.call_unet_and_control_net( - i, - num_latent_input, - self.unet, - self.control_nets, - guided_hints, - i / len(timesteps), - latent_model_input, - t, - text_embeddings, - text_emb_last, - ).sample - elif self.is_sdxl: - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) - else: - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( - num_latent_input - ) # uncond is real uncond - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - - negative_scale * (noise_pred_negative - noise_pred_uncond) - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - if return_latents: - return latents - - latents = 1 / (sdxl_model_util.VAE_SCALE_FACTOR if self.is_sdxl else 0.18215) * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents.to(self.vae.dtype)).sample - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append( - self.vae.decode( - (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) - ).sample - ) - image = torch.cat(images) - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - return image - - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - # keep break as separate token - text = text.replace("BREAK", "\\BREAK\\") - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - if word.strip() == "BREAK": - # pad until next multiple of tokenizer's max token length - pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") - for i in range(pad_len): - # v2のときEOSをつけるべきかどうかわからないぜ - # if i == 0: - # text_token.append(tokenizer.eos_token_id) - # else: - text_token.append(tokenizer.pad_token_id) - text_weight.append(1.0) - continue - - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - - token = token_replacer(token) # for Textual Inversion - - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - is_sdxl: bool, - text_encoder: CLIPTextModel, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - pool = None - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-clip_skip] - if not is_sdxl: # SD 1.5 requires final_layer_norm - text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) - if pool is None: - pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-clip_skip] - if not is_sdxl: # SD 1.5 requires final_layer_norm - text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) - pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) - return text_embeddings, pool - - -def get_weighted_text_embeddings( - is_sdxl: bool, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip: int = 1, - token_replacer=None, - device=None, - emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" - **kwargs, -): - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - # split the prompts with "AND". each prompt must have the same number of splits - new_prompts = [] - for p in prompt: - new_prompts.extend(p.split(" AND ")) - prompt = new_prompts - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings, text_pool = get_unweighted_text_embeddings( - is_sdxl, - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - if uncond_prompt is not None: - uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( - is_sdxl, - text_encoder, - uncond_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - - if (not skip_parsing) and (not skip_weighting): - if emb_normalize_mode == "abs": - previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - elif emb_normalize_mode == "none": - text_embeddings *= prompt_weights.unsqueeze(-1) - if uncond_prompt is not None: - uncond_embeddings *= uncond_weights.unsqueeze(-1) - - else: # "original" - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens - return text_embeddings, text_pool, None, None, prompt_tokens - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -# regular expression for dynamic prompt: -# starts and ends with "{" and "}" -# contains at least one variant divided by "|" -# optional framgments divided by "$$" at start -# if the first fragment is "E" or "e", enumerate all variants -# if the second fragment is a number or two numbers, repeat the variants in the range -# if the third fragment is a string, use it as a separator - -RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") - - -def handle_dynamic_prompt_variants(prompt, repeat_count): - founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) - if not founds: - return [prompt] - - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating - - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") - - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] - else: - print(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) - - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values - - return replacer - - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] - - return replacer - - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts - - -# endregion - -# def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - -class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any - raw_prompt: str - - -class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - original_width: int - original_height: int - original_width_negative: int - original_height_negative: int - crop_left: int - crop_top: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] - num_sub_prompts: int - - -class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt - - -class ListPrompter: - def __init__(self, prompts: List[str]): - self.prompts = prompts - self.index = 0 - - def shuffle(self): - random.shuffle(self.prompts) - - def __len__(self): - return len(self.prompts) - - def __call__(self, *args, **kwargs): - if self.index >= len(self.prompts): - self.index = 0 # reset - return None - - prompt = self.prompts[self.index] - self.index += 1 - return prompt - - -def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - - # モデルを読み込む - if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - name_or_path = os.readlink(args.ckpt) if os.path.islink(args.ckpt) else args.ckpt - use_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers - - # SDXLかどうかを判定する - is_sdxl = args.sdxl - if not is_sdxl and not args.v1 and not args.v2: # どれも指定されていない場合は自動で判定する - if use_stable_diffusion_format: - # if file size > 5.5GB, sdxl - is_sdxl = os.path.getsize(name_or_path) > 5.5 * 1024**3 - else: - # if `text_encoder_2` subdirectory exists, sdxl - is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2")) - print(f"SDXL: {is_sdxl}") - - if is_sdxl: - if args.clip_skip is None: - args.clip_skip = 2 - - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - text_encoders = [text_encoder1, text_encoder2] - else: - if args.clip_skip is None: - args.clip_skip = 2 if args.v2 else 1 - - if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) - else: - print("load Diffusers pretrained models") - loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - text_encoder = loading_pipe.text_encoder - vae = loading_pipe.vae - unet = loading_pipe.unet - tokenizer = loading_pipe.tokenizer - del loading_pipe - - # Diffusers U-Net to original U-Net - original_unet = UNet2DConditionModel( - unet.config.sample_size, - unet.config.attention_head_dim, - unet.config.cross_attention_dim, - unet.config.use_linear_projection, - unet.config.upcast_attention, - ) - original_unet.load_state_dict(unet.state_dict()) - unet = original_unet - unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) - text_encoders = [text_encoder] - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - mem_eff = not (args.xformers or args.sdpa) - replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) - replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) - - # tokenizerを読み込む - print("loading tokenizer") - if is_sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - if use_stable_diffusion_format: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] - - # schedulerを用意する - sched_init_args = {} - has_steps_offset = True - has_clip_sample = True - scheduler_num_noises_per_step = 1 - - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - has_clip_sample = False - elif args.sampler == "lms" or args.sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - has_clip_sample = False - elif args.sampler == "euler" or args.sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - has_clip_sample = False - elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteSchedulerGL - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - has_clip_sample = False - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - has_clip_sample = False - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - has_clip_sample = False - has_steps_offset = False - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - has_clip_sample = False - elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - has_clip_sample = False - elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - has_clip_sample = False - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - # 警告を出さないようにする - if has_steps_offset: - sched_init_args["steps_offset"] = 1 - if has_clip_sample: - sched_init_args["clip_sample"] = False - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == "randn": - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # ↓以下は結局PipeでFalseに設定されるので意味がなかった - # # clip_sample=Trueにする - # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - # scheduler.config.clip_sample = True - - # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - - # custom pipelineをコピったやつを生成する - if args.vae_slices: - from library.slicing_vae import SlicingAutoencoderKL - - sli_vae = SlicingAutoencoderKL( - act_fn="silu", - block_out_channels=(128, 256, 512, 512), - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], - in_channels=3, - latent_channels=4, - layers_per_block=2, - norm_num_groups=32, - out_channels=3, - sample_size=512, - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], - num_slices=args.vae_slices, - ) - sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする - vae = sli_vae - del sli_vae - - vae_dtype = dtype - if args.no_half_vae: - print("set vae_dtype to float32") - vae_dtype = torch.float32 - vae.to(vae_dtype).to(device) - vae.eval() - - for text_encoder in text_encoders: - text_encoder.to(dtype).to(device) - text_encoder.eval() - unet.to(dtype).to(device) - unet.eval() - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - network_pre_calc = args.network_pre_calc - - # merge関連の引数を統合する - if args.network_merge: - network_merge = len(args.network_module) # all networks are merged - elif args.network_merge_n_models: - network_merge = args.network_merge_n_models - else: - network_merge = 0 - print(f"network_merge: {network_merge}") - - for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights is None or len(args.network_weights) <= i: - raise ValueError("No weight. Weight is required.") - - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs - ) - if network is None: - return - - mergeable = network.is_mergeable() - if network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") - - if not mergeable or i >= network_merge: - # not merging - network.apply_to(text_encoders, unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - if network_pre_calc: - print("backup original weights") - network.backup_weights() - - networks.append(network) - network_default_muls.append(network_mul) - else: - network.merge_to(text_encoders, unet, weights_sd, dtype, device) - - else: - networks = [] - - # upscalerの指定があれば取得する - upscaler = None - if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) - imported_module = importlib.import_module(args.highres_fix_upscaler) - - us_kwargs = {} - if args.highres_fix_upscaler_args: - for net_arg in args.highres_fix_upscaler_args.split(";"): - key, value = net_arg.split("=") - us_kwargs[key] = value - - print("create upscaler") - upscaler = imported_module.create_upscaler(**us_kwargs) - upscaler.to(dtype).to(device) - - # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) - - control_net_lllites: List[Tuple[ControlNetLLLite, float]] = [] - if args.control_net_lllite_models: - for i, model_file in enumerate(args.control_net_lllite_models): - print(f"loading ControlNet-LLLite: {model_file}") - - from safetensors.torch import load_file - - state_dict = load_file(model_file) - mlp_dim = None - cond_emb_dim = None - for key, value in state_dict.items(): - if mlp_dim is None and "down.0.weight" in key: - mlp_dim = value.shape[0] - elif cond_emb_dim is None and "conditioning1.0" in key: - cond_emb_dim = value.shape[0] * 2 - if mlp_dim is not None and cond_emb_dim is not None: - break - assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" - - multiplier = ( - 1.0 - if not args.control_net_multipliers or len(args.control_net_multipliers) <= i - else args.control_net_multipliers[i] - ) - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - control_net_lllite = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) - control_net_lllite.apply_to() - control_net_lllite.load_state_dict(state_dict) - control_net_lllite.to(dtype).to(device) - control_net_lllite.set_batch_cond_only(False, False) - control_net_lllites.append((control_net_lllite, ratio)) - assert ( - len(control_nets) == 0 or len(control_net_lllites) == 0 - ), "ControlNet and ControlNet-LLLite cannot be used at the same time" - - if args.opt_channels_last: - print(f"set optimizing: channels last") - for text_encoder in text_encoders: - text_encoder.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.to(memory_format=torch.channels_last) - - for cn in control_net_lllites: - cn.to(memory_format=torch.channels_last) - - pipe = PipelineLike( - is_sdxl, - device, - vae, - text_encoders, - tokenizers, - unet, - scheduler, - args.clip_skip, - ) - pipe.set_control_nets(control_nets) - pipe.set_control_net_lllites(control_net_lllites) - print("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Deep Shrink - if args.ds_depth_1 is not None: - unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) - - # Gradual Latent - if args.gradual_latent_timesteps is not None: - if args.gradual_latent_unsharp_params: - us_params = args.gradual_latent_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] - us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - - gradual_latent = GradualLatent( - args.gradual_latent_ratio, - args.gradual_latent_timesteps, - args.gradual_latent_every_n_steps, - args.gradual_latent_ratio_step, - args.gradual_latent_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds1 = [] - token_ids_embeds2 = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - if "string_to_param" in data: - data = data["string_to_param"] - if is_sdxl: - - embeds1 = data["clip_l"] # text encoder 1 - embeds2 = data["clip_g"] # text encoder 2 - else: - embeds1 = next(iter(data.values())) - embeds2 = None - - num_vectors_per_token = embeds1.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens1 = tokenizers[0].add_tokens(token_strings) - num_added_tokens2 = tokenizers[1].add_tokens(token_strings) if is_sdxl else 0 - assert num_added_tokens1 == num_vectors_per_token and ( - num_added_tokens2 == 0 or num_added_tokens2 == num_vectors_per_token - ), ( - f"tokenizer has same word to token string (filename): {embeds_file}" - + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" - ) - - token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings) - token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") - assert ( - min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 - ), f"token ids1 is not ordered" - assert not is_sdxl or ( - min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 - ), f"token ids2 is not ordered" - assert len(tokenizers[0]) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizers[0])}" - assert ( - not is_sdxl or len(tokenizers[1]) - 1 == token_ids2[-1] - ), f"token ids 2 is not end of tokenize: {len(tokenizers[1])}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... - if is_sdxl: - pipe.add_token_replacement(1, token_ids2[0], token_ids2) - - token_ids_embeds1.append((token_ids1, embeds1)) - if is_sdxl: - token_ids_embeds2.append((token_ids2, embeds2)) - - text_encoders[0].resize_token_embeddings(len(tokenizers[0])) - token_embeds1 = text_encoders[0].get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds1: - for token_id, embed in zip(token_ids, embeds): - token_embeds1[token_id] = embed - - if is_sdxl: - text_encoders[1].resize_token_embeddings(len(tokenizers[1])) - token_embeds2 = text_encoders[1].get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds2: - for token_id, embed in zip(token_ids, embeds): - token_embeds2[token_id] = embed - - # promptを取得する - prompt_list = None - if args.from_file is not None: - print(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] - prompter = ListPrompter(prompt_list) - - elif args.from_module is not None: - - def load_module_from_path(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'") - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - print(f"reading prompts from module: {args.from_module}") - prompt_module = load_module_from_path("prompt_module", args.from_module) - - prompter = prompt_module.get_prompter(args, pipe, networks) - - elif args.prompt is not None: - prompter = ListPrompter([args.prompt]) - - else: - prompter = None # interactive mode - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] - else: - paths = ( - glob.glob(os.path.join(path, "*.png")) - + glob.glob(os.path.join(path, "*.jpg")) - + glob.glob(os.path.join(path, "*.jpeg")) - + glob.glob(os.path.join(path, "*.webp")) - ) - paths.sort() - - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) - - return images - - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, "filename"): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized - - if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") - - # CLIP Vision - if args.clip_vision_strength is not None: - print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") - vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) - vision_model.to(device, dtype) - processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) - - pipe.clip_vision_model = vision_model - pipe.clip_vision_processor = processor - pipe.clip_vision_strength = args.clip_vision_strength - print(f"CLIP Vision model loaded.") - - else: - init_images = None - - if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None - - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and prompter is None and not args.interactive: - print("get prompts from images' metadata") - prompt_list = [] - for img in init_images: - if "prompt" in img.text: - prompt = img.text["prompt"] - if "negative-prompt" in img.text: - prompt += " --n " + img.text["negative-prompt"] - prompt_list.append(prompt) - prompter = ListPrompter(prompt_list) - - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l - - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l - - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - # highres fix を考慮に入れる - w, h = args.W, args.H - if highres_fix: - w = int(w * args.highres_fix_scale + 0.5) - h = int(h * args.highres_fix_scale + 0.5) - - if init_images is not None: - print(f"resize img2img source images to {w}*{h}") - init_images = resize_images(init_images, (w, h)) - if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") - mask_images = resize_images(mask_images, (w, h)) - - regional_network = False - if networks and mask_images: - # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 - regional_network = True - print("use mask as region") - - size = None - for i, network in enumerate(networks): - if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: - np_mask = np.array(mask_images[0]) - - if args.network_regional_mask_max_color_codes: - # カラーコードでマスクを指定する - ch0 = (i + 1) & 1 - ch1 = ((i + 1) >> 1) & 1 - ch2 = ((i + 1) >> 2) & 1 - np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) - np_mask = np_mask.astype(np.uint8) * 255 - else: - np_mask = np_mask[:, :, i] - size = np_mask.shape - else: - np_mask = np.full(size, 255, dtype=np.uint8) - mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) - network.set_region(i, i == len(networks) - 1, mask) - mask_images = None - - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) - - print(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - print( - f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" - ) - guide_images = None - else: - guide_images = None - - # 新しい乱数生成器を作成する - if args.seed is not None: - if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1: - # 引数のseedをそのまま使う - def fixed_seed(*args, **kwargs): - return args.seed - - seed_random = SimpleNamespace(randint=fixed_seed) - else: - seed_random = random.Random(args.seed) - else: - seed_random = random.Random() - - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 1024 if is_sdxl else 512 - if args.H is None: - args.H = 1024 if is_sdxl else 512 - - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - - for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") - if args.iter_same_seed: - iter_seed = seed_random.randint(0, 2**32 - 1) - else: - iter_seed = None - - # shuffle prompt list - if args.shuffle_prompts: - prompter.shuffle() - - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) - - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - - print("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - - def scale_and_round(x): - if x is None: - return None - return int(x * args.highres_fix_scale + 0.5) - - width_1st = scale_and_round(ext.width) - height_1st = scale_and_round(ext.height) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 - - original_width_1st = scale_and_round(ext.original_width) - original_height_1st = scale_and_round(ext.original_height) - original_width_negative_1st = scale_and_round(ext.original_width_negative) - original_height_negative_1st = scale_and_round(ext.original_height_negative) - crop_left_1st = scale_and_round(ext.crop_left) - crop_top_1st = scale_and_round(ext.crop_top) - - strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength - - ext_1st = BatchDataExt( - width_1st, - height_1st, - original_width_1st, - original_height_1st, - original_width_negative_1st, - original_height_negative_1st, - crop_left_1st, - crop_top_1st, - args.highres_fix_steps, - ext.scale, - ext.negative_scale, - strength_1st, - ext.network_muls, - ext.num_sub_prompts, - ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) - - pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) - - # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") - width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height - - if upscaler: - # upscalerを使って画像を拡大する - lowreso_imgs = None if is_1st_latent else images_1st - lowreso_latents = None if not is_1st_latent else images_1st - - # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents - batch_size = len(images_1st) - vae_batch_size = ( - batch_size - if args.vae_batch_size is None - else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) - ) - vae_batch_size = int(vae_batch_size) - images_1st = upscaler.upscale( - vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size - ) - - elif args.highres_fix_latents_upscaling: - # latentを拡大する - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" - ) # , antialias=True) - images_1st = images_1st.to(org_dtype) - - else: - # 画像をLANCZOSで拡大する - images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] - - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd - - if args.highres_fix_disable_control_net: - pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする - - # このバッチの情報を取り出す - ( - return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), - ( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - network_muls, - num_sub_prompts, - ), - ) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) - - prompts = [] - negative_prompts = [] - raw_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [ - torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step) - ] - seeds = [] - clip_prompts = [] - - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] - - if mask_image is not None: - mask_images = [] - else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None - - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None - - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, ( - _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), - _, - ) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - raw_prompts.append(raw_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets or control_net_lllites: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - # 追加ネットワークの処理 - shared = {} - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - if regional_network: - # TODO バッチから ds_ratio を取り出すべき - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio) - - if not regional_network and network_pre_calc: - for n in networks: - n.restore_weights() - for n in networks: - n.pre_calculation() - print("pre-calculation... done") - - images = pipe( - prompts, - negative_prompts, - init_images, - mask_images, - height, - width, - original_height, - original_width, - original_height_negative, - original_width_negative, - crop_top, - crop_left, - steps, - scale, - negative_scale, - strength, - latents=start_code, - output_type="pil", - max_embeddings_multiples=max_embeddings_multiples, - img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, - return_latents=return_latents, - clip_prompts=clip_prompts, - clip_guide_images=guide_images, - emb_normalize_mode=args.emb_normalize_mode, - ) - if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images - - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - if is_sdxl: - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - print( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) - - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while True: - if args.interactive: - # interactive - valid = False - while not valid: - print("\nType prompt:") - try: - raw_prompt = input() - except EOFError: - break - - valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step) - if raw_prompt is None: - break - - # sd-dynamic-prompts like variants: - # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - - if pi == 0 or len(raw_prompts) > 1: - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - length = len(prompter) if hasattr(prompter, "__len__") else 0 - print(f"prompt {prompt_index+1}/{length}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - print(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - print(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue - - # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - print(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - print(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if args.iter_same_seed: - seed = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = seed_random.randint(0, 2**32 - 1) - if args.interactive: - print(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - print( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets or control_net_lllites: # 複数件の場合あり - c = max(len(control_nets), len(control_net_lllites)) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() - - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() - - global_step += 1 - - prompt_index += 1 - - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() - - print("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument( - "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" - ) - parser.add_argument( - "--v1", action="store_true", help="load Stable Diffusion v1.x model / Stable Diffusion 1.xのモデルを読み込む" - ) - parser.add_argument( - "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" - ) - parser.add_argument( - "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" - ) - - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument( - "--from_file", - type=str, - default=None, - help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", - ) - parser.add_argument( - "--from_module", - type=str, - default=None, - help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む", - ) - parser.add_argument( - "--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数" - ) - parser.add_argument( - "--interactive", - action="store_true", - help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", - ) - parser.add_argument( - "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" - ) - parser.add_argument( - "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" - ) - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument( - "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" - ) - parser.add_argument( - "--use_original_file_name", - action="store_true", - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", - ) - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument( - "--original_height", - type=int, - default=None, - help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width", - type=int, - default=None, - help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", - ) - parser.add_argument( - "--original_height_negative", - type=int, - default=None, - help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width_negative", - type=int, - default=None, - help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", - ) - parser.add_argument( - "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" - ) - parser.add_argument( - "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument( - "--vae_batch_size", - type=float, - default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", - ) - parser.add_argument( - "--vae_slices", - type=int, - default=None, - help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", - ) - parser.add_argument( - "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" - ) - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument( - "--sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", - ) - parser.add_argument( - "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" - ) - parser.add_argument( - "--vae", - type=str, - default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument( - "--seed", - type=int, - default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", - ) - parser.add_argument( - "--iter_same_seed", - action="store_true", - help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", - ) - parser.add_argument( - "--shuffle_prompts", - action="store_true", - help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", - ) - parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") - parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") - parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") - parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") - parser.add_argument( - "--diffusers_xformers", - action="store_true", - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - "--opt_channels_last", - action="store_true", - help="set channels last option to model / モデルにchannels lastを指定し最適化する", - ) - parser.add_argument( - "--network_module", - type=str, - default=None, - nargs="*", - help="additional network module to use / 追加ネットワークを使う時そのモジュール名", - ) - parser.add_argument( - "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" - ) - parser.add_argument( - "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" - ) - parser.add_argument( - "--network_args", - type=str, - default=None, - nargs="*", - help="additional arguments for network (key=value) / ネットワークへの追加の引数", - ) - parser.add_argument( - "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" - ) - parser.add_argument( - "--network_merge_n_models", - type=int, - default=None, - help="merge this number of networks / この数だけネットワークをマージする", - ) - parser.add_argument( - "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" - ) - parser.add_argument( - "--network_pre_calc", - action="store_true", - help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", - ) - parser.add_argument( - "--network_regional_mask_max_color_codes", - type=int, - default=None, - help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", - ) - parser.add_argument( - "--textual_inversion_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", - ) - parser.add_argument( - "--clip_skip", - type=int, - default=None, - help="layer number from bottom to use in CLIP, default is 1 for SD1/2, 2 for SDXL " - + "/ CLIPの後ろからn層目の出力を使う(デフォルトはSD1/2の場合1、SDXLの場合2)", - ) - parser.add_argument( - "--max_embeddings_multiples", - type=int, - default=None, - help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", - ) - parser.add_argument( - "--emb_normalize_mode", - type=str, - default="original", - choices=["original", "none", "abs"], - help="embedding normalization mode / embeddingの正規化モード", - ) - parser.add_argument( - "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" - ) - parser.add_argument( - "--highres_fix_scale", - type=float, - default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", - ) - parser.add_argument( - "--highres_fix_steps", - type=int, - default=28, - help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", - ) - parser.add_argument( - "--highres_fix_strength", - type=float, - default=None, - help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", - ) - parser.add_argument( - "--highres_fix_save_1st", - action="store_true", - help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", - ) - parser.add_argument( - "--highres_fix_latents_upscaling", - action="store_true", - help="use latents upscaling for highres fix / highres fixでlatentで拡大する", - ) - parser.add_argument( - "--highres_fix_upscaler", - type=str, - default=None, - help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", - ) - parser.add_argument( - "--highres_fix_upscaler_args", - type=str, - default=None, - help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", - ) - parser.add_argument( - "--highres_fix_disable_control_net", - action="store_true", - help="disable ControlNet for highres fix / highres fixでControlNetを使わない", - ) - - parser.add_argument( - "--negative_scale", - type=float, - default=None, - help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", - ) - - parser.add_argument( - "--control_net_lllite_models", - type=str, - default=None, - nargs="*", - help="ControlNet models to use / 使用するControlNetのモデル名", - ) - parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", - type=str, - default=None, - nargs="*", - help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", - ) - parser.add_argument( - "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" - ) - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - ) - parser.add_argument( - "--clip_vision_strength", - type=float, - default=None, - help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", - ) - - # Deep Shrink - parser.add_argument( - "--ds_depth_1", - type=int, - default=None, - help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", - ) - parser.add_argument( - "--ds_timesteps_1", - type=int, - default=650, - help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", - ) - parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") - parser.add_argument( - "--ds_timesteps_2", - type=int, - default=650, - help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", - ) - parser.add_argument( - "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" - ) - - # gradual latent - parser.add_argument( - "--gradual_latent_timesteps", - type=int, - default=None, - help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", - ) - parser.add_argument( - "--gradual_latent_ratio", - type=float, - default=0.5, - help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", - ) - parser.add_argument( - "--gradual_latent_ratio_step", - type=float, - default=0.125, - help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", - ) - parser.add_argument( - "--gradual_latent_every_n_steps", - type=int, - default=3, - help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", - ) - parser.add_argument( - "--gradual_latent_s_noise", - type=float, - default=1.0, - help="s_noise for Gradual Latent / Gradual Latentのs_noise", - ) - parser.add_argument( - "--gradual_latent_unsharp_params", - type=str, - default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", - ) - - # # parser.add_argument( - # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" - # ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - main(args) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py deleted file mode 100644 index 2c5f84a93..000000000 --- a/gen_img_diffusers.py +++ /dev/null @@ -1,3866 +0,0 @@ -""" -VGG( - (features): Sequential( - (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (1): ReLU(inplace=True) - (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (3): ReLU(inplace=True) - (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) - (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (6): ReLU(inplace=True) - (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (8): ReLU(inplace=True) - (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) - (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (11): ReLU(inplace=True) - (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (13): ReLU(inplace=True) - (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (15): ReLU(inplace=True) - (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) - (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (18): ReLU(inplace=True) - (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (20): ReLU(inplace=True) - (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (22): ReLU(inplace=True) - (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) - (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (25): ReLU(inplace=True) - (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (27): ReLU(inplace=True) - (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - (29): ReLU(inplace=True) - (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) - ) - (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) - (classifier): Sequential( - (0): Linear(in_features=25088, out_features=4096, bias=True) - (1): ReLU(inplace=True) - (2): Dropout(p=0.5, inplace=False) - (3): Linear(in_features=4096, out_features=4096, bias=True) - (4): ReLU(inplace=True) - (5): Dropout(p=0.5, inplace=False) - (6): Linear(in_features=4096, out_features=1000, bias=True) - ) -) -""" - -import itertools -import json -from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable -import glob -import importlib -import inspect -import time -import zipfile -from diffusers.utils import deprecate -from diffusers.configuration_utils import FrozenDict -import argparse -import math -import os -import random -import re - -import diffusers -import numpy as np - -import torch -from library.device_utils import init_ipex, clean_memory, get_preferred_device -init_ipex() - -import torchvision -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - # UNet2DConditionModel, - StableDiffusionPipeline, -) -from einops import rearrange -from tqdm import tqdm -from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig -import PIL -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import library.model_util as model_util -import library.train_util as train_util -from networks.lora import LoRANetwork -import tools.original_control_net as original_control_net -from tools.original_control_net import ControlNetInfo -from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel -from library.original_unet import FlashAttentionFunction -from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL - -from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - -# その他の設定 -LATENT_CHANNELS = 4 -DOWNSAMPLING_FACTOR = 8 - -# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336" - -# CLIP guided SD関連 -CLIP_MODEL_PATH = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" -FEATURE_EXTRACTOR_SIZE = (224, 224) -FEATURE_EXTRACTOR_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073] -FEATURE_EXTRACTOR_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711] - -VGG16_IMAGE_MEAN = [0.485, 0.456, 0.406] -VGG16_IMAGE_STD = [0.229, 0.224, 0.225] -VGG16_INPUT_RESIZE_DIV = 4 - -# CLIP特徴量の取得時にcutoutを使うか:使う場合にはソースを書き換えてください -NUM_CUTOUTS = 4 -USE_CUTOUTS = False - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - logger.info("Enable memory efficient attention for U-Net") - - # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - logger.info("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - logger.info("Enable SDPA for U-Net") - unet.set_use_memory_efficient_attention(False, False) - unet.set_use_sdpa(True) - - -# TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - replace_vae_attn_to_xformers() - elif sdpa: - replace_vae_attn_to_sdpa() - - -def replace_vae_attn_to_memory_efficient(): - logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_flash_attn(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_flash_attn - - -def replace_vae_attn_to_xformers(): - logger.info("VAE: Attention.forward has been replaced to xformers") - import xformers.ops - - def forward_xformers(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_xformers_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_xformers(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_xformers - - -def replace_vae_attn_to_sdpa(): - logger.info("VAE: Attention.forward has been replaced to sdpa") - - def forward_sdpa(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = torch.nn.functional.scaled_dot_product_attention( - query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - out = rearrange(out, "b n h d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_sdpa_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_sdpa(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_sdpa - - -# endregion - -# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 -# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 - - -class PipelineLike: - r""" - Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing - weighting in prompt. - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: InferUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - clip_model: CLIPModel, - clip_guidance_scale: float, - clip_image_guidance_scale: float, - vgg16_model: torchvision.models.VGG, - vgg16_guidance_scale: float, - vgg16_layer_no: int, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.vae = vae - self.text_encoder = text_encoder - self.tokenizer = tokenizer - self.unet = unet - self.scheduler = scheduler - self.safety_checker = None - - # Textual Inversion - self.token_replacements = {} - - # XTI - self.token_replacements_XTI = {} - - # CLIP guidance - self.clip_guidance_scale = clip_guidance_scale - self.clip_image_guidance_scale = clip_image_guidance_scale - self.clip_model = clip_model - self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) - self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) - - # VGG16 guidance - self.vgg16_guidance_scale = vgg16_guidance_scale - if self.vgg16_guidance_scale > 0.0: - return_layers = {f"{vgg16_layer_no}": "feat"} - self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter( - vgg16_model.features, return_layers=return_layers - ) - self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) - - # ControlNet - self.control_nets: List[ControlNetInfo] = [] - self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - - self.gradual_latent: GradualLatent = None - - # Textual Inversion - def add_token_replacement(self, target_token_id, rep_token_ids): - self.token_replacements[target_token_id] = rep_token_ids - - def set_enable_control_net(self, en: bool): - self.control_net_enabled = en - - def replace_token(self, tokens, layer=None): - new_tokens = [] - for token in tokens: - if token in self.token_replacements: - replacer_ = self.token_replacements[token] - if layer: - replacer = [] - for r in replacer_: - if r in self.token_replacements_XTI: - replacer.append(self.token_replacements_XTI[r][layer]) - else: - replacer = replacer_ - new_tokens.extend(replacer) - else: - new_tokens.append(token) - return new_tokens - - def add_token_replacement_XTI(self, target_token_id, rep_token_ids): - self.token_replacements_XTI[target_token_id] = rep_token_ids - - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets - - def set_gradual_latent(self, gradual_latent): - if gradual_latent is None: - print("gradual_latent is disabled") - self.gradual_latent = None - else: - print(f"gradual_latent is enabled: {gradual_latent}") - self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) - - # region xformersとか使う部分:独自に書き換えるので関係なし - - def enable_xformers_memory_efficient_attention(self): - r""" - Enable memory efficient attention as implemented in xformers. - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - """ - self.unet.set_use_memory_efficient_attention_xformers(True) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.unet.set_use_memory_efficient_attention_xformers(False) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - - def enable_sequential_cpu_offload(self): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - # accelerateが必要になるのでとりあえず省略 - raise NotImplementedError("cpu_offload is omitted.") - # if is_accelerate_available(): - # from accelerate import cpu_offload - # else: - # raise ImportError("Please install accelerate via `pip install accelerate`") - - # device = self.device - - # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: - # if cpu_offloaded_model is not None: - # cpu_offload(cpu_offloaded_model, device) - - # endregion - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_prompts=None, - clip_guide_images=None, - networks: Optional[List[LoRANetwork]] = None, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - `None` if cancelled by `is_cancelled_callback`, - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - num_images_per_prompt = 1 # fixed - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - reginonal_network = " AND " in prompt[0] - - vae_batch_size = ( - batch_size - if vae_batch_size is None - else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - ) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - if not self.token_replacements_XTI: - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""] * batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) - - if self.token_replacements_XTI: - text_embeddings_concat = [] - for layer in [ - "IN01", - "IN02", - "IN04", - "IN05", - "IN07", - "IN08", - "MID", - "OUT03", - "OUT04", - "OUT05", - "OUT06", - "OUT07", - "OUT08", - "OUT09", - "OUT10", - "OUT11", - ]: - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - layer=layer, - **kwargs, - ) - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) - else: - text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) - text_embeddings = torch.stack(text_embeddings_concat) - else: - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - # CLIP guidanceで使用するembeddingsを取得する - if self.clip_guidance_scale > 0: - clip_text_input = prompt_tokens - if clip_text_input.shape[1] > self.tokenizer.model_max_length: - # TODO 75文字を超えたら警告を出す? - logger.info(f"trim text input {clip_text_input.shape}") - clip_text_input = torch.cat( - [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 - ) - logger.info(f"trimmed {clip_text_input.shape}") - - for i, clip_prompt in enumerate(clip_prompts): - if clip_prompt is not None: # clip_promptがあれば上書きする - clip_text_input[i] = self.tokenizer( - clip_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ).input_ids.to(self.device) - - text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) - text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK - - if ( - self.clip_image_guidance_scale > 0 - or self.vgg16_guidance_scale > 0 - and clip_guide_images is not None - or self.control_nets - ): - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - - if self.clip_image_guidance_scale > 0: - clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images, dim=0) - - clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) - image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - if len(image_embeddings_clip) == 1: - image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) - elif self.vgg16_guidance_scale > 0: - size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) - clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images, dim=0) - - clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) - image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)["feat"] - if len(image_embeddings_vgg16) == 1: - image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) - else: - # ControlNetのhintにguide imageを流用する - # 前処理はControlNet側で行う - pass - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) - - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[-2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - clean_memory() - init_latents = [] - for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): - init_latent_dist = self.vae.encode( - init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) - ).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) - - init_latents = 0.18215 * init_latents - - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - - if reginonal_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - # last subprompt and negative prompt - text_emb_last = [] - for j in range(batch_size): - text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 2]) - text_emb_last.append(text_embeddings[(j + 1) * num_sub_and_neg_prompts - 1]) - text_emb_last = torch.stack(text_emb_last) - else: - text_emb_last = text_embeddings - - enable_gradual_latent = False - if self.gradual_latent: - if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) - else: - enable_gradual_latent = True - step_elapsed = 1000 - current_ratio = self.gradual_latent.ratio - - # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする - height, width = latents.shape[-2:] - org_dtype = latents.dtype - if org_dtype == torch.bfloat16: - latents = latents.float() - latents = torch.nn.functional.interpolate( - latents, scale_factor=current_ratio, mode="bicubic", align_corners=False - ).to(org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - if self.gradual_latent.gaussian_blur_ksize: - latents = self.gradual_latent.apply_unshark_mask(latents) - - for i, t in enumerate(tqdm(timesteps)): - resized_size = None - if enable_gradual_latent: - # gradually upscale the latents / latentsを徐々にアップスケールする - if ( - t < self.gradual_latent.start_timesteps - and current_ratio < 1.0 - and step_elapsed >= self.gradual_latent.every_n_steps - ): - current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) - # make divisible by 8 because size of latents must be divisible at bottom of UNet - h = int(height * current_ratio) // 8 * 8 - w = int(width * current_ratio) // 8 * 8 - resized_size = (h, w) - self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) - step_elapsed = 0 - else: - self.scheduler.set_gradual_latent_params(None, None) - step_elapsed += 1 - - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - if self.control_nets and self.control_net_enabled: - noise_pred = original_control_net.call_unet_and_control_net( - i, - num_latent_input, - self.unet, - self.control_nets, - guided_hints, - i / len(timesteps), - latent_model_input, - t, - text_embeddings, - text_emb_last, - ).sample - else: - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( - num_latent_input - ) # uncond is real uncond - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - - negative_scale * (noise_pred_negative - noise_pred_uncond) - ) - - # perform clip guidance - if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: - text_embeddings_for_guidance = ( - text_embeddings.chunk(num_latent_input)[1] if do_classifier_free_guidance else text_embeddings - ) - - if self.clip_guidance_scale > 0: - noise_pred, latents = self.cond_fn( - latents, - t, - i, - text_embeddings_for_guidance, - noise_pred, - text_embeddings_clip, - self.clip_guidance_scale, - NUM_CUTOUTS, - USE_CUTOUTS, - ) - if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: - noise_pred, latents = self.cond_fn( - latents, - t, - i, - text_embeddings_for_guidance, - noise_pred, - image_embeddings_clip, - self.clip_image_guidance_scale, - NUM_CUTOUTS, - USE_CUTOUTS, - ) - if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: - noise_pred, latents = self.cond_fn_vgg16( - latents, t, i, text_embeddings_for_guidance, noise_pred, image_embeddings_vgg16, self.vgg16_guidance_scale - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - if return_latents: - return (latents, False) - - latents = 1 / 0.18215 * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents).sample - else: - clean_memory() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append( - self.vae.decode(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample - ) - image = torch.cat(images) - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, - clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), - ) - else: - has_nsfw_concept = None - - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - # if not return_dict: - return (image, has_nsfw_concept) - - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def text2img( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, - **kwargs, - ) - - def img2img( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for image-to-image generation. - Args: - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - init_image=init_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, - **kwargs, - ) - - def inpaint( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for inpaint. - Args: - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - init_image=init_image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, - **kwargs, - ) - - # CLIP guidance StableDiffusion - # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py - - # バッチを分解して1件ずつ処理する - def cond_fn( - self, - latents, - timestep, - index, - text_embeddings, - noise_pred_original, - guide_embeddings_clip, - clip_guidance_scale, - num_cutouts, - use_cutouts=True, - ): - if len(latents) == 1: - return self.cond_fn1( - latents, - timestep, - index, - text_embeddings, - noise_pred_original, - guide_embeddings_clip, - clip_guidance_scale, - num_cutouts, - use_cutouts, - ) - - noise_pred = [] - cond_latents = [] - for i in range(len(latents)): - lat1 = latents[i].unsqueeze(0) - tem1 = text_embeddings[i].unsqueeze(0) - npo1 = noise_pred_original[i].unsqueeze(0) - gem1 = guide_embeddings_clip[i].unsqueeze(0) - npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) - noise_pred.append(npr1) - cond_latents.append(cla1) - - noise_pred = torch.cat(noise_pred) - cond_latents = torch.cat(cond_latents) - return noise_pred, cond_latents - - @torch.enable_grad() - def cond_fn1( - self, - latents, - timestep, - index, - text_embeddings, - noise_pred_original, - guide_embeddings_clip, - clip_guidance_scale, - num_cutouts, - use_cutouts=True, - ): - latents = latents.detach().requires_grad_() - - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents - - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) - - if use_cutouts: - image = self.make_cutouts(image, num_cutouts) - else: - image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) - image = self.normalize(image).to(latents.dtype) - - image_embeddings_clip = self.clip_model.get_image_features(image) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - - if use_cutouts: - dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) - dists = dists.view([num_cutouts, sample.shape[0], -1]) - loss = dists.sum(2).mean(0).sum() * clip_guidance_scale - else: - # バッチサイズが複数だと正しく動くかわからない - loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale - - grads = -torch.autograd.grad(loss, latents)[0] - - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents - - # バッチを分解して一件ずつ処理する - def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): - if len(latents) == 1: - return self.cond_fn_vgg16_b1( - latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale - ) - - noise_pred = [] - cond_latents = [] - for i in range(len(latents)): - lat1 = latents[i].unsqueeze(0) - tem1 = text_embeddings[i].unsqueeze(0) - npo1 = noise_pred_original[i].unsqueeze(0) - gem1 = guide_embeddings[i].unsqueeze(0) - npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) - noise_pred.append(npr1) - cond_latents.append(cla1) - - noise_pred = torch.cat(noise_pred) - cond_latents = torch.cat(cond_latents) - return noise_pred, cond_latents - - # 1件だけ処理する - @torch.enable_grad() - def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): - latents = latents.detach().requires_grad_() - - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents - - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) - image = self.vgg16_normalize(image).to(latents.dtype) - - image_embeddings = self.vgg16_feat_model(image)["feat"] - - # バッチサイズが複数だと正しく動くかわからない - loss = ( - (image_embeddings - guide_embeddings) ** 2 - ).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので - - grads = -torch.autograd.grad(loss, latents)[0] - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents - - -class MakeCutouts(torch.nn.Module): - def __init__(self, cut_size, cut_power=1.0): - super().__init__() - - self.cut_size = cut_size - self.cut_power = cut_power - - def forward(self, pixel_values, num_cutouts): - sideY, sideX = pixel_values.shape[2:4] - max_size = min(sideX, sideY) - min_size = min(sideX, sideY, self.cut_size) - cutouts = [] - for _ in range(num_cutouts): - size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) - offsetx = torch.randint(0, sideX - size + 1, ()) - offsety = torch.randint(0, sideY - size + 1, ()) - cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] - cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) - return torch.cat(cutouts) - - -def spherical_dist_loss(x, y): - x = torch.nn.functional.normalize(x, dim=-1) - y = torch.nn.functional.normalize(y, dim=-1) - return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - # keep break as separate token - text = text.replace("BREAK", "\\BREAK\\") - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - if word.strip() == "BREAK": - # pad until next multiple of tokenizer's max token length - pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) - logger.info(f"BREAK pad_len: {pad_len}") - for i in range(pad_len): - # v2のときEOSをつけるべきかどうかわからないぜ - # if i == 0: - # text_token.append(pipe.tokenizer.eos_token_id) - # else: - text_token.append(pipe.tokenizer.pad_token_id) - text_weight.append(1.0) - continue - - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] - - token = pipe.replace_token(token, layer=layer) - - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - pipe: PipelineLike, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - if clip_skip is None or clip_skip == 1: - text_embedding = pipe.text_encoder(text_input_chunk)[0] - else: - enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-clip_skip] - text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = pipe.text_encoder(text_input)[0] - else: - enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-clip_skip] - text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings - - -def get_weighted_text_embeddings( - pipe: PipelineLike, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, - layer=None, - **kwargs, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - Args: - pipe (`DiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `1`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - # split the prompts with "AND". each prompt must have the same number of splits - new_prompts = [] - for p in prompt: - new_prompts.extend(p.split(" AND ")) - prompt = new_prompts - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer) - else: - prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - pad = pipe.tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( - pipe, - uncond_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings, prompt_tokens - return text_embeddings, None, prompt_tokens - - -def preprocess_guide_image(image): - image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnと合わせる - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 - - -# VGG16の入力は任意サイズでよいので入力画像を適宜リサイズする -def preprocess_vgg16_guide_image(image, size): - image = image.resize(size, resample=Image.NEAREST) # cond_fnと合わせる - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -# regular expression for dynamic prompt: -# starts and ends with "{" and "}" -# contains at least one variant divided by "|" -# optional framgments divided by "$$" at start -# if the first fragment is "E" or "e", enumerate all variants -# if the second fragment is a number or two numbers, repeat the variants in the range -# if the third fragment is a string, use it as a separator - -RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") - - -def handle_dynamic_prompt_variants(prompt, repeat_count): - founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) - if not founds: - return [prompt] - - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating - - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") - - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] - else: - logger.warning(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) - - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values - - return replacer - - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] - - return replacer - - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts - - -# endregion - - -# def load_clip_l14_336(dtype): -# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - -class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any - raw_prompt: str - - -class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] - num_sub_prompts: int - - -class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt - - -def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - - # モデルを読み込む - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - use_stable_diffusion_format = os.path.isfile(args.ckpt) - if use_stable_diffusion_format: - logger.info("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) - else: - logger.info("load Diffusers pretrained models") - loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - text_encoder = loading_pipe.text_encoder - vae = loading_pipe.vae - unet = loading_pipe.unet - tokenizer = loading_pipe.tokenizer - del loading_pipe - - # Diffusers U-Net to original U-Net - original_unet = UNet2DConditionModel( - unet.config.sample_size, - unet.config.attention_head_dim, - unet.config.cross_attention_dim, - unet.config.use_linear_projection, - unet.config.upcast_attention, - ) - original_unet.load_state_dict(unet.state_dict()) - unet = original_unet - unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, dtype) - logger.info("additional VAE loaded") - - # # 置換するCLIPを読み込む - # if args.replace_clip_l14_336: - # text_encoder = load_clip_l14_336(dtype) - # logger.info(f"large clip {CLIP_ID_L14_336} is loaded") - - if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - logger.info("prepare clip model") - clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) - else: - clip_model = None - - if args.vgg16_guidance_scale > 0.0: - logger.info("prepare resnet model") - vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) - else: - vgg16_model = None - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - mem_eff = not (args.xformers or args.sdpa) - replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) - replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) - - # tokenizerを読み込む - logger.info("loading tokenizer") - if use_stable_diffusion_format: - tokenizer = train_util.load_tokenizer(args) - - # schedulerを用意する - sched_init_args = {} - scheduler_num_noises_per_step = 1 - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - elif args.sampler == "lms" or args.sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - elif args.sampler == "euler" or args.sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteSchedulerGL - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}") - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == "randn": - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - logger.info("set clip_sample to True") - scheduler.config.clip_sample = True - - # deviceを決定する - device = get_preferred_device() - - # custom pipelineをコピったやつを生成する - if args.vae_slices: - from library.slicing_vae import SlicingAutoencoderKL - - sli_vae = SlicingAutoencoderKL( - act_fn="silu", - block_out_channels=(128, 256, 512, 512), - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], - in_channels=3, - latent_channels=4, - layers_per_block=2, - norm_num_groups=32, - out_channels=3, - sample_size=512, - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], - num_slices=args.vae_slices, - ) - sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする - vae = sli_vae - del sli_vae - vae.to(dtype).to(device) - vae.eval() - - text_encoder.to(dtype).to(device) - unet.to(dtype).to(device) - - text_encoder.eval() - unet.eval() - - if clip_model is not None: - clip_model.to(dtype).to(device) - clip_model.eval() - if vgg16_model is not None: - vgg16_model.to(dtype).to(device) - vgg16_model.eval() - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - network_pre_calc = args.network_pre_calc - - # merge関連の引数を統合する - if args.network_merge: - network_merge = len(args.network_module) # all networks are merged - elif args.network_merge_n_models: - network_merge = args.network_merge_n_models - else: - network_merge = 0 - - for i, network_module in enumerate(args.network_module): - logger.info(f"import network module: {network_module}") - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights is None or len(args.network_weights) <= i: - raise ValueError("No weight. Weight is required.") - - network_weight = args.network_weights[i] - logger.info(f"load network weights from: {network_weight}") - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - logger.info(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs - ) - if network is None: - return - - mergeable = network.is_mergeable() - if network_merge and not mergeable: - logger.warning("network is not mergiable. ignore merge option.") - - if not mergeable or i >= network_merge: - # not merging - network.apply_to(text_encoder, unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - logger.info(f"weights are loaded: {info}") - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - if network_pre_calc: - logger.info("backup original weights") - network.backup_weights() - - networks.append(network) - network_default_muls.append(network_mul) - else: - network.merge_to(text_encoder, unet, weights_sd, dtype, device) - - else: - networks = [] - - # upscalerの指定があれば取得する - upscaler = None - if args.highres_fix_upscaler: - logger.info(f"import upscaler module {args.highres_fix_upscaler}") - imported_module = importlib.import_module(args.highres_fix_upscaler) - - us_kwargs = {} - if args.highres_fix_upscaler_args: - for net_arg in args.highres_fix_upscaler_args.split(";"): - key, value = net_arg.split("=") - us_kwargs[key] = value - - logger.info("create upscaler") - upscaler = imported_module.create_upscaler(**us_kwargs) - upscaler.to(dtype).to(device) - - # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) - - if args.opt_channels_last: - logger.info(f"set optimizing: channels last") - text_encoder.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if clip_model is not None: - clip_model.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - if vgg16_model is not None: - vgg16_model.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.unet.to(memory_format=torch.channels_last) - cn.net.to(memory_format=torch.channels_last) - - pipe = PipelineLike( - device, - vae, - text_encoder, - tokenizer, - unet, - scheduler, - args.clip_skip, - clip_model, - args.clip_guidance_scale, - args.clip_image_guidance_scale, - vgg16_model, - args.vgg16_guidance_scale, - args.vgg16_guidance_layer, - ) - pipe.set_control_nets(control_nets) - logger.info("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Deep Shrink - if args.ds_depth_1 is not None: - unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) - - # Gradual Latent - if args.gradual_latent_timesteps is not None: - if args.gradual_latent_unsharp_params: - us_params = args.gradual_latent_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] - us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - - gradual_latent = GradualLatent( - args.gradual_latent_ratio, - args.gradual_latent_timesteps, - args.gradual_latent_every_n_steps, - args.gradual_latent_ratio_step, - args.gradual_latent_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # Extended Textual Inversion および Textual Inversionを処理する - if args.XTI_embeddings: - diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI - - if args.textual_inversion_embeddings: - token_ids_embeds = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - if "string_to_param" in data: - data = data["string_to_param"] - embeds = next(iter(data.values())) - - if type(embeds) != torch.Tensor: - raise ValueError( - f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}" - ) - - num_vectors_per_token = embeds.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == num_vectors_per_token - ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") - assert ( - min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 - ), f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(token_ids[0], token_ids) - - token_ids_embeds.append((token_ids, embeds)) - - text_encoder.resize_token_embeddings(len(tokenizer)) - token_embeds = text_encoder.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds: - for token_id, embed in zip(token_ids, embeds): - token_embeds[token_id] = embed - - if args.XTI_embeddings: - XTI_layers = [ - "IN01", - "IN02", - "IN04", - "IN05", - "IN07", - "IN08", - "MID", - "OUT03", - "OUT04", - "OUT05", - "OUT06", - "OUT07", - "OUT08", - "OUT09", - "OUT10", - "OUT11", - ] - token_ids_embeds_XTI = [] - for embeds_file in args.XTI_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - if set(data.keys()) != set(XTI_layers): - raise ValueError("NOT XTI") - embeds = torch.concat(list(data.values())) - num_vectors_per_token = data["MID"].size()[0] - - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == num_vectors_per_token - ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") - - # if num_vectors_per_token > 1: - pipe.add_token_replacement(token_ids[0], token_ids) - - token_strings_XTI = [] - for layer_name in XTI_layers: - token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] - tokenizer.add_tokens(token_strings_XTI) - token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - token_ids_embeds_XTI.append((token_ids_XTI, embeds)) - for t in token_ids: - t_XTI_dic = {} - for i, layer_name in enumerate(XTI_layers): - t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens - pipe.add_token_replacement_XTI(t, t_XTI_dic) - - text_encoder.resize_token_embeddings(len(tokenizer)) - token_embeds = text_encoder.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds_XTI: - for token_id, embed in zip(token_ids, embeds): - token_embeds[token_id] = embed - - # promptを取得する - if args.from_file is not None: - logger.info(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] - else: - paths = ( - glob.glob(os.path.join(path, "*.png")) - + glob.glob(os.path.join(path, "*.jpg")) - + glob.glob(os.path.join(path, "*.jpeg")) - + glob.glob(os.path.join(path, "*.webp")) - ) - paths.sort() - - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - logger.info(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) - - return images - - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, "filename"): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized - - if args.image_path is not None: - logger.info(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - logger.info(f"loaded {len(init_images)} images for img2img") - else: - init_images = None - - if args.mask_path is not None: - logger.info(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - logger.info(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None - - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - logger.info("get prompts from images' meta data") - for img in init_images: - if "prompt" in img.text: - prompt = img.text["prompt"] - if "negative-prompt" in img.text: - prompt += " --n " + img.text["negative-prompt"] - prompt_list.append(prompt) - - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l - - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l - - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - # highres fix を考慮に入れる - w, h = args.W, args.H - if highres_fix: - w = int(w * args.highres_fix_scale + 0.5) - h = int(h * args.highres_fix_scale + 0.5) - - if init_images is not None: - logger.info(f"resize img2img source images to {w}*{h}") - init_images = resize_images(init_images, (w, h)) - if mask_images is not None: - logger.info(f"resize img2img mask images to {w}*{h}") - mask_images = resize_images(mask_images, (w, h)) - - regional_network = False - if networks and mask_images: - # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 - regional_network = True - logger.info("use mask as region") - - size = None - for i, network in enumerate(networks): - if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: - np_mask = np.array(mask_images[0]) - - if args.network_regional_mask_max_color_codes: - # カラーコードでマスクを指定する - ch0 = (i + 1) & 1 - ch1 = ((i + 1) >> 1) & 1 - ch2 = ((i + 1) >> 2) & 1 - np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) - np_mask = np_mask.astype(np.uint8) * 255 - else: - np_mask = np_mask[:, :, i] - size = np_mask.shape - else: - np_mask = np.full(size, 255, dtype=np.uint8) - mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) - network.set_region(i, i == len(networks) - 1, mask) - mask_images = None - - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) - - logger.info(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - logger.info( - f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" - ) - guide_images = None - else: - guide_images = None - - # seed指定時はseedを決めておく - if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None - - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 512 - if args.H is None: - args.H = 512 - - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - - for gen_iter in range(args.n_iter): - logger.info(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) - - # shuffle prompt list - if args.shuffle_prompts: - random.shuffle(prompt_list) - - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) - - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - - logger.info("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - width_1st = int(ext.width * args.highres_fix_scale + 0.5) - height_1st = int(ext.height * args.highres_fix_scale + 0.5) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 - - strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength - - ext_1st = BatchDataExt( - width_1st, - height_1st, - args.highres_fix_steps, - ext.scale, - ext.negative_scale, - strength_1st, - ext.network_muls, - ext.num_sub_prompts, - ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) - - pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) - - # 2nd stageのバッチを作成して以下処理する - logger.info("process 2nd stage") - width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height - - if upscaler: - # upscalerを使って画像を拡大する - lowreso_imgs = None if is_1st_latent else images_1st - lowreso_latents = None if not is_1st_latent else images_1st - - # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents - batch_size = len(images_1st) - vae_batch_size = ( - batch_size - if args.vae_batch_size is None - else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) - ) - vae_batch_size = int(vae_batch_size) - images_1st = upscaler.upscale( - vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size - ) - - elif args.highres_fix_latents_upscaling: - # latentを拡大する - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" - ) # , antialias=True) - images_1st = images_1st.to(org_dtype) - - else: - # 画像をLANCZOSで拡大する - images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] - - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd - - if args.highres_fix_disable_control_net: - pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする - - # このバッチの情報を取り出す - ( - return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), - (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), - ) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) - - prompts = [] - negative_prompts = [] - raw_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [ - torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step) - ] - seeds = [] - clip_prompts = [] - - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] - - if mask_image is not None: - mask_images = [] - else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None - - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None - - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, ( - _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), - _, - ) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - raw_prompts.append(raw_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - # 追加ネットワークの処理 - shared = {} - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) - - if not regional_network and network_pre_calc: - for n in networks: - n.restore_weights() - for n in networks: - n.pre_calculation() - logger.info("pre-calculation... done") - - images = pipe( - prompts, - negative_prompts, - init_images, - mask_images, - height, - width, - steps, - scale, - negative_scale, - strength, - latents=start_code, - output_type="pil", - max_embeddings_multiples=max_embeddings_multiples, - img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, - return_latents=return_latents, - clip_prompts=clip_prompts, - clip_guide_images=guide_images, - )[0] - if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images - - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - logger.info( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) - - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - logger.info("") - logger.info("Type prompt:") - try: - raw_prompt = input() - except EOFError: - break - - valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - raw_prompt = prompt_list[prompt_index] - - # sd-dynamic-prompts like variants: - # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - - if pi == 0 or len(raw_prompts) > 1: - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue - - # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.info(f"Exception in parsing / 解析エラー: {parg}") - logger.info(ex) - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - print(unsharp_params) - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - logger.info("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seed = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - logger.info( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: - if prev_image is None: - logger.info("Generate 1st image without guide image.") - else: - logger.info("Use previous image as guide image.") - guide_image = prev_image - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() - - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() - - global_step += 1 - - prompt_index += 1 - - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() - - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - - parser.add_argument( - "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" - ) - parser.add_argument( - "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" - ) - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument( - "--from_file", - type=str, - default=None, - help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", - ) - parser.add_argument( - "--interactive", - action="store_true", - help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", - ) - parser.add_argument( - "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" - ) - parser.add_argument( - "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" - ) - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument( - "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" - ) - parser.add_argument( - "--use_original_file_name", - action="store_true", - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", - ) - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument( - "--vae_batch_size", - type=float, - default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", - ) - parser.add_argument( - "--vae_slices", - type=int, - default=None, - help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", - ) - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument( - "--sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", - ) - parser.add_argument( - "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" - ) - parser.add_argument( - "--vae", - type=str, - default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument( - "--seed", - type=int, - default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", - ) - parser.add_argument( - "--iter_same_seed", - action="store_true", - help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", - ) - parser.add_argument( - "--shuffle_prompts", - action="store_true", - help="shuffle prompts in iteration / 繰り返し内のプロンプトをシャッフルする", - ) - parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") - parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") - parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") - parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") - parser.add_argument( - "--diffusers_xformers", - action="store_true", - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - "--opt_channels_last", - action="store_true", - help="set channels last option to model / モデルにchannels lastを指定し最適化する", - ) - parser.add_argument( - "--network_module", - type=str, - default=None, - nargs="*", - help="additional network module to use / 追加ネットワークを使う時そのモジュール名", - ) - parser.add_argument( - "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" - ) - parser.add_argument( - "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" - ) - parser.add_argument( - "--network_args", - type=str, - default=None, - nargs="*", - help="additional arguments for network (key=value) / ネットワークへの追加の引数", - ) - parser.add_argument( - "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" - ) - parser.add_argument( - "--network_merge_n_models", - type=int, - default=None, - help="merge this number of networks / この数だけネットワークをマージする", - ) - parser.add_argument( - "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" - ) - parser.add_argument( - "--network_pre_calc", - action="store_true", - help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", - ) - parser.add_argument( - "--network_regional_mask_max_color_codes", - type=int, - default=None, - help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", - ) - parser.add_argument( - "--textual_inversion_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", - ) - parser.add_argument( - "--XTI_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", - ) - parser.add_argument( - "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" - ) - parser.add_argument( - "--max_embeddings_multiples", - type=int, - default=None, - help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", - ) - parser.add_argument( - "--clip_guidance_scale", - type=float, - default=0.0, - help="enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしてこのscaleを適用する(サンプラーはDDIM、PNDM、LMSのみ)", - ) - parser.add_argument( - "--clip_image_guidance_scale", - type=float, - default=0.0, - help="enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしてこのscaleを適用する", - ) - parser.add_argument( - "--vgg16_guidance_scale", - type=float, - default=0.0, - help="enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する", - ) - parser.add_argument( - "--vgg16_guidance_layer", - type=int, - default=20, - help="layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)", - ) - parser.add_argument( - "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" - ) - parser.add_argument( - "--highres_fix_scale", - type=float, - default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", - ) - parser.add_argument( - "--highres_fix_steps", - type=int, - default=28, - help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", - ) - parser.add_argument( - "--highres_fix_strength", - type=float, - default=None, - help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", - ) - parser.add_argument( - "--highres_fix_save_1st", - action="store_true", - help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", - ) - parser.add_argument( - "--highres_fix_latents_upscaling", - action="store_true", - help="use latents upscaling for highres fix / highres fixでlatentで拡大する", - ) - parser.add_argument( - "--highres_fix_upscaler", - type=str, - default=None, - help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", - ) - parser.add_argument( - "--highres_fix_upscaler_args", - type=str, - default=None, - help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", - ) - parser.add_argument( - "--highres_fix_disable_control_net", - action="store_true", - help="disable ControlNet for highres fix / highres fixでControlNetを使わない", - ) - - parser.add_argument( - "--negative_scale", - type=float, - default=None, - help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", - ) - - parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", - type=str, - default=None, - nargs="*", - help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名", - ) - parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - ) - # parser.add_argument( - # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" - # ) - - # Deep Shrink - parser.add_argument( - "--ds_depth_1", - type=int, - default=None, - help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする", - ) - parser.add_argument( - "--ds_timesteps_1", - type=int, - default=650, - help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", - ) - parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") - parser.add_argument( - "--ds_timesteps_2", - type=int, - default=650, - help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", - ) - parser.add_argument( - "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" - ) - - # gradual latent - parser.add_argument( - "--gradual_latent_timesteps", - type=int, - default=None, - help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", - ) - parser.add_argument( - "--gradual_latent_ratio", - type=float, - default=0.5, - help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", - ) - parser.add_argument( - "--gradual_latent_ratio_step", - type=float, - default=0.125, - help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", - ) - parser.add_argument( - "--gradual_latent_every_n_steps", - type=int, - default=3, - help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", - ) - parser.add_argument( - "--gradual_latent_s_noise", - type=float, - default=1.0, - help="s_noise for Gradual Latent / Gradual Latentのs_noise", - ) - parser.add_argument( - "--gradual_latent_unsharp_params", - type=str, - default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - setup_logging(args, reset=True) - main(args) diff --git a/gui.sh b/gui.sh index 022335125..50704dbc8 100755 --- a/gui.sh +++ b/gui.sh @@ -10,6 +10,19 @@ env_var_exists() { fi } +# Define the directory path for WSL2 +lib_path="/usr/lib/wsl/lib/" + +# Check if the directory exists +if [ -d "$lib_path" ]; then + # Check if LD_LIBRARY_PATH is already set + if [ -z "${LD_LIBRARY_PATH}" ]; then + # LD_LIBRARY_PATH is not set, set it to the lib_path + export LD_LIBRARY_PATH="$lib_path" + # echo "LD_LIBRARY_PATH set to: $LD_LIBRARY_PATH" + fi +fi + # Need RUNPOD to have a default value before first access RUNPOD=false if env_var_exists RUNPOD_POD_ID || env_var_exists RUNPOD_API_KEY; then diff --git a/kohya_gui/basic_caption_gui.py b/kohya_gui/basic_caption_gui.py index 23a3c8713..ad5ee6497 100644 --- a/kohya_gui/basic_caption_gui.py +++ b/kohya_gui/basic_caption_gui.py @@ -1,7 +1,7 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path, add_pre_postfix, find_replace, scriptdir +from .common_gui import get_folder_path, add_pre_postfix, find_replace, scriptdir, list_dirs import os import sys @@ -52,7 +52,7 @@ def caption_images( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/tools{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command based on the operating system subprocess.run(run_cmd, shell=True, env=env) @@ -86,64 +86,101 @@ def caption_images( # Gradio UI -def gradio_basic_caption_gui_tab(headless=False): +def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None): + from .common_gui import create_refresh_button + + # Set default images directory if not provided + default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data") + current_images_dir = default_images_dir + + # Function to list directories + def list_images_dirs(path): + # Allows list_images_dirs to modify current_images_dir outside of this function + nonlocal current_images_dir + current_images_dir = path + return list(list_dirs(path)) + + # Gradio tab for basic captioning with gr.Tab('Basic Captioning'): + # Markdown description gr.Markdown( 'This utility allows you to create simple caption files for each image in a folder.' ) - with gr.Row(): - images_dir = gr.Textbox( - label='Image folder to caption', - placeholder='Directory containing the images to caption', + # Group and row for image folder selection + with gr.Group(), gr.Row(): + # Dropdown for image folder + images_dir = gr.Dropdown( + label='Image folder to caption (containing the images to caption)', + choices=list_images_dirs(default_images_dir), + value="", interactive=True, + allow_custom_value=True, ) + # Refresh button for image folder + create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small") + # Button to open folder folder_button = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not headless) ) + # Event handler for button click folder_button.click( get_folder_path, outputs=images_dir, show_progress=False, ) + # Textbox for caption file extension caption_ext = gr.Textbox( label='Caption file extension', placeholder='Extension for caption file (e.g., .caption, .txt)', value='.txt', interactive=True, ) + # Checkbox to overwrite existing captions overwrite = gr.Checkbox( label='Overwrite existing captions in folder', interactive=True, value=False, ) + # Row for caption prefix and text with gr.Row(): + # Textbox for caption prefix prefix = gr.Textbox( label='Prefix to add to caption', placeholder='(Optional)', interactive=True, ) + # Textbox for caption text caption_text = gr.Textbox( label='Caption text', placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.', interactive=True, + lines=2, ) + # Textbox for caption postfix postfix = gr.Textbox( label='Postfix to add to caption', placeholder='(Optional)', interactive=True, ) - with gr.Row(): + # Group and row for find and replace text + with gr.Group(), gr.Row(): + # Textbox for find text find_text = gr.Textbox( label='Find text', placeholder='e.g., "by some artist". Leave empty if you only want to add a prefix or postfix.', interactive=True, + lines=2, ) + # Textbox for replace text replace_text = gr.Textbox( label='Replacement text', placeholder='e.g., "by some artist". Leave empty if you want to replace with nothing.', interactive=True, + lines=2, ) + # Button to caption images caption_button = gr.Button('Caption images') + # Event handler for button click caption_button.click( caption_images, inputs=[ @@ -158,3 +195,11 @@ def gradio_basic_caption_gui_tab(headless=False): ], show_progress=False, ) + + # Event handler for dynamic update of dropdown choices + images_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_images_dirs(path)), + inputs=images_dir, + outputs=images_dir, + show_progress=False, + ) diff --git a/kohya_gui/blip_caption_gui.py b/kohya_gui/blip_caption_gui.py index 983ade63b..225c7fe27 100644 --- a/kohya_gui/blip_caption_gui.py +++ b/kohya_gui/blip_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import os import sys -from .common_gui import get_folder_path, add_pre_postfix, scriptdir +from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs from .custom_logging import setup_logging # Set up logging @@ -25,27 +25,27 @@ def caption_images( postfix, ): # Check if the image folder is provided - if train_data_dir == '': - msgbox('Image folder is missing...') + if train_data_dir == "": + msgbox("Image folder is missing...") return # Check if the caption file extension is provided - if caption_file_ext == '': - msgbox('Please provide an extension for the caption files.') + if caption_file_ext == "": + msgbox("Please provide an extension for the caption files.") return - log.info(f'Captioning files in {train_data_dir}...') + log.info(f"Captioning files in {train_data_dir}...") # Construct the command to run - run_cmd = fr'{PYTHON} "{scriptdir}/finetune/make_captions.py"' + run_cmd = rf'{PYTHON} "{scriptdir}/sd-scripts/finetune/make_captions.py"' run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --num_beams="{int(num_beams)}"' run_cmd += f' --top_p="{top_p}"' run_cmd += f' --max_length="{int(max_length)}"' run_cmd += f' --min_length="{int(min_length)}"' if beam_search: - run_cmd += f' --beam_search' - if caption_file_ext != '': + run_cmd += f" --beam_search" + if caption_file_ext != "": run_cmd += f' --caption_extension="{caption_file_ext}"' run_cmd += f' "{train_data_dir}"' run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"' @@ -53,10 +53,12 @@ def caption_images( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env["PYTHONPATH"] = ( + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + ) - # Run the command - subprocess.run(run_cmd, shell=True, env=env) + # Run the command in the sd-scripts folder context + subprocess.run(run_cmd, shell=True, env=env, cwd=rf"{scriptdir}/sd-scripts") # Add prefix and postfix add_pre_postfix( @@ -66,7 +68,7 @@ def caption_images( postfix=postfix, ) - log.info('...captioning done') + log.info("...captioning done") ### @@ -74,19 +76,44 @@ def caption_images( ### -def gradio_blip_caption_gui_tab(headless=False): - with gr.Tab('BLIP Captioning'): +def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None): + from .common_gui import create_refresh_button + + default_train_dir = ( + default_train_dir + if default_train_dir is not None + else os.path.join(scriptdir, "data") + ) + current_train_dir = default_train_dir + + def list_train_dirs(path): + nonlocal current_train_dir + current_train_dir = path + return list(list_dirs(path)) + + with gr.Tab("BLIP Captioning"): gr.Markdown( - 'This utility uses BLIP to caption files for each image in a folder.' + "This utility uses BLIP to caption files for each image in a folder." ) - with gr.Row(): - train_data_dir = gr.Textbox( - label='Image folder to caption', - placeholder='Directory containing the images to caption', + with gr.Group(), gr.Row(): + train_data_dir = gr.Dropdown( + label="Image folder to caption (containing the images to caption)", + choices=list_train_dirs(default_train_dir), + value="", interactive=True, + allow_custom_value=True, + ) + create_refresh_button( + train_data_dir, + lambda: None, + lambda: {"choices": list_train_dirs(current_train_dir)}, + "open_folder_small", ) button_train_data_dir_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + "📂", + elem_id="open_folder_small", + elem_classes=["tool"], + visible=(not headless), ) button_train_data_dir_input.click( get_folder_path, @@ -95,44 +122,36 @@ def gradio_blip_caption_gui_tab(headless=False): ) with gr.Row(): caption_file_ext = gr.Textbox( - label='Caption file extension', - placeholder='Extension for caption file, e.g., .caption, .txt', - value='.txt', + label="Caption file extension", + placeholder="Extension for caption file, e.g., .caption, .txt", + value=".txt", interactive=True, ) prefix = gr.Textbox( - label='Prefix to add to BLIP caption', - placeholder='(Optional)', + label="Prefix to add to BLIP caption", + placeholder="(Optional)", interactive=True, ) postfix = gr.Textbox( - label='Postfix to add to BLIP caption', - placeholder='(Optional)', + label="Postfix to add to BLIP caption", + placeholder="(Optional)", interactive=True, ) - batch_size = gr.Number( - value=1, label='Batch size', interactive=True - ) + batch_size = gr.Number(value=1, label="Batch size", interactive=True) with gr.Row(): beam_search = gr.Checkbox( - label='Use beam search', interactive=True, value=True - ) - num_beams = gr.Number( - value=1, label='Number of beams', interactive=True - ) - top_p = gr.Number(value=0.9, label='Top p', interactive=True) - max_length = gr.Number( - value=75, label='Max length', interactive=True - ) - min_length = gr.Number( - value=5, label='Min length', interactive=True + label="Use beam search", interactive=True, value=True ) + num_beams = gr.Number(value=1, label="Number of beams", interactive=True) + top_p = gr.Number(value=0.9, label="Top p", interactive=True) + max_length = gr.Number(value=75, label="Max length", interactive=True) + min_length = gr.Number(value=5, label="Min length", interactive=True) - caption_button = gr.Button('Caption images') + caption_button = gr.Button("Caption images") caption_button.click( caption_images, @@ -150,3 +169,10 @@ def gradio_blip_caption_gui_tab(headless=False): ], show_progress=False, ) + + train_data_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_train_dirs(path)), + inputs=train_data_dir, + outputs=train_data_dir, + show_progress=False, + ) diff --git a/kohya_gui/class_advanced_training.py b/kohya_gui/class_advanced_training.py index 5c876c4ff..b8d819321 100644 --- a/kohya_gui/class_advanced_training.py +++ b/kohya_gui/class_advanced_training.py @@ -1,13 +1,17 @@ import gradio as gr -from .common_gui import get_folder_path, get_any_file_path +import os +from .common_gui import get_folder_path, get_any_file_path, scriptdir, list_files, list_dirs, create_refresh_button class AdvancedTraining: - def __init__(self, headless=False, finetuning: bool = False, training_type: str = ""): + def __init__(self, headless=False, finetuning: bool = False, training_type: str = "", default_vae_dir: str = "", default_output_dir: str = ""): self.headless = headless self.finetuning = finetuning self.training_type = training_type + current_vae_dir = default_vae_dir if default_vae_dir != "" else os.path.join(scriptdir, "vae") + current_state_dir = default_output_dir if default_output_dir != "" else os.path.join(scriptdir, "outputs") + def noise_offset_type_change(noise_offset_type): if noise_offset_type == 'Original': return ( @@ -36,15 +40,23 @@ def noise_offset_type_change(noise_offset_type): self.weighted_captions = gr.Checkbox( label='Weighted captions', value=False ) - with gr.Row(visible=not finetuning): + with gr.Group(), gr.Row(visible=not finetuning): self.prior_loss_weight = gr.Number( label='Prior loss weight', value=1.0 ) - self.vae = gr.Textbox( - label='VAE', - placeholder='(Optional) path to checkpoint of vae to replace for training', + + def list_vae_files(path): + current_vae_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + self.vae = gr.Dropdown( + label='VAE (Optional. path to checkpoint of vae to replace for training)', interactive=True, + choices=list_vae_files(current_vae_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(self.vae, lambda: None, lambda: {"choices": list_vae_files(current_vae_dir)},"open_folder_small") self.vae_button = gr.Button( '📂', elem_id='open_folder_small', visible=(not headless) ) @@ -54,6 +66,13 @@ def noise_offset_type_change(noise_offset_type): show_progress=False, ) + self.vae.change( + fn=lambda path: gr.Dropdown().update(choices=list_vae_files(path)), + inputs=self.vae, + outputs=self.vae, + show_progress=False, + ) + with gr.Row(): self.additional_parameters = gr.Textbox( label='Additional parameters', @@ -274,14 +293,23 @@ def full_options_update(full_fp16, full_bf16): self.vae_batch_size = gr.Slider( label='VAE batch size', minimum=0, maximum=32, value=0, step=1 ) - with gr.Row(): + with gr.Group(), gr.Row(): self.save_state = gr.Checkbox( label='Save training state', value=False ) - self.resume = gr.Textbox( - label='Resume from saved training state', - placeholder='path to "last-state" state folder to resume from', + + def list_state_dirs(path): + current_state_dir = path + return list(list_dirs(path)) + + self.resume = gr.Dropdown( + label='Resume from saved training state (path to "last-state" state folder)', + choices=list_state_dirs(current_state_dir), + value="", + interactive=True, + allow_custom_value=True, ) + create_refresh_button(self.resume, lambda: None, lambda: {"choices": list_state_dirs(current_state_dir)}, "open_folder_small") self.resume_button = gr.Button( '📂', elem_id='open_folder_small', visible=(not headless) ) @@ -290,6 +318,12 @@ def full_options_update(full_fp16, full_bf16): outputs=self.resume, show_progress=False, ) + self.resume.change( + fn=lambda path: gr.Dropdown().update(choices=list_state_dirs(path)), + inputs=self.resume, + outputs=self.resume, + show_progress=False, + ) # self.max_train_epochs = gr.Textbox( # label='Max train epoch', # placeholder='(Optional) Override number of epoch', diff --git a/kohya_gui/class_basic_training.py b/kohya_gui/class_basic_training.py index 6ce55e174..700313bf8 100644 --- a/kohya_gui/class_basic_training.py +++ b/kohya_gui/class_basic_training.py @@ -53,15 +53,6 @@ def __init__( ], value="fp16", ) - self.save_precision = gr.Dropdown( - label="Save precision", - choices=[ - "float", - "fp16", - "bf16", - ], - value="fp16", - ) self.num_cpu_threads_per_process = gr.Slider( minimum=1, maximum=os.cpu_count(), diff --git a/kohya_gui/class_configuration_file.py b/kohya_gui/class_configuration_file.py index 66b5512cc..890264c11 100644 --- a/kohya_gui/class_configuration_file.py +++ b/kohya_gui/class_configuration_file.py @@ -1,30 +1,60 @@ import gradio as gr +import os + +from .common_gui import list_files class ConfigurationFile: - def __init__(self, headless=False): + def __init__(self, headless=False, output_dir: gr.Dropdown = None): + from .common_gui import create_refresh_button + self.headless = headless - with gr.Accordion('Configuration file', open=False): + self.output_dir = None + + def update_configs(output_dir): + self.output_dir = output_dir + return gr.Dropdown().update(choices=list(list_files(output_dir, exts=[".json"], all=True))) + + def list_configs(path): + self.output_dir = path + return list(list_files(path, exts=[".json"], all=True)) + + with gr.Group(): with gr.Row(): + self.config_file_name = gr.Dropdown( + label='Load/Save Config file', + choices=list_configs(self.output_dir), + value="", + interactive=True, + allow_custom_value=True, + ) + create_refresh_button(self.config_file_name, lambda: None, lambda: {"choices": list_configs(self.output_dir)}, "open_folder_small") self.button_open_config = gr.Button( - 'Open 📂', - elem_id='open_folder', + '📂', + elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) self.button_save_config = gr.Button( - 'Save 💾', - elem_id='open_folder', - ) - self.button_save_as_config = gr.Button( - 'Save as... 💾', - elem_id='open_folder', - visible=(not self.headless), - ) - self.config_file_name = gr.Textbox( - label='', - placeholder="type the configuration file path or use the 'Open' button above to select it...", - interactive=True, + '💾', + elem_id='open_folder_small', + elem_classes=['tool'], ) self.button_load_config = gr.Button( - 'Load 💾', elem_id='open_folder' + '↩️ ', + elem_id='open_folder_small', + elem_classes=['tool'], ) + + self.config_file_name.change( + fn=lambda path: gr.Dropdown().update(choices=list_configs(path)), + inputs=self.config_file_name, + outputs=self.config_file_name, + show_progress=False, + ) + + output_dir.change( + fn=update_configs, + inputs=output_dir, + outputs=self.config_file_name, + ) diff --git a/kohya_gui/class_folders.py b/kohya_gui/class_folders.py index d108201a5..1b9e4aae4 100644 --- a/kohya_gui/class_folders.py +++ b/kohya_gui/class_folders.py @@ -1,30 +1,71 @@ import gradio as gr -from .common_gui import get_folder_path +import os +from .common_gui import get_folder_path, scriptdir, list_dirs class Folders: - def __init__(self, finetune=False, headless=False): + def __init__(self, finetune=False, train_data_dir: gr.Dropdown = None, data_dir=None, output_dir=None, logging_dir=None, headless=False): + from .common_gui import create_refresh_button + self.headless = headless + default_data_dir = data_dir if data_dir is not None else os.path.join(scriptdir, "data") + default_output_dir = output_dir if output_dir is not None else os.path.join(scriptdir, "outputs") + default_logging_dir = logging_dir if logging_dir is not None else os.path.join(scriptdir, "logs") + default_reg_data_dir = default_data_dir + + self.current_data_dir = default_data_dir + self.current_output_dir = default_output_dir + self.current_logging_dir = default_logging_dir + + + if default_data_dir is not None and default_data_dir.strip() != "" and not os.path.exists(default_data_dir): + os.makedirs(default_data_dir, exist_ok=True) + if default_output_dir is not None and default_output_dir.strip() != "" and not os.path.exists(default_output_dir): + os.makedirs(default_output_dir, exist_ok=True) + if default_logging_dir is not None and default_logging_dir.strip() != "" and not os.path.exists(default_logging_dir): + os.makedirs(default_logging_dir, exist_ok=True) + + def list_data_dirs(path): + self.current_data_dir = path + return list(list_dirs(path)) + + def list_output_dirs(path): + self.current_output_dir = path + return list(list_dirs(path)) + + def list_logging_dirs(path): + self.current_logging_dir = path + return list(list_dirs(path)) + with gr.Row(): - self.train_data_dir = gr.Textbox( - label='Image folder', - placeholder='Folder where the training folders containing the images are located', + self.output_dir = gr.Dropdown( + label=f'Output folder to output trained model', + choices=list_output_dirs(default_output_dir), + value="", + interactive=True, + allow_custom_value=True, ) - self.train_data_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', visible=(not self.headless) + create_refresh_button(self.output_dir, lambda: None, lambda: {"choices": list_output_dirs(self.current_output_dir)}, "open_folder_small") + self.output_dir_folder = gr.Button( + '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) ) - self.train_data_dir_folder.click( + self.output_dir_folder.click( get_folder_path, - outputs=self.train_data_dir, + outputs=self.output_dir, show_progress=False, ) - self.reg_data_dir = gr.Textbox( - label='Regularisation folder' if not finetune else 'Train config folder', - placeholder='(Optional) Folder where where the regularization folders containing the images are located' if not finetune else "folder where the training configuration files will be saved", + + self.reg_data_dir = gr.Dropdown( + label='Regularisation folder (Optional. containing reqularization images)' if not finetune else 'Train config folder (Optional. where config files will be saved)', + choices=list_data_dirs(default_reg_data_dir), + value="", + interactive=True, + allow_custom_value=True, ) + create_refresh_button(self.reg_data_dir, lambda: None, lambda: {"choices": list_data_dirs(self.current_data_dir)}, "open_folder_small") self.reg_data_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', visible=(not self.headless) + '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) ) self.reg_data_dir_folder.click( get_folder_path, @@ -32,39 +73,38 @@ def __init__(self, finetune=False, headless=False): show_progress=False, ) with gr.Row(): - self.output_dir = gr.Textbox( - label='Output folder', - placeholder='Folder to output trained model', - ) - self.output_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', visible=(not self.headless) - ) - self.output_dir_folder.click( - get_folder_path, - outputs=self.output_dir, - show_progress=False, - ) - self.logging_dir = gr.Textbox( - label='Logging folder', - placeholder='Optional: enable logging and output TensorBoard log to this folder', + self.logging_dir = gr.Dropdown( + label='Logging folder (Optional. to enable logging and output Tensorboard log)', + choices=list_logging_dirs(default_logging_dir), + value="", + interactive=True, + allow_custom_value=True, ) + create_refresh_button(self.logging_dir, lambda: None, lambda: {"choices": list_logging_dirs(self.current_logging_dir)}, "open_folder_small") self.logging_dir_folder = gr.Button( - '📂', elem_id='open_folder_small', visible=(not self.headless) + '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) ) self.logging_dir_folder.click( get_folder_path, outputs=self.logging_dir, show_progress=False, ) - with gr.Row(): - self.output_name = gr.Textbox( - label='Model output name', - placeholder='(Name of the model to output)', - value='last', - interactive=True, + + self.output_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_output_dirs(path)), + inputs=self.output_dir, + outputs=self.output_dir, + show_progress=False, ) - self.training_comment = gr.Textbox( - label='Training comment', - placeholder='(Optional) Add training comment to be included in metadata', - interactive=True, + self.reg_data_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_data_dirs(path)), + inputs=self.reg_data_dir, + outputs=self.reg_data_dir, + show_progress=False, + ) + self.logging_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_logging_dirs(path)), + inputs=self.logging_dir, + outputs=self.logging_dir, + show_progress=False, ) diff --git a/kohya_gui/class_lora_tab.py b/kohya_gui/class_lora_tab.py index 2f33d9860..bcc1d16e1 100644 --- a/kohya_gui/class_lora_tab.py +++ b/kohya_gui/class_lora_tab.py @@ -17,12 +17,11 @@ class LoRATools: - def __init__(self, folders='', headless: bool = False): + def __init__(self, train_data_dir=None, reg_data_dir=None, output_dir=None, logging_dir=None, headless: bool = False): self.headless = headless - self.folders = folders gr.Markdown( - 'This section provide LoRA tools to help setup your dataset...' + 'This section provide various LoRA tools...' ) gradio_extract_dylora_tab(headless=headless) gradio_convert_lcm_tab(headless=headless) @@ -33,13 +32,3 @@ def __init__(self, folders='', headless: bool = False): gradio_svd_merge_lora_tab(headless=headless) gradio_resize_lora_tab(headless=headless) gradio_verify_lora_tab(headless=headless) - if folders: - with gr.Tab('Dataset Preparation'): - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) - gradio_dataset_balancing_tab(headless=headless) diff --git a/kohya_gui/class_sample_images.py b/kohya_gui/class_sample_images.py index f7ad970d9..8088973aa 100644 --- a/kohya_gui/class_sample_images.py +++ b/kohya_gui/class_sample_images.py @@ -1,4 +1,3 @@ -import tempfile import os import gradio as gr from easygui import msgbox @@ -33,7 +32,7 @@ def run_cmd_sample( run_cmd = '' - if sample_every_n_epochs == 0 and sample_every_n_steps == 0: + if sample_every_n_epochs == sample_every_n_steps == 0: return run_cmd # Create the prompt file and get its path @@ -45,10 +44,10 @@ def run_cmd_sample( run_cmd += f' --sample_sampler={sample_sampler}' run_cmd += f' --sample_prompts="{sample_prompts_path}"' - if not sample_every_n_epochs == 0: + if sample_every_n_epochs != 0: run_cmd += f' --sample_every_n_epochs="{sample_every_n_epochs}"' - if not sample_every_n_steps == 0: + if sample_every_n_steps != 0: run_cmd += f' --sample_every_n_steps="{sample_every_n_steps}"' return run_cmd diff --git a/kohya_gui/class_source_model.py b/kohya_gui/class_source_model.py index 041ed647d..9ba55d82c 100644 --- a/kohya_gui/class_source_model.py +++ b/kohya_gui/class_source_model.py @@ -1,8 +1,13 @@ import gradio as gr +import os + from .common_gui import ( get_any_file_path, get_folder_path, set_pretrained_model_name_or_path_input, + scriptdir, + list_dirs, + list_files, ) folder_symbol = '\U0001f4c2' # 📂 @@ -21,47 +26,71 @@ def __init__( 'diffusers_safetensors', 'safetensors', ], + save_precision_choices=[ + "float", + "fp16", + "bf16", + ], headless=False, + default_data_dir=None, + finetuning=False, ): self.headless = headless self.save_model_as_choices = save_model_as_choices + self.finetuning = finetuning + + default_models = [ + 'stabilityai/stable-diffusion-xl-base-1.0', + 'stabilityai/stable-diffusion-xl-refiner-1.0', + 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + 'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', + ] + + from .common_gui import create_refresh_button + + default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs") + default_train_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "data") + model_checkpoints = list(list_files(default_data_dir, exts=[".ckpt", ".safetensors"], all=True)) + self.current_data_dir = default_data_dir + self.current_train_dir = default_train_dir + + def list_models(path): + self.current_data_dir = path if os.path.isdir(path) else os.path.dirname(path) + return default_models + list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_train_dirs(path): + self.current_train_dir = path if os.path.isdir(path) else os.path.dirname(path) + return list(list_dirs(path)) + + if default_data_dir is not None and default_data_dir.strip() != "" and not os.path.exists(default_data_dir): + os.makedirs(default_data_dir, exist_ok=True) - with gr.Tab('Source model'): + with gr.Column(), gr.Group(): # Define the input elements with gr.Row(): - self.model_list = gr.Dropdown( - label='Model Quick Pick', - choices=[ - 'custom', - 'stabilityai/stable-diffusion-xl-base-1.0', - 'stabilityai/stable-diffusion-xl-refiner-1.0', - 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - 'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - 'runwayml/stable-diffusion-v1-5', - 'CompVis/stable-diffusion-v1-4', - ], - value='runwayml/stable-diffusion-v1-5', - ) - self.save_model_as = gr.Dropdown( - label='Save trained model as', - choices=save_model_as_choices, - value='safetensors', - ) - with gr.Row(): - self.pretrained_model_name_or_path = gr.Textbox( + with gr.Column(), gr.Row(): + self.model_list = gr.Textbox(visible=False, value="") + self.pretrained_model_name_or_path = gr.Dropdown( label='Pretrained model name or path', - placeholder='enter the path to custom model or name of pretrained model', + choices=default_models + model_checkpoints, value='runwayml/stable-diffusion-v1-5', - visible=(False and not headless), + allow_custom_value=True, + visible=True, + min_width=100, ) + create_refresh_button(self.pretrained_model_name_or_path, lambda: None, lambda: {"choices": list_models(self.current_data_dir)},"open_folder_small") + self.pretrained_model_name_or_path_file = gr.Button( document_symbol, elem_id='open_folder_small', - visible=(False and not headless), + elem_classes=['tool'], + visible=(not headless), ) self.pretrained_model_name_or_path_file.click( get_any_file_path, @@ -72,7 +101,8 @@ def __init__( self.pretrained_model_name_or_path_folder = gr.Button( folder_symbol, elem_id='open_folder_small', - visible=(False and not headless), + elem_classes=['tool'], + visible=(not headless), ) self.pretrained_model_name_or_path_folder.click( get_folder_path, @@ -80,34 +110,80 @@ def __init__( outputs=self.pretrained_model_name_or_path, show_progress=False, ) + + with gr.Column(), gr.Row(): + self.train_data_dir = gr.Dropdown( + label='Image folder (containing training images subfolders)' if not finetuning else 'Image folder (containing training images)', + choices=list_train_dirs(default_train_dir), + value="", + interactive=True, + allow_custom_value=True, + ) + create_refresh_button(self.train_data_dir, lambda: None, lambda: {"choices": list_train_dirs(self.current_train_dir)}, "open_folder_small") + self.train_data_dir_folder = gr.Button( + '📂', elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless) + ) + self.train_data_dir_folder.click( + get_folder_path, + outputs=self.train_data_dir, + show_progress=False, + ) + + with gr.Row(): + with gr.Column(): + with gr.Row(): + self.v2 = gr.Checkbox(label='v2', value=False, visible=False, min_width=60) + self.v_parameterization = gr.Checkbox( + label='v_parameterization', value=False, visible=False, min_width=130, + ) + self.sdxl_checkbox = gr.Checkbox( + label='SDXL', value=False, visible=False, min_width=60, + ) + with gr.Column(): + gr.Box(visible=False) + + with gr.Row(): + self.output_name = gr.Textbox( + label='Trained Model output name', + placeholder='(Name of the model to output)', + value='last', + interactive=True, + ) + self.training_comment = gr.Textbox( + label='Training comment', + placeholder='(Optional) Add training comment to be included in metadata', + interactive=True, + ) + with gr.Row(): - self.v2 = gr.Checkbox(label='v2', value=False, visible=False) - self.v_parameterization = gr.Checkbox( - label='v_parameterization', value=False, visible=False + self.save_model_as = gr.Radio( + save_model_as_choices, + label="Save trained model as", + value="safetensors", ) - self.sdxl_checkbox = gr.Checkbox( - label='SDXL Model', value=False, visible=False + self.save_precision = gr.Radio( + save_precision_choices, + label="Save precision", + value="fp16", ) - self.model_list.change( - set_pretrained_model_name_or_path_input, + self.pretrained_model_name_or_path.change( + fn=lambda path: set_pretrained_model_name_or_path_input(path, refresh_method=list_models), inputs=[ - self.model_list, self.pretrained_model_name_or_path, - self.pretrained_model_name_or_path_file, - self.pretrained_model_name_or_path_folder, - self.v2, - self.v_parameterization, - self.sdxl_checkbox, ], outputs=[ - self.model_list, self.pretrained_model_name_or_path, - self.pretrained_model_name_or_path_file, - self.pretrained_model_name_or_path_folder, self.v2, self.v_parameterization, self.sdxl_checkbox, ], show_progress=False, ) + + self.train_data_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_train_dirs(path)), + inputs=self.train_data_dir, + outputs=self.train_data_dir, + show_progress=False, + ) diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 89f0e2749..003c3aa7e 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -21,6 +21,9 @@ scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) +# insert sd-scripts path into PYTHONPATH +sys.path.insert(0, os.path.join(scriptdir, "sd-scripts")) + # define a list of substrings to search for v2 base models V2_BASE_MODELS = [ "stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned", @@ -90,6 +93,129 @@ def output_message(msg="", title="", headless=False): msgbox(msg=msg, title=title) +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + refresh_components = ( + refresh_component + if isinstance(refresh_component, list) + else [refresh_component] + ) + + label = None + for comp in refresh_components: + label = getattr(comp, "label", None) + if label is not None: + break + + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + for comp in refresh_components: + setattr(comp, k, v) + + return ( + [gr.update(**(args or {})) for _ in refresh_components] + if len(refresh_components) > 1 + else gr.update(**(args or {})) + ) + + refresh_button = gr.Button( + value=refresh_symbol, elem_id=elem_id, elem_classes=["tool"] + ) + refresh_button.click(fn=refresh, inputs=[], outputs=refresh_components) + return refresh_button + + +def list_dirs(path): + if path is None or path == "None" or path == "": + return + + if not os.path.exists(path): + path = os.path.dirname(path) + if not os.path.exists(path): + return + + if not os.path.isdir(path): + path = os.path.dirname(path) + + def natural_sort_key(s, regex=re.compile("([0-9]+)")): + return [ + int(text) if text.isdigit() else text.lower() for text in regex.split(s) + ] + + subdirs = [ + (item, os.path.join(path, item)) + for item in os.listdir(path) + if os.path.isdir(os.path.join(path, item)) + ] + subdirs = [ + filename + for item, filename in subdirs + if item[0] != "." and item not in ["__pycache__"] + ] + subdirs = sorted(subdirs, key=natural_sort_key) + if os.path.dirname(path) != "": + dirs = [os.path.dirname(path), path] + subdirs + else: + dirs = [path] + subdirs + + if os.sep == "\\": + dirs = [d.replace("\\", "/") for d in dirs] + for d in dirs: + yield d + + +def list_files(path, exts=None, all=False): + if path is None or path == "None" or path == "": + return + + if not os.path.exists(path): + path = os.path.dirname(path) + if not os.path.exists(path): + return + + if not os.path.isdir(path): + path = os.path.dirname(path) + + files = [ + (item, os.path.join(path, item)) + for item in os.listdir(path) + if all or os.path.isfile(os.path.join(path, item)) + ] + files = [ + filename + for item, filename in files + if item[0] != "." and item not in ["__pycache__"] + ] + exts = set(exts) if exts is not None else None + + def natural_sort_key(s, regex=re.compile("([0-9]+)")): + return [ + int(text) if text.isdigit() else text.lower() for text in regex.split(s) + ] + + files = sorted(files, key=natural_sort_key) + if os.path.dirname(path) != "": + files = [os.path.dirname(path), path] + files + else: + files = [path] + files + + if os.sep == "\\": + files = [d.replace("\\", "/") for d in files] + + for filename in files: + if exts is not None: + if os.path.isdir(filename): + yield filename + _, ext = os.path.splitext(filename) + if ext.lower() not in exts: + continue + yield filename + else: + yield filename + + def update_my_data(my_data): # Update the optimizer based on the use_8bit_adam flag use_8bit_adam = my_data.get("use_8bit_adam", False) @@ -424,7 +550,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): f"Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml" ) shutil.copy( - fr"{scriptdir}/v2_inference/v2-inference-v.yaml", + rf"{scriptdir}/v2_inference/v2-inference-v.yaml", f"{output_dir}/{file_name}.yaml", ) elif v2: @@ -432,123 +558,85 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): f"Saving v2-inference.yaml as {output_dir}/{file_name}.yaml" ) shutil.copy( - fr"{scriptdir}/v2_inference/v2-inference.yaml", + rf"{scriptdir}/v2_inference/v2-inference.yaml", f"{output_dir}/{file_name}.yaml", ) def set_pretrained_model_name_or_path_input( - model_list, - pretrained_model_name_or_path, - pretrained_model_name_or_path_file, - pretrained_model_name_or_path_folder, - v2, - v_parameterization, - sdxl, + pretrained_model_name_or_path, refresh_method=None ): - # Check if the given model_list is in the list of SDXL models - if str(model_list) in SDXL_MODELS: + # Check if the given pretrained_model_name_or_path is in the list of SDXL models + if pretrained_model_name_or_path in SDXL_MODELS: log.info("SDXL model selected. Setting sdxl parameters") - v2 = gr.Checkbox(value=False, visible=False) - v_parameterization = gr.Checkbox(value=False, visible=False) - sdxl = gr.Checkbox(value=True, visible=False) - pretrained_model_name_or_path = gr.Textbox( - value=str(model_list), visible=False - ) - pretrained_model_name_or_path_file = gr.Button(visible=False) - pretrained_model_name_or_path_folder = gr.Button(visible=False) + v2 = gr.Checkbox.update(value=False, visible=False) + v_parameterization = gr.Checkbox.update(value=False, visible=False) + sdxl = gr.Checkbox.update(value=True, visible=False) return ( - model_list, - pretrained_model_name_or_path, - pretrained_model_name_or_path_file, - pretrained_model_name_or_path_folder, + gr.Dropdown().update(), v2, v_parameterization, sdxl, ) - # Check if the given model_list is in the list of V2 base models - if str(model_list) in V2_BASE_MODELS: + # Check if the given pretrained_model_name_or_path is in the list of V2 base models + if pretrained_model_name_or_path in V2_BASE_MODELS: log.info("SD v2 base model selected. Setting --v2 parameter") - v2 = gr.Checkbox(value=True, visible=False) - v_parameterization = gr.Checkbox(value=False, visible=False) - sdxl = gr.Checkbox(value=False, visible=False) - pretrained_model_name_or_path = gr.Textbox( - value=str(model_list), visible=False - ) - pretrained_model_name_or_path_file = gr.Button(visible=False) - pretrained_model_name_or_path_folder = gr.Button(visible=False) + v2 = gr.Checkbox.update(value=True, visible=False) + v_parameterization = gr.Checkbox.update(value=False, visible=False) + sdxl = gr.Checkbox.update(value=False, visible=False) return ( - model_list, - pretrained_model_name_or_path, - pretrained_model_name_or_path_file, - pretrained_model_name_or_path_folder, + gr.Dropdown().update(), v2, v_parameterization, sdxl, ) - # Check if the given model_list is in the list of V parameterization models - if str(model_list) in V_PARAMETERIZATION_MODELS: + # Check if the given pretrained_model_name_or_path is in the list of V parameterization models + if pretrained_model_name_or_path in V_PARAMETERIZATION_MODELS: log.info( "SD v2 model selected. Setting --v2 and --v_parameterization parameters" ) - v2 = gr.Checkbox(value=True, visible=False) - v_parameterization = gr.Checkbox(value=True, visible=False) - sdxl = gr.Checkbox(value=False, visible=False) - pretrained_model_name_or_path = gr.Textbox( - value=str(model_list), visible=False - ) - pretrained_model_name_or_path_file = gr.Button(visible=False) - pretrained_model_name_or_path_folder = gr.Button(visible=False) + v2 = gr.Checkbox.update(value=True, visible=False) + v_parameterization = gr.Checkbox.update(value=True, visible=False) + sdxl = gr.Checkbox.update(value=False, visible=False) return ( - model_list, - pretrained_model_name_or_path, - pretrained_model_name_or_path_file, - pretrained_model_name_or_path_folder, + gr.Dropdown().update(), v2, v_parameterization, sdxl, ) - # Check if the given model_list is in the list of V1 models - if str(model_list) in V1_MODELS: - log.info(f"{model_list} model selected.") - v2 = gr.Checkbox(value=False, visible=False) - v_parameterization = gr.Checkbox(value=False, visible=False) - sdxl = gr.Checkbox(value=False, visible=False) - pretrained_model_name_or_path = gr.Textbox( - value=str(model_list), visible=False - ) - pretrained_model_name_or_path_file = gr.Button(visible=False) - pretrained_model_name_or_path_folder = gr.Button(visible=False) + # Check if the given pretrained_model_name_or_path is in the list of V1 models + if pretrained_model_name_or_path in V1_MODELS: + log.info(f"{pretrained_model_name_or_path} model selected.") + v2 = gr.Checkbox.update(value=False, visible=False) + v_parameterization = gr.Checkbox.update(value=False, visible=False) + sdxl = gr.Checkbox.update(value=False, visible=False) return ( - model_list, - pretrained_model_name_or_path, - pretrained_model_name_or_path_file, - pretrained_model_name_or_path_folder, + gr.Dropdown().update(), v2, v_parameterization, sdxl, ) # Check if the model_list is set to 'custom' - if model_list == "custom": - v2 = gr.Checkbox(visible=True) - v_parameterization = gr.Checkbox(visible=True) - sdxl = gr.Checkbox(visible=True) - pretrained_model_name_or_path = gr.Textbox(visible=True) - pretrained_model_name_or_path_file = gr.Button(visible=True) - pretrained_model_name_or_path_folder = gr.Button(visible=True) - return ( - model_list, - pretrained_model_name_or_path, - pretrained_model_name_or_path_file, - pretrained_model_name_or_path_folder, - v2, - v_parameterization, - sdxl, + v2 = gr.Checkbox.update(visible=True) + v_parameterization = gr.Checkbox.update(visible=True) + sdxl = gr.Checkbox.update(visible=True) + + if refresh_method is not None: + args = dict( + choices=refresh_method(pretrained_model_name_or_path), ) + else: + args = {} + return ( + gr.Dropdown().update(**args), + v2, + v_parameterization, + sdxl, + ) ### @@ -556,157 +644,84 @@ def set_pretrained_model_name_or_path_input( ### -def get_pretrained_model_name_or_path_file(model_list, pretrained_model_name_or_path): - pretrained_model_name_or_path = get_any_file_path(pretrained_model_name_or_path) - # set_model_list(model_list, pretrained_model_name_or_path) +# def get_pretrained_model_name_or_path_file(model_list, pretrained_model_name_or_path): +# pretrained_model_name_or_path = get_any_file_path(pretrained_model_name_or_path) +# # set_model_list(model_list, pretrained_model_name_or_path) def get_int_or_default(kwargs, key, default_value=0): value = kwargs.get(key, default_value) if isinstance(value, int): return value - elif isinstance(value, str): - return int(value) - elif isinstance(value, float): - return int(value) else: - log.info( - f"{key} is not an int, float or a string, setting value to {default_value}" - ) - return default_value + try: + return int(value) + except ValueError: + log.info( + f"{key} is not an int, float or a string, setting value to {default_value}" + ) + return default_value def get_float_or_default(kwargs, key, default_value=0.0): + # Try to retrieve the value for the specified key from the kwargs. + # Use the provided default_value if the key does not exist. value = kwargs.get(key, default_value) - if isinstance(value, float): - return value - elif isinstance(value, int): - return float(value) - elif isinstance(value, str): + + try: + # Try to convert the value to a float. This should works for int, float, + # and strings that represent a valid floating-point number. return float(value) - else: + except ValueError: + # If the conversion fails (for example, the value is a string that cannot + # be converted to a float), log the issue and return the provided default_value. log.info( - f"{key} is not an int, float or a string, setting value to {default_value}" + f"{key} is not an int, float or a valid string for conversion, setting value to {default_value}" ) return default_value def get_str_or_default(kwargs, key, default_value=""): value = kwargs.get(key, default_value) + + # Check if the retrieved value is already a string. if isinstance(value, str): return value - elif isinstance(value, int): - return str(value) - elif isinstance(value, str): - return str(value) else: - return default_value - - -# def run_cmd_training(**kwargs): -# run_cmd = "" - -# lr_scheduler = kwargs.get("lr_scheduler", "") -# if lr_scheduler: -# run_cmd += f' --lr_scheduler="{lr_scheduler}"' - -# lr_warmup_steps = kwargs.get("lr_warmup_steps", "") -# if lr_warmup_steps: -# if lr_scheduler == "constant": -# log.info("Can't use LR warmup with LR Scheduler constant... ignoring...") -# else: -# run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"' - -# train_batch_size = kwargs.get("train_batch_size", "") -# if train_batch_size: -# run_cmd += f' --train_batch_size="{train_batch_size}"' - -# max_train_steps = kwargs.get("max_train_steps", "") -# if max_train_steps: -# run_cmd += f' --max_train_steps="{max_train_steps}"' - -# save_every_n_epochs = kwargs.get("save_every_n_epochs") -# if save_every_n_epochs: -# run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"' - -# mixed_precision = kwargs.get("mixed_precision", "") -# if mixed_precision: -# run_cmd += f' --mixed_precision="{mixed_precision}"' - -# save_precision = kwargs.get("save_precision", "") -# if save_precision: -# run_cmd += f' --save_precision="{save_precision}"' - -# seed = kwargs.get("seed", "") -# if seed != "": -# run_cmd += f' --seed="{seed}"' - -# caption_extension = kwargs.get("caption_extension", "") -# if caption_extension: -# run_cmd += f' --caption_extension="{caption_extension}"' - -# cache_latents = kwargs.get("cache_latents") -# if cache_latents: -# run_cmd += " --cache_latents" - -# cache_latents_to_disk = kwargs.get("cache_latents_to_disk") -# if cache_latents_to_disk: -# run_cmd += " --cache_latents_to_disk" - -# optimizer_type = kwargs.get("optimizer", "AdamW") -# run_cmd += f' --optimizer_type="{optimizer_type}"' - -# optimizer_args = kwargs.get("optimizer_args", "") -# if optimizer_args != "": -# run_cmd += f" --optimizer_args {optimizer_args}" - -# lr_scheduler_args = kwargs.get("lr_scheduler_args", "") -# if lr_scheduler_args != "": -# run_cmd += f" --lr_scheduler_args {lr_scheduler_args}" - -# max_grad_norm = kwargs.get("max_grad_norm", "") -# if max_grad_norm != "": -# run_cmd += f' --max_grad_norm="{max_grad_norm}"' - -# return run_cmd + # If the value is not a string (e.g., int, float, or any other type), + # convert it to a string and return the converted value. + return str(value) def run_cmd_advanced_training(**kwargs): run_cmd = "" - additional_parameters = kwargs.get("additional_parameters") - if additional_parameters: - run_cmd += f" {additional_parameters}" + if "additional_parameters" in kwargs: + run_cmd += f' {kwargs["additional_parameters"]}' - block_lr = kwargs.get("block_lr") - if block_lr: - run_cmd += f' --block_lr="{block_lr}"' + if "block_lr" in kwargs: + run_cmd += f' --block_lr="{kwargs["block_lr"]}"' - bucket_no_upscale = kwargs.get("bucket_no_upscale") - if bucket_no_upscale: + if kwargs.get("bucket_no_upscale"): run_cmd += " --bucket_no_upscale" - bucket_reso_steps = kwargs.get("bucket_reso_steps") - if bucket_reso_steps: - run_cmd += f" --bucket_reso_steps={int(bucket_reso_steps)}" + if "bucket_reso_steps" in kwargs: + run_cmd += f' --bucket_reso_steps={int(kwargs["bucket_reso_steps"])}' - cache_latents = kwargs.get("cache_latents") - if cache_latents: + if kwargs.get("cache_latents"): run_cmd += " --cache_latents" - cache_latents_to_disk = kwargs.get("cache_latents_to_disk") - if cache_latents_to_disk: + if kwargs.get("cache_latents_to_disk"): run_cmd += " --cache_latents_to_disk" - cache_text_encoder_outputs = kwargs.get("cache_text_encoder_outputs") - if cache_text_encoder_outputs: + if kwargs.get("cache_text_encoder_outputs"): run_cmd += " --cache_text_encoder_outputs" - caption_dropout_every_n_epochs = kwargs.get("caption_dropout_every_n_epochs") - if caption_dropout_every_n_epochs and int(caption_dropout_every_n_epochs) > 0: - run_cmd += ( - f' --caption_dropout_every_n_epochs="{int(caption_dropout_every_n_epochs)}"' - ) + if ( + "caption_dropout_every_n_epochs" in kwargs + and int(kwargs["caption_dropout_every_n_epochs"]) > 0 + ): + run_cmd += f' --caption_dropout_every_n_epochs="{int(kwargs["caption_dropout_every_n_epochs"])}"' caption_dropout_rate = kwargs.get("caption_dropout_rate") if caption_dropout_rate and float(caption_dropout_rate) > 0: @@ -738,12 +753,14 @@ def run_cmd_advanced_training(**kwargs): ): # Only if lora_network_weights is true run_cmd += f" --dim_from_weights" - enable_bucket = kwargs.get("enable_bucket") - if enable_bucket: - min_bucket_reso = kwargs.get("min_bucket_reso") - max_bucket_reso = kwargs.get("max_bucket_reso") - if min_bucket_reso and max_bucket_reso: - run_cmd += f" --enable_bucket --min_bucket_reso={min_bucket_reso} --max_bucket_reso={max_bucket_reso}" + # Check if enable_bucket is true and both min_bucket_reso and max_bucket_reso are provided as part of the kwargs + if ( + kwargs.get("enable_bucket") + and "min_bucket_reso" in kwargs + and "max_bucket_reso" in kwargs + ): + # Append the enable_bucket flag and min/max bucket resolution values to the run_cmd string + run_cmd += f' --enable_bucket --min_bucket_reso={kwargs["min_bucket_reso"]} --max_bucket_reso={kwargs["max_bucket_reso"]}' in_json = kwargs.get("in_json") if in_json: @@ -765,44 +782,49 @@ def run_cmd_advanced_training(**kwargs): if full_fp16: run_cmd += " --full_fp16" - gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps") - if gradient_accumulation_steps and int(gradient_accumulation_steps) > 1: - run_cmd += f" --gradient_accumulation_steps={int(gradient_accumulation_steps)}" + if ( + "gradient_accumulation_steps" in kwargs + and int(kwargs["gradient_accumulation_steps"]) > 1 + ): + run_cmd += f" --gradient_accumulation_steps={int(kwargs['gradient_accumulation_steps'])}" - gradient_checkpointing = kwargs.get("gradient_checkpointing") - if gradient_checkpointing: + if kwargs.get("gradient_checkpointing"): run_cmd += " --gradient_checkpointing" - keep_tokens = kwargs.get("keep_tokens") - if keep_tokens and int(keep_tokens) > 0: - run_cmd += f' --keep_tokens="{int(keep_tokens)}"' + if "keep_tokens" in kwargs and int(kwargs["keep_tokens"]) > 0: + run_cmd += f' --keep_tokens="{int(kwargs["keep_tokens"])}"' - learning_rate = kwargs.get("learning_rate") - if learning_rate: - run_cmd += f' --learning_rate="{learning_rate}"' + if "learning_rate" in kwargs: + run_cmd += f' --learning_rate="{kwargs["learning_rate"]}"' - learning_rate_te = kwargs.get("learning_rate_te") - if learning_rate_te: - run_cmd += f' --learning_rate_te="{learning_rate_te}"' + if "learning_rate_te" in kwargs: + if kwargs["learning_rate_te"] == 0: + run_cmd += f' --learning_rate_te="0"' + else: + run_cmd += f' --learning_rate_te="{kwargs["learning_rate_te"]}"' - learning_rate_te1 = kwargs.get("learning_rate_te1") - if learning_rate_te1: - run_cmd += f' --learning_rate_te1="{learning_rate_te1}"' + if "learning_rate_te1" in kwargs: + if kwargs["learning_rate_te1"] == 0: + run_cmd += f' --learning_rate_te1="0"' + else: + run_cmd += f' --learning_rate_te1="{kwargs["learning_rate_te1"]}"' - learning_rate_te2 = kwargs.get("learning_rate_te2") - if learning_rate_te2: - run_cmd += f' --learning_rate_te2="{learning_rate_te2}"' + if "learning_rate_te2" in kwargs: + if kwargs["learning_rate_te2"] == 0: + run_cmd += f' --learning_rate_te2="0"' + else: + run_cmd += f' --learning_rate_te2="{kwargs["learning_rate_te2"]}"' logging_dir = kwargs.get("logging_dir") if logging_dir: if logging_dir.startswith('"') and logging_dir.endswith('"'): logging_dir = logging_dir[1:-1] if os.path.exists(logging_dir): - run_cmd += fr' --logging_dir="{logging_dir}"' + run_cmd += rf' --logging_dir="{logging_dir}"' lora_network_weights = kwargs.get("lora_network_weights") if lora_network_weights: - run_cmd += f' --network_weights="{lora_network_weights}"' # Yes, the parameter is now called network_weights instead of lora_network_weights + run_cmd += f' --network_weights="{lora_network_weights}"' # Yes, the parameter is now called network_weights instead of lora_network_weights lr_scheduler = kwargs.get("lr_scheduler") if lr_scheduler: @@ -919,29 +941,26 @@ def run_cmd_advanced_training(**kwargs): if no_token_padding: run_cmd += " --no_token_padding" - noise_offset_type = kwargs.get("noise_offset_type") - if noise_offset_type and noise_offset_type == "Original": - noise_offset = kwargs.get("noise_offset") - if noise_offset and float(noise_offset) > 0: - run_cmd += f" --noise_offset={float(noise_offset)}" + if "noise_offset_type" in kwargs: + noise_offset_type = kwargs["noise_offset_type"] - adaptive_noise_scale = kwargs.get("adaptive_noise_scale") - if ( - adaptive_noise_scale - and float(adaptive_noise_scale) != 0 - and float(noise_offset) > 0 - ): - run_cmd += f" --adaptive_noise_scale={float(adaptive_noise_scale)}" - elif noise_offset_type and noise_offset_type == "Multires": - multires_noise_iterations = kwargs.get("multires_noise_iterations") - if int(multires_noise_iterations) > 0: - run_cmd += ( - f' --multires_noise_iterations="{int(multires_noise_iterations)}"' - ) + if kwargs["noise_offset_type"] == "Original": + noise_offset = float(kwargs.get("noise_offset", 0)) + if noise_offset: + run_cmd += f" --noise_offset={noise_offset}" + + adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0)) + if adaptive_noise_scale != 0 and noise_offset > 0: + run_cmd += f" --adaptive_noise_scale={adaptive_noise_scale}" + + elif noise_offset_type == "Multires": + multires_noise_iterations = int(kwargs.get("multires_noise_iterations", 0)) + if multires_noise_iterations > 0: + run_cmd += f' --multires_noise_iterations="{multires_noise_iterations}"' - multires_noise_discount = kwargs.get("multires_noise_discount") - if multires_noise_discount and float(multires_noise_discount) > 0: - run_cmd += f' --multires_noise_discount="{float(multires_noise_discount)}"' + multires_noise_discount = float(kwargs.get("multires_noise_discount", 0)) + if multires_noise_discount > 0: + run_cmd += f' --multires_noise_discount="{multires_noise_discount}"' num_machines = kwargs.get("num_machines") if num_machines and int(num_machines) > 1: @@ -968,7 +987,7 @@ def run_cmd_advanced_training(**kwargs): if output_dir.startswith('"') and output_dir.endswith('"'): output_dir = output_dir[1:-1] if os.path.exists(output_dir): - run_cmd += fr' --output_dir="{output_dir}"' + run_cmd += rf' --output_dir="{output_dir}"' output_name = kwargs.get("output_name") if output_name and not output_name == "": @@ -980,7 +999,9 @@ def run_cmd_advanced_training(**kwargs): pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path") if pretrained_model_name_or_path: - run_cmd += f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + run_cmd += ( + rf' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + ) prior_loss_weight = kwargs.get("prior_loss_weight") if prior_loss_weight and not float(prior_loss_weight) == 1.0: @@ -995,7 +1016,7 @@ def run_cmd_advanced_training(**kwargs): if reg_data_dir.startswith('"') and reg_data_dir.endswith('"'): reg_data_dir = reg_data_dir[1:-1] if os.path.isdir(reg_data_dir): - run_cmd += fr' --reg_data_dir="{reg_data_dir}"' + run_cmd += rf' --reg_data_dir="{reg_data_dir}"' resume = kwargs.get("resume") if resume: @@ -1066,7 +1087,7 @@ def run_cmd_advanced_training(**kwargs): if train_data_dir.startswith('"') and train_data_dir.endswith('"'): train_data_dir = train_data_dir[1:-1] if os.path.exists(train_data_dir): - run_cmd += fr' --train_data_dir="{train_data_dir}"' + run_cmd += rf' --train_data_dir="{train_data_dir}"' train_text_encoder = kwargs.get("train_text_encoder") if train_text_encoder: @@ -1094,7 +1115,10 @@ def run_cmd_advanced_training(**kwargs): vae = kwargs.get("vae") if vae and not vae == "": - run_cmd += f' --vae="{vae}"' + if not os.path.exists(vae): + vae = os.path.join("models", "VAE", vae).replace(os.sep, "/") + if os.path.exists(vae): + run_cmd += f' --vae="{vae}"' vae_batch_size = kwargs.get("vae_batch_size") if vae_batch_size and int(vae_batch_size) > 0: diff --git a/kohya_gui/convert_lcm_gui.py b/kohya_gui/convert_lcm_gui.py index c48f17358..768ec1892 100644 --- a/kohya_gui/convert_lcm_gui.py +++ b/kohya_gui/convert_lcm_gui.py @@ -6,6 +6,8 @@ get_saveasfilename_path, get_file_path, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -27,11 +29,29 @@ def convert_lcm( model_type ): run_cmd = fr'{PYTHON} "{scriptdir}/tools/lcm_convert.py"' + + # Check if source model exist + if not os.path.isfile(model_path): + msgbox('The provided DyLoRA model is not a file') + return + + if os.path.dirname(name) == "": + # only filename given. prepend dir + name = os.path.join(os.path.dirname(model_path), name) + if os.path.isdir(name): + # only dir name given. set default lcm name + name = os.path.join(name, "lcm.safetensors") + if os.path.normpath(model_path) == os.path.normpath(name): + # same path. silently ignore but rename output + path, ext = os.path.splitext(save_to) + save_to = f"{path}_lcm{ext}" + + # Construct the command to run the script - run_cmd += f' --name "{name}"' - run_cmd += f' --model "{model_path}"' run_cmd += f" --lora-scale {lora_scale}" - + run_cmd += f' --model "{model_path}"' + run_cmd += f' --name "{name}"' + if model_type == "SDXL": run_cmd += f" --sdxl" if model_type == "SSD-1B": @@ -40,7 +60,7 @@ def convert_lcm( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -50,6 +70,17 @@ def convert_lcm( def gradio_convert_lcm_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".safetensors"], all=True)) + with gr.Tab("Convert to LCM"): gr.Markdown("This utility convert a model to an LCM model.") lora_ext = gr.Textbox(value="*.safetensors", visible=False) @@ -57,14 +88,19 @@ def gradio_convert_lcm_tab(headless=False): model_ext = gr.Textbox(value="*.safetensors", visible=False) model_ext_name = gr.Textbox(value="Model types", visible=False) - with gr.Row(): - model_path = gr.Textbox( + with gr.Group(), gr.Row(): + model_path = gr.Dropdown( label="Stable Diffusion model to convert to LCM", interactive=True, + choices=list_models(current_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(model_path, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") button_model_path_file = gr.Button( folder_symbol, elem_id="open_folder_small", + elem_classes=['tool'], visible=(not headless), ) button_model_path_file.click( @@ -74,14 +110,18 @@ def gradio_convert_lcm_tab(headless=False): show_progress=False, ) - name = gr.Textbox( + name = gr.Dropdown( label="Name of the new LCM model", - placeholder="Path to the LCM file to create", interactive=True, + choices=list_save_to(current_save_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_name = gr.Button( folder_symbol, elem_id="open_folder_small", + elem_classes=['tool'], visible=(not headless), ) button_name.click( @@ -90,6 +130,18 @@ def gradio_convert_lcm_tab(headless=False): outputs=name, show_progress=False, ) + model_path.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=model_path, + outputs=model_path, + show_progress=False, + ) + name.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=name, + outputs=name, + show_progress=False, + ) with gr.Row(): lora_scale = gr.Slider( @@ -102,7 +154,7 @@ def gradio_convert_lcm_tab(headless=False): ) # with gr.Row(): # no_half = gr.Checkbox(label="Convert the new LCM model to FP32", value=False) - model_type = gr.Dropdown( + model_type = gr.Radio( label="Model type", choices=["SD15", "SDXL", "SD-1B"], value="SD15" ) diff --git a/kohya_gui/convert_model_gui.py b/kohya_gui/convert_model_gui.py index f4a6a0397..3d689c64b 100644 --- a/kohya_gui/convert_model_gui.py +++ b/kohya_gui/convert_model_gui.py @@ -4,7 +4,7 @@ import os import shutil import sys -from .common_gui import get_folder_path, get_file_path, scriptdir +from .common_gui import get_folder_path, get_file_path, scriptdir, list_files, list_dirs from .custom_logging import setup_logging @@ -49,7 +49,7 @@ def convert_model( msgbox('The provided target folder does not exist') return - run_cmd = fr'{PYTHON} "{scriptdir}tools/convert_diffusers20_original_sd.py"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"' v1_models = [ 'runwayml/stable-diffusion-v1-5', @@ -104,7 +104,7 @@ def convert_model( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -172,6 +172,21 @@ def convert_model( def gradio_convert_model_tab(headless=False): + from .common_gui import create_refresh_button + + default_source_model = os.path.join(scriptdir, "outputs") + default_target_folder = os.path.join(scriptdir, "outputs") + current_source_model = default_source_model + current_target_folder = default_target_folder + + def list_source_model(path): + current_source_model = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_target_folder(path): + current_target_folder = path + return list(list_dirs(path)) + with gr.Tab('Convert model'): gr.Markdown( 'This utility can be used to convert from one stable diffusion model format to another.' @@ -180,15 +195,20 @@ def gradio_convert_model_tab(headless=False): model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False) - with gr.Row(): - source_model_input = gr.Textbox( - label='Source model', - placeholder='path to source model folder of file to convert...', + with gr.Group(), gr.Row(): + with gr.Column(), gr.Row(): + source_model_input = gr.Dropdown( + label='Source model (path to source model folder of file to convert...)', interactive=True, + choices=list_source_model(default_source_model), + value="", + allow_custom_value=True, ) + create_refresh_button(source_model_input, lambda: None, lambda: {"choices": list_source_model(current_source_model)}, "open_folder_small") button_source_model_dir = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_source_model_dir.click( @@ -200,6 +220,7 @@ def gradio_convert_model_tab(headless=False): button_source_model_file = gr.Button( document_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_source_model_file.click( @@ -209,6 +230,13 @@ def gradio_convert_model_tab(headless=False): show_progress=False, ) + source_model_input.change( + fn=lambda path: gr.Dropdown().update(choices=list_source_model(path)), + inputs=source_model_input, + outputs=source_model_input, + show_progress=False, + ) + with gr.Column(), gr.Row(): source_model_type = gr.Dropdown( label='Source model type', choices=[ @@ -220,15 +248,20 @@ def gradio_convert_model_tab(headless=False): 'CompVis/stable-diffusion-v1-4', ], ) - with gr.Row(): - target_model_folder_input = gr.Textbox( - label='Target model folder', - placeholder='path to target model folder of file name to create...', + with gr.Group(), gr.Row(): + with gr.Column(), gr.Row(): + target_model_folder_input = gr.Dropdown( + label='Target model folder (path to target model folder of file name to create...)', interactive=True, + choices=list_target_folder(default_target_folder), + value="", + allow_custom_value=True, ) + create_refresh_button(target_model_folder_input, lambda: None, lambda: {"choices": list_target_folder(current_target_folder)},"open_folder_small") button_target_model_folder = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_target_model_folder.click( @@ -237,11 +270,20 @@ def gradio_convert_model_tab(headless=False): show_progress=False, ) + target_model_folder_input.change( + fn=lambda path: gr.Dropdown().update(choices=list_target_folder(path)), + inputs=target_model_folder_input, + outputs=target_model_folder_input, + show_progress=False, + ) + + with gr.Column(), gr.Row(): target_model_name_input = gr.Textbox( label='Target model name', placeholder='target model name...', interactive=True, ) + with gr.Row(): target_model_type = gr.Dropdown( label='Target model type', choices=[ diff --git a/kohya_gui/dataset_balancing_gui.py b/kohya_gui/dataset_balancing_gui.py index f2d1c0533..a20e92cbe 100644 --- a/kohya_gui/dataset_balancing_gui.py +++ b/kohya_gui/dataset_balancing_gui.py @@ -2,7 +2,7 @@ import re import gradio as gr from easygui import msgbox, boolbox -from .common_gui import get_folder_path +from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button from .custom_logging import setup_logging @@ -110,6 +110,9 @@ def warning(insecure): def gradio_dataset_balancing_tab(headless=False): + + current_dataset_dir = os.path.join(scriptdir, "data") + with gr.Tab('Dreambooth/LoRA Dataset balancing'): gr.Markdown( 'This utility will ensure that each concept folder in the dataset folder is used equally during the training process of the dreambooth machine learning model, regardless of the number of images in each folder. It will do this by renaming the concept folders to indicate the number of times they should be repeated during training.' @@ -117,15 +120,22 @@ def gradio_dataset_balancing_tab(headless=False): gr.Markdown( 'WARNING! The use of this utility on the wrong folder can lead to unexpected folder renaming!!!' ) - with gr.Row(): - select_dataset_folder_input = gr.Textbox( - label='Dataset folder', - placeholder='Folder containing the concepts folders to balance...', + with gr.Group(), gr.Row(): + + def list_dataset_dirs(path): + current_dataset_dir = path + return list(list_dirs(path)) + + select_dataset_folder_input = gr.Dropdown( + label='Dataset folder (folder containing the concepts folders to balance...)', interactive=True, + choices=list_dataset_dirs(current_dataset_dir), + value="", + allow_custom_value=True, ) - + create_refresh_button(select_dataset_folder_input, lambda: None, lambda: {"choices": list_dataset_dir(current_dataset_dir)}, "open_folder_small") select_dataset_folder_button = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) select_dataset_folder_button.click( get_folder_path, @@ -138,6 +148,13 @@ def gradio_dataset_balancing_tab(headless=False): interactive=True, label='Training steps per concept per epoch', ) + select_dataset_folder_input.change( + fn=lambda path: gr.Dropdown().update(choices=list_dataset_dirs(path)), + inputs=select_dataset_folder_input, + outputs=select_dataset_folder_input, + show_progress=False, + ) + with gr.Accordion('Advanced options', open=False): insecure = gr.Checkbox( value=False, diff --git a/kohya_gui/dreambooth_folder_creation_gui.py b/kohya_gui/dreambooth_folder_creation_gui.py index 66a243d3b..2a5b35731 100644 --- a/kohya_gui/dreambooth_folder_creation_gui.py +++ b/kohya_gui/dreambooth_folder_creation_gui.py @@ -1,6 +1,6 @@ import gradio as gr from easygui import diropenbox, msgbox -from .common_gui import get_folder_path +from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button import shutil import os @@ -119,12 +119,17 @@ def dreambooth_folder_preparation( def gradio_dreambooth_folder_creation_tab( - train_data_dir_input=gr.Textbox(), - reg_data_dir_input=gr.Textbox(), - output_dir_input=gr.Textbox(), - logging_dir_input=gr.Textbox(), + train_data_dir_input=gr.Dropdown(), + reg_data_dir_input=gr.Dropdown(), + output_dir_input=gr.Dropdown(), + logging_dir_input=gr.Dropdown(), headless=False, ): + + current_train_data_dir = os.path.join(scriptdir, "data") + current_reg_data_dir = os.path.join(scriptdir, "data") + current_train_output_dir = os.path.join(scriptdir, "data") + with gr.Tab('Dreambooth/LoRA Folder preparation'): gr.Markdown( 'This utility will create the necessary folder structure for the training images and optional regularization images needed for the kohys_ss Dreambooth/LoRA method to function correctly.' @@ -140,14 +145,22 @@ def gradio_dreambooth_folder_creation_tab( placeholder='Eg: person', interactive=True, ) - with gr.Row(): - util_training_images_dir_input = gr.Textbox( - label='Training images', - placeholder='Directory containing the training images', + with gr.Group(), gr.Row(): + + def list_train_data_dirs(path): + current_train_data_dir = path + return list(list_dirs(path)) + + util_training_images_dir_input = gr.Dropdown( + label='Training images (directory containing the training images)', interactive=True, + choices=list_train_data_dirs(current_train_data_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(util_training_images_dir_input, lambda: None, lambda: {"choices": list_train_data_dir(current_train_data_dir)}, "open_folder_small") button_util_training_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_util_training_images_dir_input.click( get_folder_path, @@ -160,14 +173,28 @@ def gradio_dreambooth_folder_creation_tab( interactive=True, elem_id='number_input', ) - with gr.Row(): - util_regularization_images_dir_input = gr.Textbox( - label='Regularisation images', - placeholder='(Optional) Directory containing the regularisation images', + util_training_images_dir_input.change( + fn=lambda path: gr.Dropdown().update(choices=list_train_data_dirs(path)), + inputs=util_training_images_dir_input, + outputs=util_training_images_dir_input, + show_progress=False, + ) + + with gr.Group(), gr.Row(): + def list_reg_data_dirs(path): + current_reg_data_dir = path + return list(list_dirs(path)) + + util_regularization_images_dir_input = gr.Dropdown( + label='Regularisation images (Optional. directory containing the regularisation images)', interactive=True, + choices=list_reg_data_dirs(current_reg_data_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(util_regularization_images_dir_input, lambda: None, lambda: {"choices": list_reg_data_dir(current_reg_data_dir)}, "open_folder_small") button_util_regularization_images_dir_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_util_regularization_images_dir_input.click( get_folder_path, @@ -180,18 +207,37 @@ def gradio_dreambooth_folder_creation_tab( interactive=True, elem_id='number_input', ) - with gr.Row(): - util_training_dir_output = gr.Textbox( - label='Destination training directory', - placeholder='Directory where formatted training and regularisation folders will be placed', + util_regularization_images_dir_input.change( + fn=lambda path: gr.Dropdown().update(choices=list_reg_data_dirs(path)), + inputs=util_regularization_images_dir_input, + outputs=util_regularization_images_dir_input, + show_progress=False, + ) + with gr.Group(), gr.Row(): + def list_train_output_dirs(path): + current_train_output_dir = path + return list(list_dirs(path)) + + util_training_dir_output = gr.Dropdown( + label='Destination training directory (where formatted training and regularisation folders will be placed)', interactive=True, + choices=list_train_output_dirs(current_train_output_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(util_training_dir_output, lambda: None, lambda: {"choices": list_train_output_dirs(current_train_output_dir)}, "open_folder_small") button_util_training_dir_output = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_util_training_dir_output.click( get_folder_path, outputs=util_training_dir_output ) + util_training_dir_output.change( + fn=lambda path: gr.Dropdown().update(choices=list_train_output_dirs(path)), + inputs=util_training_dir_output, + outputs=util_training_dir_output, + show_progress=False, + ) button_prepare_training_data = gr.Button('Prepare training data') button_prepare_training_data.click( dreambooth_folder_preparation, diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index b002ee199..14d74f700 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -41,10 +41,9 @@ gradio_dreambooth_folder_creation_tab, ) from .dataset_balancing_gui import gradio_dataset_balancing_tab -from .utilities import utilities_tab from .class_sample_images import SampleImages, run_cmd_sample -from .custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() @@ -52,6 +51,7 @@ # Setup command executor executor = CommandExecutor() +PYTHON = sys.executable def save_configuration( save_as, @@ -561,91 +561,98 @@ def train_model( ) if sdxl: - run_cmd += fr' "{scriptdir}/sdxl_train.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train.py"' else: - run_cmd += fr' "{scriptdir}/train_db.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/train_db.py"' + + # Initialize a dictionary with always-included keyword arguments + kwargs_for_training = { + "adaptive_noise_scale": adaptive_noise_scale, + "additional_parameters": additional_parameters, + "bucket_no_upscale": bucket_no_upscale, + "bucket_reso_steps": bucket_reso_steps, + "cache_latents": cache_latents, + "cache_latents_to_disk": cache_latents_to_disk, + "caption_dropout_every_n_epochs": caption_dropout_every_n_epochs, + "caption_dropout_rate": caption_dropout_rate, + "caption_extension": caption_extension, + "clip_skip": clip_skip, + "color_aug": color_aug, + "enable_bucket": enable_bucket, + "epoch": epoch, + "flip_aug": flip_aug, + "full_bf16": full_bf16, + "full_fp16": full_fp16, + "gradient_accumulation_steps": gradient_accumulation_steps, + "gradient_checkpointing": gradient_checkpointing, + "keep_tokens": keep_tokens, + "learning_rate": learning_rate, + "logging_dir": logging_dir, + "lr_scheduler": lr_scheduler, + "lr_scheduler_args": lr_scheduler_args, + "lr_scheduler_num_cycles": lr_scheduler_num_cycles, + "lr_scheduler_power": lr_scheduler_power, + "lr_warmup_steps": lr_warmup_steps, + "max_bucket_reso": max_bucket_reso, + "max_data_loader_n_workers": max_data_loader_n_workers, + "max_resolution": max_resolution, + "max_timestep": max_timestep, + "max_token_length": max_token_length, + "max_train_epochs": max_train_epochs, + "max_train_steps": max_train_steps, + "mem_eff_attn": mem_eff_attn, + "min_bucket_reso": min_bucket_reso, + "min_snr_gamma": min_snr_gamma, + "min_timestep": min_timestep, + "mixed_precision": mixed_precision, + "multires_noise_discount": multires_noise_discount, + "multires_noise_iterations": multires_noise_iterations, + "no_token_padding": no_token_padding, + "noise_offset": noise_offset, + "noise_offset_type": noise_offset_type, + "optimizer": optimizer, + "optimizer_args": optimizer_args, + "output_dir": output_dir, + "output_name": output_name, + "persistent_data_loader_workers": persistent_data_loader_workers, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "prior_loss_weight": prior_loss_weight, + "random_crop": random_crop, + "reg_data_dir": reg_data_dir, + "resume": resume, + "save_every_n_epochs": save_every_n_epochs, + "save_every_n_steps": save_every_n_steps, + "save_last_n_steps": save_last_n_steps, + "save_last_n_steps_state": save_last_n_steps_state, + "save_model_as": save_model_as, + "save_precision": save_precision, + "save_state": save_state, + "scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred, + "seed": seed, + "shuffle_caption": shuffle_caption, + "stop_text_encoder_training": stop_text_encoder_training, + "train_batch_size": train_batch_size, + "train_data_dir": train_data_dir, + "use_wandb": use_wandb, + "v2": v2, + "v_parameterization": v_parameterization, + "v_pred_like_loss": v_pred_like_loss, + "vae": vae, + "vae_batch_size": vae_batch_size, + "wandb_api_key": wandb_api_key, + "weighted_captions": weighted_captions, + "xformers": xformers, + } + + # Conditionally include specific keyword arguments based on sdxl + if sdxl: + kwargs_for_training["learning_rate_te1"] = learning_rate_te1 + kwargs_for_training["learning_rate_te2"] = learning_rate_te2 + else: + kwargs_for_training["learning_rate_te"] = learning_rate_te - - run_cmd += run_cmd_advanced_training( - adaptive_noise_scale=adaptive_noise_scale, - additional_parameters=additional_parameters, - bucket_no_upscale=bucket_no_upscale, - bucket_reso_steps=bucket_reso_steps, - cache_latents=cache_latents, - cache_latents_to_disk=cache_latents_to_disk, - caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, - caption_extension=caption_extension, - clip_skip=clip_skip, - color_aug=color_aug, - enable_bucket=enable_bucket, - epoch=epoch, - flip_aug=flip_aug, - full_bf16=full_bf16, - full_fp16=full_fp16, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_checkpointing=gradient_checkpointing, - keep_tokens=keep_tokens, - learning_rate=learning_rate, - learning_rate_te1=learning_rate_te1 if sdxl else None, - learning_rate_te2=learning_rate_te2 if sdxl else None, - learning_rate_te=learning_rate_te if not sdxl else None, - logging_dir=logging_dir, - lr_scheduler=lr_scheduler, - lr_scheduler_args=lr_scheduler_args, - lr_scheduler_num_cycles=lr_scheduler_num_cycles, - lr_scheduler_power=lr_scheduler_power, - lr_warmup_steps=lr_warmup_steps, - max_bucket_reso=max_bucket_reso, - max_data_loader_n_workers=max_data_loader_n_workers, - max_resolution=max_resolution, - max_timestep=max_timestep, - max_token_length=max_token_length, - max_train_epochs=max_train_epochs, - max_train_steps=max_train_steps, - mem_eff_attn=mem_eff_attn, - min_bucket_reso=min_bucket_reso, - min_snr_gamma=min_snr_gamma, - min_timestep=min_timestep, - mixed_precision=mixed_precision, - multires_noise_discount=multires_noise_discount, - multires_noise_iterations=multires_noise_iterations, - no_token_padding=no_token_padding, - noise_offset=noise_offset, - noise_offset_type=noise_offset_type, - optimizer=optimizer, - optimizer_args=optimizer_args, - output_dir=output_dir, - output_name=output_name, - persistent_data_loader_workers=persistent_data_loader_workers, - pretrained_model_name_or_path=pretrained_model_name_or_path, - prior_loss_weight=prior_loss_weight, - random_crop=random_crop, - reg_data_dir=reg_data_dir, - resume=resume, - save_every_n_epochs=save_every_n_epochs, - save_every_n_steps=save_every_n_steps, - save_last_n_steps=save_last_n_steps, - save_last_n_steps_state=save_last_n_steps_state, - save_model_as=save_model_as, - save_precision=save_precision, - save_state=save_state, - scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, - seed=seed, - shuffle_caption=shuffle_caption, - stop_text_encoder_training=stop_text_encoder_training, - train_batch_size=train_batch_size, - train_data_dir=train_data_dir, - use_wandb=use_wandb, - v2=v2, - v_parameterization=v_parameterization, - v_pred_like_loss=v_pred_like_loss, - vae=vae, - vae_batch_size=vae_batch_size, - wandb_api_key=wandb_api_key, - weighted_captions=weighted_captions, - xformers=xformers, - ) + # Pass the dynamically constructed keyword arguments to the function + run_cmd += run_cmd_advanced_training(**kwargs_for_training) run_cmd += run_cmd_sample( sample_every_n_steps, @@ -666,7 +673,8 @@ def train_model( # Saving config file for model current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + config_dir = os.path.dirname(os.path.dirname(train_data_dir)) + file_path = os.path.join(config_dir, f"{output_name}_{formatted_datetime}.json") log.info(f"Saving training config to {file_path}...") @@ -678,9 +686,12 @@ def train_model( log.info(run_cmd) + env = os.environ.copy() + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + # Run the command - executor.execute_command(run_cmd=run_cmd) + executor.execute_command(run_cmd=run_cmd, env=env) # check if output_dir/last is a folder... therefore it is a diffuser model last_dir = pathlib.Path(f"{output_dir}/{output_name}") @@ -701,18 +712,16 @@ def dreambooth_tab( dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab("Training"): + with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a custom model using kohya dreambooth python code...") - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) + with gr.Column(): + source_model = SourceModel(headless=headless) - source_model = SourceModel(headless=headless) - - with gr.Tab("Folders"): + with gr.Accordion("Folders", open=False), gr.Group(): folders = Folders(headless=headless) - with gr.Tab("Parameters"): - with gr.Tab("Basic", elem_id="basic_tab"): + with gr.Accordion("Parameters", open=False), gr.Column(): + with gr.Group(elem_id="basic_tab"): basic_training = BasicTraining( learning_rate_value="1e-5", lr_scheduler_value="cosine", @@ -724,7 +733,7 @@ def dreambooth_tab( # # Add SDXL Parameters # sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False) - with gr.Tab("Advanced", elem_id="advanced_tab"): + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless) advanced_training.color_aug.change( color_aug_changed, @@ -732,15 +741,15 @@ def dreambooth_tab( outputs=[basic_training.cache_latents], ) - with gr.Tab("Samples", elem_id="samples_tab"): + with gr.Accordion("Samples", open=False, elem_id="samples_tab"): sample = SampleImages() - with gr.Tab("Dataset Preparation"): + with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( "This section provide Dreambooth tools to help setup your dataset..." ) gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, + train_data_dir_input=source_model.train_data_dir, reg_data_dir_input=folders.reg_data_dir, output_dir_input=folders.output_dir, logging_dir_input=folders.logging_dir, @@ -748,18 +757,24 @@ def dreambooth_tab( ) gradio_dataset_balancing_tab(headless=headless) - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") + # Setup Configuration Files Gradio + with gr.Accordion("Configuration", open=False): + config = ConfigurationFile(headless=headless, output_dir=folders.output_dir) + + with gr.Column(), gr.Group(): + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") - button_stop_training = gr.Button("Stop training") + button_stop_training = gr.Button("Stop training") - button_print = gr.Button("Print training command") + button_print = gr.Button("Print training command") # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() + with gr.Column(), gr.Group(): + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() button_start_tensorboard.click( start_tensorboard, @@ -778,7 +793,7 @@ def dreambooth_tab( source_model.v_parameterization, source_model.sdxl_checkbox, folders.logging_dir, - folders.train_data_dir, + source_model.train_data_dir, folders.reg_data_dir, folders.output_dir, basic_training.max_resolution, @@ -792,7 +807,7 @@ def dreambooth_tab( basic_training.epoch, basic_training.save_every_n_epochs, basic_training.mixed_precision, - basic_training.save_precision, + source_model.save_precision, basic_training.seed, basic_training.num_cpu_threads_per_process, basic_training.cache_latents, @@ -820,7 +835,7 @@ def dreambooth_tab( advanced_training.num_machines, advanced_training.multi_gpu, advanced_training.gpu_ids, - folders.output_name, + source_model.output_name, advanced_training.max_token_length, basic_training.max_train_epochs, basic_training.max_train_steps, @@ -885,12 +900,12 @@ def dreambooth_tab( show_progress=False, ) - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + #config.button_save_as_config.click( + # save_configuration, + # inputs=[dummy_db_true, config.config_file_name] + settings_list, + # outputs=[config.config_file_name], + # show_progress=False, + #) button_run.click( train_model, @@ -907,7 +922,7 @@ def dreambooth_tab( ) return ( - folders.train_data_dir, + source_model.train_data_dir, folders.reg_data_dir, folders.output_dir, folders.logging_dir, diff --git a/kohya_gui/extract_lora_from_dylora_gui.py b/kohya_gui/extract_lora_from_dylora_gui.py index 3bccdb758..363ba4d1d 100644 --- a/kohya_gui/extract_lora_from_dylora_gui.py +++ b/kohya_gui/extract_lora_from_dylora_gui.py @@ -7,6 +7,8 @@ get_saveasfilename_path, get_file_path, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -37,17 +39,28 @@ def extract_dylora( msgbox('The provided DyLoRA model is not a file') return + if os.path.dirname(save_to) == "": + # only filename given. prepend dir + save_to = os.path.join(os.path.dirname(model), save_to) + if os.path.isdir(save_to): + # only dir name given. set default lora name + save_to = os.path.join(save_to, "lora.safetensors") + if os.path.normpath(model) == os.path.normpath(save_to): + # same path. silently ignore but rename output + path, ext = os.path.splitext(save_to) + save_to = f"{path}_tmp{ext}" + run_cmd = ( - fr'{PYTHON} "{scriptdir}/networks/extract_lora_from_dylora.py"' + fr'{PYTHON} "{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"' ) - run_cmd += f' --save_to "{save_to}"' - run_cmd += f' --model "{model}"' + run_cmd += fr' --save_to "{save_to}"' + run_cmd += fr' --model "{model}"' run_cmd += f' --unit {unit}' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -61,6 +74,9 @@ def extract_dylora( def gradio_extract_dylora_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + with gr.Tab('Extract DyLoRA'): gr.Markdown( 'This utility can extract a DyLoRA network from a finetuned model.' @@ -68,15 +84,27 @@ def gradio_extract_dylora_tab(headless=False): lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - with gr.Row(): - model = gr.Textbox( - label='DyLoRA model', - placeholder='Path to the DyLoRA model to extract from', + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + with gr.Group(), gr.Row(): + model = gr.Dropdown( + label='DyLoRA model (path to the DyLoRA model to extract from)', interactive=True, + choices=list_models(current_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") button_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_model_file.click( @@ -86,22 +114,14 @@ def gradio_extract_dylora_tab(headless=False): show_progress=False, ) - save_to = gr.Textbox( - label='Save to', - placeholder='path where to save the extracted LoRA model...', + save_to = gr.Dropdown( + label='Save to (path where to save the extracted LoRA model...)', interactive=True, + choices=list_save_to(current_save_dir), + value="", + allow_custom_value=True, ) - button_save_to = gr.Button( - folder_symbol, - elem_id='open_folder_small', - visible=(not headless), - ) - button_save_to.click( - get_saveasfilename_path, - inputs=[save_to, lora_ext, lora_ext_name], - outputs=save_to, - show_progress=False, - ) + create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") unit = gr.Slider( minimum=1, maximum=256, @@ -111,6 +131,19 @@ def gradio_extract_dylora_tab(headless=False): interactive=True, ) + model.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=model, + outputs=model, + show_progress=False, + ) + save_to.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) + extract_button = gr.Button('Extract LoRA model') extract_button.click( diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index a3d13d41e..f57632bfc 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -8,6 +8,8 @@ get_file_path, is_file_writable, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -57,17 +59,28 @@ def extract_lora( log.info('The provided base model is not a file') return + if os.path.dirname(save_to) == "": + # only filename given. prepend dir + save_to = os.path.join(os.path.dirname(model_tuned), save_to) + if os.path.isdir(save_to): + # only dir name given. set default lora name + save_to = os.path.join(save_to, "lora.safetensors") + if os.path.normpath(model_tuned) == os.path.normpath(save_to): + # same path. silently ignore but rename output + path, ext = os.path.splitext(save_to) + save_to = f"{path}_tmp{ext}" + if not is_file_writable(save_to): return run_cmd = ( - fr'{PYTHON} "{scriptdir}/networks/extract_lora_from_models.py"' + fr'{PYTHON} "{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"' ) run_cmd += f' --load_precision {load_precision}' run_cmd += f' --save_precision {save_precision}' - run_cmd += f' --save_to "{save_to}"' - run_cmd += f' --model_org "{model_org}"' - run_cmd += f' --model_tuned "{model_tuned}"' + run_cmd += fr' --save_to "{save_to}"' + run_cmd += fr' --model_org "{model_org}"' + run_cmd += fr' --model_tuned "{model_tuned}"' run_cmd += f' --dim {dim}' run_cmd += f' --device {device}' if conv_dim > 0: @@ -85,7 +98,7 @@ def extract_lora( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -97,11 +110,26 @@ def extract_lora( def gradio_extract_lora_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + current_model_org_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_org_models(path): + current_model_org_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + def change_sdxl(sdxl): - return gr.Dropdown(visible=sdxl), gr.Dropdown(visible=sdxl) - - - + return gr.Dropdown().update(visible=sdxl), gr.Dropdown().update(visible=sdxl) + + with gr.Tab('Extract LoRA'): gr.Markdown( 'This utility can extract a LoRA network from a finetuned model.' @@ -111,15 +139,19 @@ def change_sdxl(sdxl): model_ext = gr.Textbox(value='*.ckpt *.safetensors', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False) - with gr.Row(): - model_tuned = gr.Textbox( - label='Finetuned model', - placeholder='Path to the finetuned model to extract', + with gr.Group(), gr.Row(): + model_tuned = gr.Dropdown( + label='Finetuned model (path to the finetuned model to extract)', interactive=True, + choices=list_models(current_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(model_tuned, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") button_model_tuned_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_model_tuned_file.click( @@ -128,7 +160,7 @@ def change_sdxl(sdxl): outputs=model_tuned, show_progress=False, ) - load_tuned_model_to = gr.Dropdown( + load_tuned_model_to = gr.Radio( label='Load finetuned model to', choices=['cpu', 'cuda', 'cuda:0'], value='cpu', @@ -136,15 +168,18 @@ def change_sdxl(sdxl): info="only for SDXL", visible=False, ) - with gr.Row(): - model_org = gr.Textbox( - label='Stable Diffusion base model', - placeholder='Stable Diffusion original model: ckpt or safetensors file', + model_org = gr.Dropdown( + label='Stable Diffusion base model (original model: ckpt or safetensors file)', interactive=True, + choices=list_org_models(current_model_org_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(model_org, lambda: None, lambda: {"choices": list_org_models(current_model_org_dir)}, "open_folder_small") button_model_org_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_model_org_file.click( @@ -161,16 +196,20 @@ def change_sdxl(sdxl): info="only for SDXL", visible=False, ) - with gr.Row(): - save_to = gr.Textbox( - label='Save to', - placeholder='path where to save the extracted LoRA model...', + with gr.Group(), gr.Row(): + save_to = gr.Dropdown( + label='Save to (path where to save the extracted LoRA model...)', interactive=True, + choices=list_save_to(current_save_dir), + value="", + allow_custom_value=True, scale=2, ) + create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_save_to = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_save_to.click( @@ -179,18 +218,37 @@ def change_sdxl(sdxl): outputs=save_to, show_progress=False, ) - save_precision = gr.Dropdown( + save_precision = gr.Radio( label='Save precision', choices=['fp16', 'bf16', 'float'], value='fp16', interactive=True, scale=1, ) - load_precision = gr.Dropdown( + load_precision = gr.Radio( label='Load precision', choices=['fp16', 'bf16', 'float'], value='fp16', interactive=True, scale=1, ) + + model_tuned.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=model_tuned, + outputs=model_tuned, + show_progress=False, + ) + model_org.change( + fn=lambda path: gr.Dropdown().update(choices=list_org_model(path)), + inputs=model_org, + outputs=model_org, + show_progress=False, + ) + save_to.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) with gr.Row(): dim = gr.Slider( minimum=4, @@ -210,18 +268,24 @@ def change_sdxl(sdxl): ) clamp_quantile = gr.Number( label='Clamp Quantile', - value=1, + value=0.99, + minimum=0, + maximum=1, + step=0.001, interactive=True, ) min_diff = gr.Number( label='Minimum difference', value=0.01, + minimum=0, + maximum=1, + step=0.001, interactive=True, ) with gr.Row(): v2 = gr.Checkbox(label='v2', value=False, interactive=True) sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True) - device = gr.Dropdown( + device = gr.Radio( label='Device', choices=[ 'cpu', diff --git a/kohya_gui/extract_lycoris_locon_gui.py b/kohya_gui/extract_lycoris_locon_gui.py index f8eafb0ba..66c4a69e8 100644 --- a/kohya_gui/extract_lycoris_locon_gui.py +++ b/kohya_gui/extract_lycoris_locon_gui.py @@ -8,6 +8,8 @@ get_any_file_path, get_file_path, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -61,7 +63,18 @@ def extract_lycoris_locon( msgbox("The provided base model is not a file") return - run_cmd = fr'{PYTHON} "{scriptdir}/tools/lycoris_locon_extract.py"' + if os.path.dirname(output_name) == "": + # only filename given. prepend dir + output_name = os.path.join(os.path.dirname(db_model), output_name) + if os.path.isdir(output_name): + # only dir name given. set default lora name + output_name = os.path.join(output_name, "lora.safetensors") + if os.path.normpath(db_model) == os.path.normpath(output_name): + # same path. silently ignore but rename output + path, ext = os.path.splitext(output_name) + output_name = f"{path}_tmp{ext}" + + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/tools/lycoris_locon_extract.py"' if is_sdxl: run_cmd += f" --is_sdxl" if is_v2: @@ -86,14 +99,14 @@ def extract_lycoris_locon( run_cmd += f" --sparsity {sparsity}" if disable_cp: run_cmd += f" --disable_cp" - run_cmd += f' "{base_model}"' - run_cmd += f' "{db_model}"' - run_cmd += f' "{output_name}"' + run_cmd += fr' "{base_model}"' + run_cmd += fr' "{db_model}"' + run_cmd += fr' "{output_name}"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -126,13 +139,30 @@ def update_mode(mode): # Iterate through the possible modes for m in modes: # Add a visibility update for each mode, setting it to True if the input mode matches the current mode in the loop - updates.append(gr.Row(visible=(mode == m))) + updates.append(gr.Row().update(visible=(mode == m))) # Return the visibility updates as a tuple return tuple(updates) def gradio_extract_lycoris_locon_tab(headless=False): + + current_model_dir = os.path.join(scriptdir, "outputs") + current_base_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_base_models(path): + current_model_org_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".safetensors"], all=True)) + with gr.Tab("Extract LyCORIS LoCON"): gr.Markdown( "This utility can extract a LyCORIS LoCon network from a finetuned model." @@ -144,15 +174,19 @@ def gradio_extract_lycoris_locon_tab(headless=False): model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False) model_ext_name = gr.Textbox(value="Model types", visible=False) - with gr.Row(): - db_model = gr.Textbox( - label="Finetuned model", - placeholder="Path to the finetuned model to extract", + with gr.Group(), gr.Row(): + db_model = gr.Dropdown( + label="Finetuned model (path to the finetuned model to extract)", interactive=True, + choices=list_models(current_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(db_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") button_db_model_file = gr.Button( folder_symbol, elem_id="open_folder_small", + elem_classes=['tool'], visible=(not headless), ) button_db_model_file.click( @@ -162,14 +196,17 @@ def gradio_extract_lycoris_locon_tab(headless=False): show_progress=False, ) - base_model = gr.Textbox( - label="Stable Diffusion base model", - placeholder="Stable Diffusion original model: ckpt or safetensors file", - interactive=True, + base_model = gr.Dropdown( + label="Stable Diffusion base model (original model: ckpt or safetensors file)", + choices=list_base_models(current_base_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(base_model, lambda: None, lambda: {"choices": list_base_models(current_base_model_dir)}, "open_folder_small") button_base_model_file = gr.Button( folder_symbol, elem_id="open_folder_small", + elem_classes=['tool'], visible=(not headless), ) button_base_model_file.click( @@ -178,15 +215,20 @@ def gradio_extract_lycoris_locon_tab(headless=False): outputs=base_model, show_progress=False, ) - with gr.Row(): - output_name = gr.Textbox( - label="Save to", - placeholder="path where to save the extracted LoRA model...", + with gr.Group(), gr.Row(): + output_name = gr.Dropdown( + label="Save to (path where to save the extracted LoRA model...)", interactive=True, + choices=list_save_to(current_save_dir), + value="", + allow_custom_value=True, + scale=2, ) + create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_output_name = gr.Button( folder_symbol, elem_id="open_folder_small", + elem_classes=['tool'], visible=(not headless), ) button_output_name.click( @@ -195,7 +237,7 @@ def gradio_extract_lycoris_locon_tab(headless=False): outputs=output_name, show_progress=False, ) - device = gr.Dropdown( + device = gr.Radio( label="Device", choices=[ "cpu", @@ -203,17 +245,38 @@ def gradio_extract_lycoris_locon_tab(headless=False): ], value="cuda", interactive=True, + scale=2, ) - is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True) + db_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=db_model, + outputs=db_model, + show_progress=False, + ) + base_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_base_model(path)), + inputs=base_model, + outputs=base_model, + show_progress=False, + ) + output_name.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=output_name, + outputs=output_name, + show_progress=False, + ) - is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True) - mode = gr.Dropdown( - label="Mode", - choices=["fixed", "full", "quantile", "ratio", "threshold"], - value="fixed", - interactive=True, - ) + is_sdxl = gr.Checkbox(label="is SDXL", value=False, interactive=True, scale=1) + + is_v2 = gr.Checkbox(label="is v2", value=False, interactive=True, scale=1) + with gr.Row(): + mode = gr.Radio( + label="Mode", + choices=["fixed", "full", "quantile", "ratio", "threshold"], + value="fixed", + interactive=True, + ) with gr.Row(visible=True) as fixed: linear_dim = gr.Slider( minimum=1, diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index fd429e7a9..01fdd1957 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -15,6 +15,7 @@ color_aug_changed, update_my_data, check_if_model_exist, + output_message, SaveConfigFile, save_to_file, scriptdir, @@ -435,6 +436,13 @@ def train_model( headless_bool = True if headless.get("label") == "True" else False + if output_dir == "": + output_message(msg="Output folder path is missing", headless=headless_bool) + return + + if train_dir is None or train_dir.strip() == "": + train_dir = output_dir + if not print_only_bool and check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): return @@ -461,33 +469,45 @@ def train_model( if train_dir != "" and not os.path.exists(train_dir): os.mkdir(train_dir) - run_cmd = fr'{PYTHON} "{scriptdir}/finetune/merge_captions_to_metadata.py"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"' if caption_extension == "": run_cmd += f' --caption_extension=".caption"' else: run_cmd += f" --caption_extension={caption_extension}" - run_cmd += f' "{image_folder}"' - run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + run_cmd += fr' "{image_folder}"' + run_cmd += fr' "{train_dir}/{caption_metadata_filename}"' if full_path: run_cmd += f" --full_path" log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" if not print_only_bool: # Run the command subprocess.run(run_cmd, shell=True, env=env) + # check pretrained_model_name_or_path + + if not os.path.exists(pretrained_model_name_or_path): + try: + from modules import sd_models + + info = sd_models.get_closet_checkpoint_match(pretrained_model_name_or_path) + if info is not None: + pretrained_model_name_or_path = info.filename + + except Exception: + pass # create images buckets if generate_image_buckets: - run_cmd = fr'{PYTHON} "{scriptdir}/finetune/prepare_buckets_latents.py"' - run_cmd += f' "{image_folder}"' - run_cmd += f' "{train_dir}/{caption_metadata_filename}"' - run_cmd += f' "{train_dir}/{latent_metadata_filename}"' - run_cmd += f' "{pretrained_model_name_or_path}"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"' + run_cmd += fr' "{image_folder}"' + run_cmd += fr' "{train_dir}/{caption_metadata_filename}"' + run_cmd += fr' "{train_dir}/{latent_metadata_filename}"' + run_cmd += fr' "{pretrained_model_name_or_path}"' run_cmd += f" --batch_size={batch_size}" run_cmd += f" --max_resolution={max_resolution}" run_cmd += f" --min_bucket_reso={min_bucket_reso}" @@ -504,7 +524,7 @@ def train_model( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" if not print_only_bool: # Run the command @@ -554,98 +574,104 @@ def train_model( ) if sdxl_checkbox: - run_cmd += fr' "{scriptdir}/sdxl_train.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train.py"' else: - run_cmd += fr' "{scriptdir}/fine_tune.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/fine_tune.py"' in_json = ( - f"{train_dir}/{latent_metadata_filename}" + fr"{train_dir}/{latent_metadata_filename}" if use_latent_files == "Yes" - else f"{train_dir}/{caption_metadata_filename}" + else fr"{train_dir}/{caption_metadata_filename}" ) cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs no_half_vae = sdxl_checkbox and sdxl_no_half_vae - run_cmd += run_cmd_advanced_training( - adaptive_noise_scale=adaptive_noise_scale, - additional_parameters=additional_parameters, - block_lr=block_lr, - bucket_no_upscale=bucket_no_upscale, - bucket_reso_steps=bucket_reso_steps, - cache_latents=cache_latents, - cache_latents_to_disk=cache_latents_to_disk, - cache_text_encoder_outputs=cache_text_encoder_outputs - if sdxl_checkbox - else None, - caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, - caption_extension=caption_extension, - clip_skip=clip_skip, - color_aug=color_aug, - dataset_repeats=dataset_repeats, - enable_bucket=True, - flip_aug=flip_aug, - full_bf16=full_bf16, - full_fp16=full_fp16, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_checkpointing=gradient_checkpointing, - in_json=in_json, - keep_tokens=keep_tokens, - learning_rate=learning_rate, - learning_rate_te1=learning_rate_te1 if sdxl_checkbox else None, - learning_rate_te2=learning_rate_te2 if sdxl_checkbox else None, - learning_rate_te=learning_rate_te if not sdxl_checkbox else None, - logging_dir=logging_dir, - lr_scheduler=lr_scheduler, - lr_scheduler_args=lr_scheduler_args, - lr_warmup_steps=lr_warmup_steps, - max_bucket_reso=max_bucket_reso, - max_data_loader_n_workers=max_data_loader_n_workers, - max_resolution=max_resolution, - max_timestep=max_timestep, - max_token_length=max_token_length, - max_train_epochs=max_train_epochs, - max_train_steps=max_train_steps, - mem_eff_attn=mem_eff_attn, - min_bucket_reso=min_bucket_reso, - min_snr_gamma=min_snr_gamma, - min_timestep=min_timestep, - mixed_precision=mixed_precision, - multires_noise_discount=multires_noise_discount, - multires_noise_iterations=multires_noise_iterations, - no_half_vae=no_half_vae if sdxl_checkbox else None, - noise_offset=noise_offset, - noise_offset_type=noise_offset_type, - optimizer=optimizer, - optimizer_args=optimizer_args, - output_dir=output_dir, - output_name=output_name, - persistent_data_loader_workers=persistent_data_loader_workers, - pretrained_model_name_or_path=pretrained_model_name_or_path, - random_crop=random_crop, - resume=resume, - save_every_n_epochs=save_every_n_epochs, - save_every_n_steps=save_every_n_steps, - save_last_n_steps=save_last_n_steps, - save_last_n_steps_state=save_last_n_steps_state, - save_model_as=save_model_as, - save_precision=save_precision, - save_state=save_state, - scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, - seed=seed, - shuffle_caption=shuffle_caption, - train_batch_size=train_batch_size, - train_data_dir=image_folder, - train_text_encoder=train_text_encoder, - use_wandb=use_wandb, - v2=v2, - v_parameterization=v_parameterization, - v_pred_like_loss=v_pred_like_loss, - vae_batch_size=vae_batch_size, - wandb_api_key=wandb_api_key, - weighted_captions=weighted_captions, - xformers=xformers, - ) + # Initialize a dictionary with always-included keyword arguments + kwargs_for_training = { + "adaptive_noise_scale": adaptive_noise_scale, + "additional_parameters": additional_parameters, + "block_lr": block_lr, + "bucket_no_upscale": bucket_no_upscale, + "bucket_reso_steps": bucket_reso_steps, + "cache_latents": cache_latents, + "cache_latents_to_disk": cache_latents_to_disk, + "caption_dropout_every_n_epochs": caption_dropout_every_n_epochs, + "caption_dropout_rate": caption_dropout_rate, + "caption_extension": caption_extension, + "clip_skip": clip_skip, + "color_aug": color_aug, + "dataset_repeats": dataset_repeats, + "enable_bucket": True, + "flip_aug": flip_aug, + "full_bf16": full_bf16, + "full_fp16": full_fp16, + "gradient_accumulation_steps": gradient_accumulation_steps, + "gradient_checkpointing": gradient_checkpointing, + "in_json": in_json, + "keep_tokens": keep_tokens, + "learning_rate": learning_rate, + "logging_dir": logging_dir, + "lr_scheduler": lr_scheduler, + "lr_scheduler_args": lr_scheduler_args, + "lr_warmup_steps": lr_warmup_steps, + "max_bucket_reso": max_bucket_reso, + "max_data_loader_n_workers": max_data_loader_n_workers, + "max_resolution": max_resolution, + "max_timestep": max_timestep, + "max_token_length": max_token_length, + "max_train_epochs": max_train_epochs, + "max_train_steps": max_train_steps, + "mem_eff_attn": mem_eff_attn, + "min_bucket_reso": min_bucket_reso, + "min_snr_gamma": min_snr_gamma, + "min_timestep": min_timestep, + "mixed_precision": mixed_precision, + "multires_noise_discount": multires_noise_discount, + "multires_noise_iterations": multires_noise_iterations, + "noise_offset": noise_offset, + "noise_offset_type": noise_offset_type, + "optimizer": optimizer, + "optimizer_args": optimizer_args, + "output_dir": output_dir, + "output_name": output_name, + "persistent_data_loader_workers": persistent_data_loader_workers, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "random_crop": random_crop, + "resume": resume, + "save_every_n_epochs": save_every_n_epochs, + "save_every_n_steps": save_every_n_steps, + "save_last_n_steps": save_last_n_steps, + "save_last_n_steps_state": save_last_n_steps_state, + "save_model_as": save_model_as, + "save_precision": save_precision, + "save_state": save_state, + "scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred, + "seed": seed, + "shuffle_caption": shuffle_caption, + "train_batch_size": train_batch_size, + "train_data_dir": image_folder, + "train_text_encoder": train_text_encoder, + "use_wandb": use_wandb, + "v2": v2, + "v_parameterization": v_parameterization, + "v_pred_like_loss": v_pred_like_loss, + "vae_batch_size": vae_batch_size, + "wandb_api_key": wandb_api_key, + "weighted_captions": weighted_captions, + "xformers": xformers, + } + + # Conditionally include specific keyword arguments based on sdxl_checkbox + if sdxl_checkbox: + kwargs_for_training["cache_text_encoder_outputs"] = cache_text_encoder_outputs + kwargs_for_training["learning_rate_te1"] = learning_rate_te1 + kwargs_for_training["learning_rate_te2"] = learning_rate_te2 + kwargs_for_training["no_half_vae"] = no_half_vae + else: + kwargs_for_training["learning_rate_te"] = learning_rate_te + + # Pass the dynamically constructed keyword arguments to the function + run_cmd += run_cmd_advanced_training(**kwargs_for_training) run_cmd += run_cmd_sample( sample_every_n_steps, @@ -666,7 +692,8 @@ def train_model( # Saving config file for model current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + config_dir = os.path.dirname(os.path.dirname(train_dir)) + file_path = os.path.join(config_dir, f"{output_name}_{formatted_datetime}.json") log.info(f"Saving training config to {file_path}...") @@ -679,7 +706,7 @@ def train_model( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command executor.execute_command(run_cmd=run_cmd, env=env) @@ -696,62 +723,21 @@ def finetune_tab(headless=False): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab("Training"): + with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a custom model using kohya finetune python code...") - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel(headless=headless) + with gr.Column(): + source_model = SourceModel(headless=headless, finetuning=True) + image_folder = source_model.train_data_dir + output_name = source_model.output_name - with gr.Tab("Folders"): - folders = Folders(headless=headless, finetune=True) - image_folder = folders.train_data_dir - train_dir = folders.reg_data_dir + with gr.Accordion("Folders", open=False), gr.Group(): + folders = Folders(headless=headless, train_data_dir=source_model.train_data_dir, finetune=True) output_dir = folders.output_dir logging_dir = folders.logging_dir - output_name = folders.output_name + train_dir = folders.reg_data_dir - with gr.Tab("Dataset preparation"): - with gr.Row(): - max_resolution = gr.Textbox( - label="Resolution (width,height)", value="512,512" - ) - min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256") - max_bucket_reso = gr.Textbox( - label="Max bucket resolution", value="1024" - ) - batch_size = gr.Textbox(label="Batch size", value="1") - with gr.Row(): - create_caption = gr.Checkbox( - label="Generate caption metadata", value=True - ) - create_buckets = gr.Checkbox( - label="Generate image buckets metadata", value=True - ) - use_latent_files = gr.Dropdown( - label="Use latent files", - choices=[ - "No", - "Yes", - ], - value="Yes", - ) - with gr.Accordion("Advanced parameters", open=False): - with gr.Row(): - caption_metadata_filename = gr.Textbox( - label="Caption metadata filename", - value="meta_cap.json", - ) - latent_metadata_filename = gr.Textbox( - label="Latent metadata filename", value="meta_lat.json" - ) - with gr.Row(): - full_path = gr.Checkbox(label="Use full path", value=True) - weighted_captions = gr.Checkbox( - label="Weighted captions", value=False - ) - with gr.Tab("Parameters"): + with gr.Accordion("Parameters", open=False), gr.Column(): def list_presets(path): json_files = [] @@ -775,7 +761,7 @@ def list_presets(path): elem_id="myDropdown", ) - with gr.Tab("Basic", elem_id="basic_tab"): + with gr.Group(elem_id="basic_tab"): basic_training = BasicTraining( learning_rate_value="1e-5", finetuning=True, @@ -791,13 +777,13 @@ def list_presets(path): label="Train text encoder", value=True ) - with gr.Tab("Advanced", elem_id="advanced_tab"): + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): with gr.Row(): gradient_accumulation_steps = gr.Number( - label="Gradient accumulate steps", value="1" + label="Gradient accumulate steps", value="1", ) block_lr = gr.Textbox( - label="Block LR", + label="Block LR (SDXL)", placeholder="(Optional)", info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3", ) @@ -810,21 +796,68 @@ def list_presets(path): ], # Not applicable to fine_tune.py ) - with gr.Tab("Samples", elem_id="samples_tab"): + with gr.Accordion("Samples", open=False, elem_id="samples_tab"): sample = SampleImages() - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") + with gr.Accordion("Dataset Preparation", open=False): + with gr.Row(): + max_resolution = gr.Textbox( + label="Resolution (width,height)", value="512,512" + ) + min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256") + max_bucket_reso = gr.Textbox( + label="Max bucket resolution", value="1024" + ) + batch_size = gr.Textbox(label="Batch size", value="1") + with gr.Row(): + create_caption = gr.Checkbox( + label="Generate caption metadata", value=True + ) + create_buckets = gr.Checkbox( + label="Generate image buckets metadata", value=True + ) + use_latent_files = gr.Dropdown( + label="Use latent files", + choices=[ + "No", + "Yes", + ], + value="Yes", + ) + with gr.Accordion("Advanced parameters", open=False): + with gr.Row(): + caption_metadata_filename = gr.Textbox( + label="Caption metadata filename", + value="meta_cap.json", + ) + latent_metadata_filename = gr.Textbox( + label="Latent metadata filename", value="meta_lat.json" + ) + with gr.Row(): + full_path = gr.Checkbox(label="Use full path", value=True) + weighted_captions = gr.Checkbox( + label="Weighted captions", value=False + ) + + # Setup Configuration Files Gradio + with gr.Accordion("Configuration", open=False): + config = ConfigurationFile(headless=headless, output_dir=train_dir) + - button_stop_training = gr.Button("Stop training") + with gr.Column(), gr.Group(): + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") + + button_stop_training = gr.Button("Stop training") - button_print = gr.Button("Print training command") + button_print = gr.Button("Print training command") # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() + with gr.Column(), gr.Group(): + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() button_start_tensorboard.click( start_tensorboard, @@ -861,7 +894,7 @@ def list_presets(path): basic_training.epoch, basic_training.save_every_n_epochs, basic_training.mixed_precision, - basic_training.save_precision, + source_model.save_precision, basic_training.seed, basic_training.num_cpu_threads_per_process, basic_training.learning_rate_te, @@ -994,12 +1027,12 @@ def list_presets(path): show_progress=False, ) - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + #config.button_save_as_config.click( + # save_configuration, + # inputs=[dummy_db_true, config.config_file_name] + settings_list, + # outputs=[config.config_file_name], + # show_progress=False, + #) with gr.Tab("Guides"): gr.Markdown("This section provide Various Finetuning guides and information...") diff --git a/kohya_gui/git_caption_gui.py b/kohya_gui/git_caption_gui.py index a791991fe..40fe8196f 100644 --- a/kohya_gui/git_caption_gui.py +++ b/kohya_gui/git_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import os import sys -from .common_gui import get_folder_path, add_pre_postfix, scriptdir +from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs from .custom_logging import setup_logging @@ -33,7 +33,7 @@ def caption_images( return log.info(f'GIT captioning files in {train_data_dir}...') - run_cmd = fr'{PYTHON} "{scriptdir}/finetune/make_captions_by_git.py"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"' if not model_id == '': run_cmd += f' --model_id="{model_id}"' run_cmd += f' --batch_size="{int(batch_size)}"' @@ -48,7 +48,7 @@ def caption_images( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -69,19 +69,31 @@ def caption_images( ### -def gradio_git_caption_gui_tab(headless=False): +def gradio_git_caption_gui_tab(headless=False, default_train_dir=None): + from .common_gui import create_refresh_button + + default_train_dir = default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data") + current_train_dir = default_train_dir + + def list_train_dirs(path): + current_train_dir = path + return list(list_dirs(path)) + with gr.Tab('GIT Captioning'): gr.Markdown( 'This utility will use GIT to caption files for each images in a folder.' ) - with gr.Row(): - train_data_dir = gr.Textbox( - label='Image folder to caption', - placeholder='Directory containing the images to caption', + with gr.Group(), gr.Row(): + train_data_dir = gr.Dropdown( + label='Image folder to caption (containing the images to caption)', + choices=list_train_dirs(default_train_dir), + value="", interactive=True, + allow_custom_value=True, ) + create_refresh_button(train_data_dir, lambda: None, lambda: {"choices": list_train_dir(current_train_dir)},"open_folder_small") button_train_data_dir_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_train_data_dir_input.click( get_folder_path, @@ -141,3 +153,10 @@ def gradio_git_caption_gui_tab(headless=False): ], show_progress=False, ) + + train_data_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_train_dirs(path)), + inputs=train_data_dir, + outputs=train_data_dir, + show_progress=False, + ) diff --git a/kohya_gui/group_images_gui.py b/kohya_gui/group_images_gui.py index bdd24ebbd..cb38ded69 100644 --- a/kohya_gui/group_images_gui.py +++ b/kohya_gui/group_images_gui.py @@ -1,7 +1,7 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path, scriptdir +from .common_gui import get_folder_path, scriptdir, list_dirs import os import sys @@ -48,7 +48,7 @@ def group_images( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -57,19 +57,35 @@ def group_images( def gradio_group_images_gui_tab(headless=False): + from .common_gui import create_refresh_button + + current_input_folder = os.path.join(scriptdir, "data") + current_output_folder = os.path.join(scriptdir, "data") + + def list_input_dirs(path): + current_input_folder = path + return list(list_dirs(path)) + + def list_output_dirs(path): + current_output_folder = path + return list(list_dirs(path)) + with gr.Tab('Group Images'): gr.Markdown( 'This utility will group images in a folder based on their aspect ratio.' ) - with gr.Row(): - input_folder = gr.Textbox( - label='Input folder', - placeholder='Directory containing the images to group', + with gr.Group(), gr.Row(): + input_folder = gr.Dropdown( + label='Input folder (containing the images to group)', interactive=True, + choices=list_input_dirs(current_input_folder), + value="", + allow_custom_value=True, ) + create_refresh_button(input_folder, lambda: None, lambda: {"choices": list_input_dirs(current_input_dir)},"open_folder_small") button_input_folder = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_input_folder.click( get_folder_path, @@ -77,19 +93,35 @@ def gradio_group_images_gui_tab(headless=False): show_progress=False, ) - output_folder = gr.Textbox( - label='Output folder', - placeholder='Directory where the grouped images will be stored', + output_folder = gr.Dropdown( + label='Output folder (where the grouped images will be stored)', interactive=True, + choices=list_output_dirs(current_output_folder), + value="", + allow_custom_value=True, ) + create_refresh_button(output_folder, lambda: None, lambda: {"choices": list_output_dirs(current_output_dir)},"open_folder_small") button_output_folder = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_output_folder.click( get_folder_path, outputs=output_folder, show_progress=False, ) + + input_folder.change( + fn=lambda path: gr.Dropdown().update(choices=list_input_dirs(path)), + inputs=input_folder, + outputs=input_folder, + show_progress=False, + ) + output_folder.change( + fn=lambda path: gr.Dropdown().update(choices=list_output_dirs(path)), + inputs=output_folder, + outputs=output_folder, + show_progress=False, + ) with gr.Row(): group_size = gr.Slider( label='Group size', diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 0a6859fdf..39d3fc1d0 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -53,6 +53,7 @@ document_symbol = "\U0001F4C4" # 📄 + presets_dir = fr'{scriptdir}/presets' def save_configuration( @@ -743,9 +744,9 @@ def train_model( ) if sdxl: - run_cmd += fr' "{scriptdir}/sdxl_train_network.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train_network.py"' else: - run_cmd += fr' "{scriptdir}/train_network.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/train_network.py"' if LoRA_type == "LyCORIS/Diag-OFT": network_module = "lycoris.kohya" @@ -1016,7 +1017,8 @@ def train_model( # Saving config file for model current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + config_dir = os.path.dirname(os.path.dirname(train_data_dir)) + file_path = os.path.join(config_dir, f"{output_name}_{formatted_datetime}.json") log.info(f"Saving training config to {file_path}...") @@ -1027,8 +1029,11 @@ def train_model( ) log.info(run_cmd) + env = os.environ.copy() + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + # Run the command - executor.execute_command(run_cmd=run_cmd) + executor.execute_command(run_cmd=run_cmd, env=env) # # check if output_dir/last is a folder... therefore it is a diffuser model # last_dir = pathlib.Path(f'{output_dir}/{output_name}') @@ -1041,36 +1046,34 @@ def train_model( def lora_tab( - train_data_dir_input=gr.Textbox(), - reg_data_dir_input=gr.Textbox(), - output_dir_input=gr.Textbox(), - logging_dir_input=gr.Textbox(), + train_data_dir_input=gr.Dropdown(), + reg_data_dir_input=gr.Dropdown(), + output_dir_input=gr.Dropdown(), + logging_dir_input=gr.Dropdown(), headless=False, ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab("Training"): + with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown( "Train a custom model using kohya train network LoRA python code..." ) - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel( - save_model_as_choices=[ - "ckpt", - "safetensors", - ], - headless=headless, - ) + with gr.Column(): + source_model = SourceModel( + save_model_as_choices=[ + "ckpt", + "safetensors", + ], + headless=headless, + ) - with gr.Tab("Folders"): + with gr.Accordion("Folders", open=False), gr.Group(): folders = Folders(headless=headless) - with gr.Tab("Parameters"): + with gr.Accordion("Parameters", open=False), gr.Column(): def list_presets(path): json_files = [] @@ -1093,12 +1096,12 @@ def list_presets(path): training_preset = gr.Dropdown( label="Presets", - choices=list_presets(f"{presets_dir}/lora"), + choices=list_presets(fr"{presets_dir}/lora"), elem_id="myDropdown", value="none" ) - with gr.Tab("Basic", elem_id="basic_tab"): + with gr.Group(elem_id="basic_tab"): with gr.Row(): LoRA_type = gr.Dropdown( label="LoRA type", @@ -1143,6 +1146,7 @@ def list_presets(path): lora_network_weights_file = gr.Button( document_symbol, elem_id="open_folder_small", + elem_classes=["tool"], visible=(not headless), ) lora_network_weights_file.click( @@ -1681,7 +1685,7 @@ def update_LoRA_settings( return tuple(results) - with gr.Tab("Advanced", elem_id="advanced_tab"): + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): # with gr.Accordion('Advanced Configuration', open=False): with gr.Row(visible=True) as kohya_advanced_lora: with gr.Tab(label="Weights"): @@ -1739,7 +1743,7 @@ def update_LoRA_settings( outputs=[basic_training.cache_latents], ) - with gr.Tab("Samples", elem_id="samples_tab"): + with gr.Accordion("Samples", open=False, elem_id="samples_tab"): sample = SampleImages() LoRA_type.change( @@ -1777,12 +1781,12 @@ def update_LoRA_settings( ], ) - with gr.Tab("Dataset Preparation"): + with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( "This section provide Dreambooth tools to help setup your dataset..." ) gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, + train_data_dir_input=source_model.train_data_dir, reg_data_dir_input=folders.reg_data_dir, output_dir_input=folders.output_dir, logging_dir_input=folders.logging_dir, @@ -1790,18 +1794,26 @@ def update_LoRA_settings( ) gradio_dataset_balancing_tab(headless=headless) - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") - button_stop_training = gr.Button("Stop training") + # Setup Configuration Files Gradio + with gr.Accordion('Configuration', open=False): + config = ConfigurationFile(headless=headless, output_dir=folders.output_dir) + + + with gr.Column(), gr.Group(): + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") - button_print = gr.Button("Print training command") + button_stop_training = gr.Button("Stop training") + + button_print = gr.Button("Print training command") # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() + with gr.Column(), gr.Group(): + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() button_start_tensorboard.click( start_tensorboard, @@ -1820,7 +1832,7 @@ def update_LoRA_settings( source_model.v_parameterization, source_model.sdxl_checkbox, folders.logging_dir, - folders.train_data_dir, + source_model.train_data_dir, folders.reg_data_dir, folders.output_dir, basic_training.max_resolution, @@ -1831,7 +1843,7 @@ def update_LoRA_settings( basic_training.epoch, basic_training.save_every_n_epochs, basic_training.mixed_precision, - basic_training.save_precision, + source_model.save_precision, basic_training.seed, basic_training.num_cpu_threads_per_process, basic_training.cache_latents, @@ -1865,14 +1877,14 @@ def update_LoRA_settings( advanced_training.gpu_ids, advanced_training.gradient_accumulation_steps, advanced_training.mem_eff_attn, - folders.output_name, + source_model.output_name, source_model.model_list, advanced_training.max_token_length, basic_training.max_train_epochs, basic_training.max_train_steps, advanced_training.max_data_loader_n_workers, network_alpha, - folders.training_comment, + source_model.training_comment, advanced_training.keep_tokens, basic_training.lr_scheduler_num_cycles, basic_training.lr_scheduler_power, @@ -1980,12 +1992,12 @@ def update_LoRA_settings( show_progress=False, ) - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + #config.button_save_as_config.click( + # save_configuration, + # inputs=[dummy_db_true, config.config_file_name] + settings_list, + # outputs=[config.config_file_name], + # show_progress=False, + #) button_run.click( train_model, @@ -2002,7 +2014,13 @@ def update_LoRA_settings( ) with gr.Tab("Tools"): - lora_tools = LoRATools(folders=folders, headless=headless) + lora_tools = LoRATools( + train_data_dir=source_model.train_data_dir, + reg_data_dir=folders.reg_data_dir, + output_dir=folders.output_dir, + logging_dir=folders.logging_dir, + headless=headless + ) with gr.Tab("Guides"): gr.Markdown("This section provide Various LoRA guides and information...") @@ -2014,7 +2032,7 @@ def update_LoRA_settings( gr.Markdown(guides_top_level) return ( - folders.train_data_dir, + source_model.train_data_dir, folders.reg_data_dir, folders.output_dir, folders.logging_dir, diff --git a/kohya_gui/manual_caption_gui.py b/kohya_gui/manual_caption_gui.py index c7d108b12..c031da91c 100644 --- a/kohya_gui/manual_caption_gui.py +++ b/kohya_gui/manual_caption_gui.py @@ -1,6 +1,6 @@ import gradio as gr from easygui import msgbox, boolbox -from .common_gui import get_folder_path +from .common_gui import get_folder_path, scriptdir, list_dirs from math import ceil import os import re @@ -263,7 +263,19 @@ def update_images( # Gradio UI -def gradio_manual_caption_gui_tab(headless=False): +def gradio_manual_caption_gui_tab(headless=False, default_images_dir=None): + from .common_gui import create_refresh_button + + default_images_dir = default_images_dir if default_images_dir is not None else os.path.join(scriptdir, "data") + current_images_dir = default_images_dir + + # Function to list directories + def list_images_dirs(path): + # Allows list_images_dirs to modify current_images_dir outside of this function + nonlocal current_images_dir + current_images_dir = path + return list(list_dirs(path)) + with gr.Tab('Manual Captioning'): gr.Markdown( 'This utility allows quick captioning and tagging of images.' @@ -271,21 +283,24 @@ def gradio_manual_caption_gui_tab(headless=False): page = gr.Number(-1, visible=False) max_page = gr.Number(1, visible=False) loaded_images_dir = gr.Text(visible=False) - with gr.Row(): - images_dir = gr.Textbox( - label='Image folder to caption', - placeholder='Directory containing the images to caption', + with gr.Group(), gr.Row(): + images_dir = gr.Dropdown( + label='Image folder to caption (containing the images to caption)', + choices=list_images_dirs(default_images_dir), + value="", interactive=True, + allow_custom_value=True, ) + create_refresh_button(images_dir, lambda: None, lambda: {"choices": list_images_dirs(current_images_dir)},"open_folder_small") folder_button = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) folder_button.click( get_folder_path, outputs=images_dir, show_progress=False, ) - load_images_button = gr.Button('Load 💾', elem_id='open_folder') + load_images_button = gr.Button('Load', elem_id='open_folder') caption_ext = gr.Textbox( label='Caption file extension', placeholder='Extension for caption file. eg: .caption, .txt', @@ -296,14 +311,21 @@ def gradio_manual_caption_gui_tab(headless=False): label='Autosave', info='Options', value=True, interactive=True ) + images_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_images_dirs(path)), + inputs=images_dir, + outputs=images_dir, + show_progress=False, + ) + # Caption Section - with gr.Row(): + with gr.Group(), gr.Row(): quick_tags_text = gr.Textbox( label='Quick Tags', placeholder='Comma separated list of tags', interactive=True, ) - import_tags_button = gr.Button('Import 📄', elem_id='open_folder') + import_tags_button = gr.Button('Import', elem_id='open_folder') ignore_load_tags_word_count = gr.Slider( minimum=1, maximum=100, @@ -364,7 +386,7 @@ def render_pagination(): [], label='Tags', interactive=True ) save_button = gr.Button( - '💾', elem_id='open_folder_small', visible=False + '💾', elem_id='open_folder_small', elem_classes=['tool'], visible=False ) save_buttons.append(save_button) diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index c1140bcc8..c4a7f0a70 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -8,7 +8,7 @@ from easygui import msgbox # Local module imports -from .common_gui import get_saveasfilename_path, get_file_path, scriptdir +from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button from .custom_logging import setup_logging # Set up logging @@ -57,6 +57,37 @@ def load_inputs_from_json(self, file_path): return inputs def build_tab(self): + current_sd_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + current_a_model_dir = current_sd_model_dir + current_b_model_dir = current_sd_model_dir + current_c_model_dir = current_sd_model_dir + current_d_model_dir = current_sd_model_dir + + def list_sd_models(path): + current_sd_model_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_a_models(path): + current_a_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_b_models(path): + current_b_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_c_models(path): + current_c_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_d_models(path): + current_d_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + with gr.Tab('Merge LoRA'): gr.Markdown( 'This utility can merge up to 4 LoRA together or alternatively merge up to 4 LoRA into a SD checkpoint.' @@ -67,16 +98,19 @@ def build_tab(self): ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) ckpt_ext_name = gr.Textbox(value='SD model types', visible=False) - with gr.Row(): - sd_model = gr.Textbox( - label='SD Model', - placeholder='(Optional) Stable Diffusion model', + with gr.Group(), gr.Row(): + sd_model = gr.Dropdown( + label='SD Model (Optional. Stable Diffusion model path, if you want to merge it with LoRA files)', interactive=True, - info='Provide a SD file path IF you want to merge it with LoRA files', + choices=list_sd_models(current_sd_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(sd_model, lambda: None, lambda: {"choices": list_sd_models(current_sd_model_dir)}, "open_folder_small") sd_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) sd_model_file.click( @@ -87,15 +121,26 @@ def build_tab(self): ) sdxl_model = gr.Checkbox(label='SDXL model', value=False) - with gr.Row(): - lora_a_model = gr.Textbox( - label='LoRA model "A"', - placeholder='Path to the LoRA A model', + sd_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_sd_models(path)), + inputs=sd_model, + outputs=sd_model, + show_progress=False, + ) + + with gr.Group(), gr.Row(): + lora_a_model = gr.Dropdown( + label='LoRA model "A" (path to the LoRA A model)', interactive=True, + choices=list_a_models(current_a_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small") button_lora_a_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) button_lora_a_model_file.click( @@ -105,14 +150,18 @@ def build_tab(self): show_progress=False, ) - lora_b_model = gr.Textbox( - label='LoRA model "B"', - placeholder='Path to the LoRA B model', + lora_b_model = gr.Dropdown( + label='LoRA model "B" (path to the LoRA B model)', interactive=True, + choices=list_b_models(current_b_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small") button_lora_b_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) button_lora_b_model_file.click( @@ -122,6 +171,19 @@ def build_tab(self): show_progress=False, ) + lora_a_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_a_models(path)), + inputs=lora_a_model, + outputs=lora_a_model, + show_progress=False, + ) + lora_b_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_b_models(path)), + inputs=lora_b_model, + outputs=lora_b_model, + show_progress=False, + ) + with gr.Row(): ratio_a = gr.Slider( label='Model A merge ratio (eg: 0.5 mean 50%)', @@ -141,15 +203,19 @@ def build_tab(self): interactive=True, ) - with gr.Row(): - lora_c_model = gr.Textbox( - label='LoRA model "C"', - placeholder='Path to the LoRA C model', + with gr.Group(), gr.Row(): + lora_c_model = gr.Dropdown( + label='LoRA model "C" (path to the LoRA C model)', interactive=True, + choices=list_c_models(current_c_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small") button_lora_c_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) button_lora_c_model_file.click( @@ -159,14 +225,18 @@ def build_tab(self): show_progress=False, ) - lora_d_model = gr.Textbox( - label='LoRA model "D"', - placeholder='Path to the LoRA D model', + lora_d_model = gr.Dropdown( + label='LoRA model "D" (path to the LoRA D model)', interactive=True, + choices=list_d_models(current_d_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small") button_lora_d_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) button_lora_d_model_file.click( @@ -175,6 +245,18 @@ def build_tab(self): outputs=lora_d_model, show_progress=False, ) + lora_c_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_c_models(path)), + inputs=lora_c_model, + outputs=lora_c_model, + show_progress=False, + ) + lora_d_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_d_models(path)), + inputs=lora_d_model, + outputs=lora_d_model, + show_progress=False, + ) with gr.Row(): ratio_c = gr.Slider( @@ -195,15 +277,19 @@ def build_tab(self): interactive=True, ) - with gr.Row(): - save_to = gr.Textbox( - label='Save to', - placeholder='path for the file to save...', + with gr.Group(), gr.Row(): + save_to = gr.Dropdown( + label='Save to (path for the file to save...)', interactive=True, + choices=list_save_to(current_d_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_save_to = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not self.headless), ) button_save_to.click( @@ -212,19 +298,26 @@ def build_tab(self): outputs=save_to, show_progress=False, ) - precision = gr.Dropdown( + precision = gr.Radio( label='Merge precision', choices=['fp16', 'bf16', 'float'], value='float', interactive=True, ) - save_precision = gr.Dropdown( + save_precision = gr.Radio( label='Save precision', choices=['fp16', 'bf16', 'float'], value='fp16', interactive=True, ) + save_to.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) + merge_button = gr.Button('Merge model') merge_button.click( @@ -286,19 +379,19 @@ def merge_lora( return if not sdxl_model: - run_cmd = fr'{PYTHON} "{scriptdir}/networks/merge_lora.py"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/networks/merge_lora.py"' else: run_cmd = ( - fr'{PYTHON} "{scriptdir}/networks/sdxl_merge_lora.py"' + fr'{PYTHON} "{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"' ) if sd_model: - run_cmd += f' --sd_model "{sd_model}"' + run_cmd += fr' --sd_model "{sd_model}"' run_cmd += f' --save_precision {save_precision}' run_cmd += f' --precision {precision}' - run_cmd += f' --save_to "{save_to}"' + run_cmd += fr' --save_to "{save_to}"' # Create a space-separated string of non-empty models (from the second element onwards), enclosed in double quotes - models_cmd = ' '.join([f'"{model}"' for model in lora_models if model]) + models_cmd = ' '.join([fr'"{model}"' for model in lora_models if model]) # Create a space-separated string of non-zero ratios corresponding to non-empty LoRa models valid_ratios = [ @@ -313,7 +406,7 @@ def merge_lora( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) diff --git a/kohya_gui/merge_lycoris_gui.py b/kohya_gui/merge_lycoris_gui.py index a55b93562..99713b975 100644 --- a/kohya_gui/merge_lycoris_gui.py +++ b/kohya_gui/merge_lycoris_gui.py @@ -7,6 +7,8 @@ get_saveasfilename_path, get_file_path, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -35,9 +37,9 @@ def merge_lycoris( log.info('Merge model...') run_cmd = fr'{PYTHON} "{scriptdir}/tools/merge_lycoris.py"' - run_cmd += f' "{base_model}"' - run_cmd += f' "{lycoris_model}"' - run_cmd += f' "{output_name}"' + run_cmd += fr' "{base_model}"' + run_cmd += fr' "{lycoris_model}"' + run_cmd += fr' "{output_name}"' run_cmd += f' --weight {weight}' run_cmd += f' --device {device}' run_cmd += f' --dtype {dtype}' @@ -49,7 +51,7 @@ def merge_lycoris( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -63,6 +65,22 @@ def merge_lycoris( def gradio_merge_lycoris_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + current_lycoris_dir = current_model_dir + current_save_dir = current_model_dir + + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_lycoris_model(path): + current_lycoris_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + with gr.Tab('Merge LyCORIS'): gr.Markdown( 'This utility can merge a LyCORIS model into a SD checkpoint.' @@ -73,16 +91,20 @@ def gradio_merge_lycoris_tab(headless=False): ckpt_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) ckpt_ext_name = gr.Textbox(value='SD model types', visible=False) - with gr.Row(): - base_model = gr.Textbox( - label='SD Model', - placeholder='(Optional) Stable Diffusion base model', + with gr.Group(), gr.Row(): + base_model = gr.Dropdown( + label='SD Model (Optional Stable Diffusion base model)', interactive=True, info='Provide a SD file path that you want to merge with the LyCORIS file', + choices=list_models(current_save_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(base_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") base_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) base_model_file.click( @@ -92,15 +114,14 @@ def gradio_merge_lycoris_tab(headless=False): show_progress=False, ) - with gr.Row(): - lycoris_model = gr.Textbox( - label='LyCORIS model', - placeholder='Path to the LyCORIS model', + lycoris_model = gr.Dropdown( + label='LyCORIS model (path to the LyCORIS model)', interactive=True, ) button_lycoris_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lycoris_model_file.click( @@ -110,6 +131,19 @@ def gradio_merge_lycoris_tab(headless=False): show_progress=False, ) + base_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=base_model, + outputs=base_model, + show_progress=False, + ) + lycoris_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_lycoris_models(path)), + inputs=lycoris_model, + outputs=lycoris_model, + show_progress=False, + ) + with gr.Row(): weight = gr.Slider( label='Model A merge ratio (eg: 0.5 mean 50%)', @@ -120,15 +154,19 @@ def gradio_merge_lycoris_tab(headless=False): interactive=True, ) - with gr.Row(): - output_name = gr.Textbox( - label='Save to', - placeholder='path for the checkpoint file to save...', + with gr.Group(), gr.Row(): + output_name = gr.Dropdown( + label='Save to (path for the checkpoint file to save...)', interactive=True, + choices=list_save_to(current_save_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(output_name, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_output_name = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_output_name.click( @@ -137,7 +175,7 @@ def gradio_merge_lycoris_tab(headless=False): outputs=output_name, show_progress=False, ) - dtype = gr.Dropdown( + dtype = gr.Radio( label='Save dtype', choices=[ 'float', @@ -151,7 +189,7 @@ def gradio_merge_lycoris_tab(headless=False): interactive=True, ) - device = gr.Dropdown( + device = gr.Radio( label='Device', choices=[ 'cpu', @@ -161,6 +199,14 @@ def gradio_merge_lycoris_tab(headless=False): interactive=True, ) + output_name.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=output_name, + outputs=output_name, + show_progress=False, + ) + + with gr.Row(): is_sdxl = gr.Checkbox(label='is sdxl', value=False, interactive=True) is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) diff --git a/kohya_gui/resize_lora_gui.py b/kohya_gui/resize_lora_gui.py index ff64be6c3..d729531c3 100644 --- a/kohya_gui/resize_lora_gui.py +++ b/kohya_gui/resize_lora_gui.py @@ -3,7 +3,7 @@ import subprocess import os import sys -from .common_gui import get_saveasfilename_path, get_file_path, scriptdir +from .common_gui import get_saveasfilename_path, get_file_path, scriptdir, list_files, create_refresh_button from .custom_logging import setup_logging @@ -59,10 +59,10 @@ def resize_lora( if device == '': device = 'cuda' - run_cmd = fr'{PYTHON} "{scriptdir}/networks/resize_lora.py"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/networks/resize_lora.py"' run_cmd += f' --save_precision {save_precision}' - run_cmd += f' --save_to "{save_to}"' - run_cmd += f' --model "{model}"' + run_cmd += fr' --save_to "{save_to}"' + run_cmd += fr' --model "{model}"' run_cmd += f' --new_rank {new_rank}' run_cmd += f' --device {device}' if not dynamic_method == 'None': @@ -74,7 +74,7 @@ def resize_lora( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -88,21 +88,36 @@ def resize_lora( def gradio_resize_lora_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + current_save_dir = os.path.join(scriptdir, "outputs") + + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".ckpt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + with gr.Tab('Resize LoRA'): gr.Markdown('This utility can resize a LoRA.') lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - with gr.Row(): - model = gr.Textbox( - label='Source LoRA', - placeholder='Path to the LoRA to resize', + with gr.Group(), gr.Row(): + model = gr.Dropdown( + label='Source LoRA (path to the LoRA to resize)', interactive=True, + choices=list_models(current_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") button_lora_a_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lora_a_model_file.click( @@ -111,15 +126,18 @@ def gradio_resize_lora_tab(headless=False): outputs=model, show_progress=False, ) - with gr.Row(): - save_to = gr.Textbox( - label='Save to', - placeholder='path for the LoRA file to save...', + save_to = gr.Dropdown( + label='Save to (path for the LoRA file to save...)', interactive=True, + choices=list_save_to(current_save_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_save_to = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_save_to.click( @@ -128,6 +146,18 @@ def gradio_resize_lora_tab(headless=False): outputs=save_to, show_progress=False, ) + model.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=model, + outputs=model, + show_progress=False, + ) + save_to.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) with gr.Row(): new_rank = gr.Slider( label='Desired LoRA rank', @@ -137,7 +167,7 @@ def gradio_resize_lora_tab(headless=False): value=4, interactive=True, ) - dynamic_method = gr.Dropdown( + dynamic_method = gr.Radio( choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'], value='sv_fro', label='Dynamic method', @@ -152,13 +182,13 @@ def gradio_resize_lora_tab(headless=False): with gr.Row(): verbose = gr.Checkbox(label='Verbose', value=True) - save_precision = gr.Dropdown( + save_precision = gr.Radio( label='Save precision', choices=['fp16', 'bf16', 'float'], value='fp16', interactive=True, ) - device = gr.Dropdown( + device = gr.Radio( label='Device', choices=[ 'cpu', diff --git a/kohya_gui/svd_merge_lora_gui.py b/kohya_gui/svd_merge_lora_gui.py index cd5aa9878..311869781 100644 --- a/kohya_gui/svd_merge_lora_gui.py +++ b/kohya_gui/svd_merge_lora_gui.py @@ -8,6 +8,8 @@ get_any_file_path, get_file_path, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -51,10 +53,10 @@ def svd_merge_lora( ratio_c /= total_ratio ratio_d /= total_ratio - run_cmd = fr'{PYTHON} "{scriptdir}/networks/svd_merge_lora.py"' + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"' run_cmd += f' --save_precision {save_precision}' run_cmd += f' --precision {precision}' - run_cmd += f' --save_to "{save_to}"' + run_cmd += fr' --save_to "{save_to}"' run_cmd_models = ' --models' run_cmd_ratios = ' --ratios' @@ -63,25 +65,25 @@ def svd_merge_lora( if not os.path.isfile(lora_a_model): msgbox('The provided model A is not a file') return - run_cmd_models += f' "{lora_a_model}"' + run_cmd_models += fr' "{lora_a_model}"' run_cmd_ratios += f' {ratio_a}' if lora_b_model: if not os.path.isfile(lora_b_model): msgbox('The provided model B is not a file') return - run_cmd_models += f' "{lora_b_model}"' + run_cmd_models += fr' "{lora_b_model}"' run_cmd_ratios += f' {ratio_b}' if lora_c_model: if not os.path.isfile(lora_c_model): msgbox('The provided model C is not a file') return - run_cmd_models += f' "{lora_c_model}"' + run_cmd_models += fr' "{lora_c_model}"' run_cmd_ratios += f' {ratio_c}' if lora_d_model: if not os.path.isfile(lora_d_model): msgbox('The provided model D is not a file') return - run_cmd_models += f' "{lora_d_model}"' + run_cmd_models += fr' "{lora_d_model}"' run_cmd_ratios += f' {ratio_d}' run_cmd += run_cmd_models @@ -93,7 +95,7 @@ def svd_merge_lora( log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -105,6 +107,32 @@ def svd_merge_lora( def gradio_svd_merge_lora_tab(headless=False): + current_save_dir = os.path.join(scriptdir, "outputs") + current_a_model_dir = current_save_dir + current_b_model_dir = current_save_dir + current_c_model_dir = current_save_dir + current_d_model_dir = current_save_dir + + def list_a_models(path): + current_a_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_b_models(path): + current_b_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_c_models(path): + current_c_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_d_models(path): + current_d_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + + def list_save_to(path): + current_save_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + with gr.Tab('Merge LoRA (SVD)'): gr.Markdown( 'This utility can merge two LoRA networks together into a new LoRA.' @@ -113,15 +141,19 @@ def gradio_svd_merge_lora_tab(headless=False): lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - with gr.Row(): - lora_a_model = gr.Textbox( - label='LoRA model "A"', - placeholder='Path to the LoRA A model', + with gr.Group(), gr.Row(): + lora_a_model = gr.Dropdown( + label='LoRA model "A" (path to the LoRA A model)', interactive=True, + choices=list_a_models(current_a_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_a_model, lambda: None, lambda: {"choices": list_a_models(current_a_model_dir)}, "open_folder_small") button_lora_a_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lora_a_model_file.click( @@ -131,14 +163,18 @@ def gradio_svd_merge_lora_tab(headless=False): show_progress=False, ) - lora_b_model = gr.Textbox( - label='LoRA model "B"', - placeholder='Path to the LoRA B model', + lora_b_model = gr.Dropdown( + label='LoRA model "B" (path to the LoRA B model)', interactive=True, + choices=list_b_models(current_b_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_b_model, lambda: None, lambda: {"choices": list_b_models(current_b_model_dir)}, "open_folder_small") button_lora_b_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lora_b_model_file.click( @@ -147,6 +183,18 @@ def gradio_svd_merge_lora_tab(headless=False): outputs=lora_b_model, show_progress=False, ) + lora_a_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_a_models(path)), + inputs=lora_a_model, + outputs=lora_a_model, + show_progress=False, + ) + lora_b_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_b_models(path)), + inputs=lora_b_model, + outputs=lora_b_model, + show_progress=False, + ) with gr.Row(): ratio_a = gr.Slider( label='Merge ratio model A', @@ -164,15 +212,19 @@ def gradio_svd_merge_lora_tab(headless=False): value=0.25, interactive=True, ) - with gr.Row(): - lora_c_model = gr.Textbox( - label='LoRA model "C"', - placeholder='Path to the LoRA C model', + with gr.Group(), gr.Row(): + lora_c_model = gr.Dropdown( + label='LoRA model "C" (path to the LoRA C model)', interactive=True, + choices=list_c_models(current_c_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_c_model, lambda: None, lambda: {"choices": list_c_models(current_c_model_dir)}, "open_folder_small") button_lora_c_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lora_c_model_file.click( @@ -182,14 +234,18 @@ def gradio_svd_merge_lora_tab(headless=False): show_progress=False, ) - lora_d_model = gr.Textbox( - label='LoRA model "D"', - placeholder='Path to the LoRA D model', + lora_d_model = gr.Dropdown( + label='LoRA model "D" (path to the LoRA D model)', interactive=True, + choices=list_d_models(current_d_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_d_model, lambda: None, lambda: {"choices": list_d_models(current_d_model_dir)}, "open_folder_small") button_lora_d_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lora_d_model_file.click( @@ -198,6 +254,19 @@ def gradio_svd_merge_lora_tab(headless=False): outputs=lora_d_model, show_progress=False, ) + + lora_c_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_c_models(path)), + inputs=lora_c_model, + outputs=lora_c_model, + show_progress=False, + ) + lora_d_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_d_models(path)), + inputs=lora_d_model, + outputs=lora_d_model, + show_progress=False, + ) with gr.Row(): ratio_c = gr.Slider( label='Merge ratio model C', @@ -233,15 +302,19 @@ def gradio_svd_merge_lora_tab(headless=False): interactive=True, ) - with gr.Row(): - save_to = gr.Textbox( - label='Save to', - placeholder='path for the new LoRA file to save...', + with gr.Group(), gr.Row(): + save_to = gr.Dropdown( + label='Save to (path for the new LoRA file to save...)', interactive=True, + choices=list_save_to(current_d_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(save_to, lambda: None, lambda: {"choices": list_save_to(current_save_dir)}, "open_folder_small") button_save_to = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_save_to.click( @@ -250,20 +323,26 @@ def gradio_svd_merge_lora_tab(headless=False): outputs=save_to, show_progress=False, ) - with gr.Row(): - precision = gr.Dropdown( + save_to.change( + fn=lambda path: gr.Dropdown().update(choices=list_save_to(path)), + inputs=save_to, + outputs=save_to, + show_progress=False, + ) + with gr.Group(), gr.Row(): + precision = gr.Radio( label='Merge precision', choices=['fp16', 'bf16', 'float'], value='float', interactive=True, ) - save_precision = gr.Dropdown( + save_precision = gr.Radio( label='Save precision', choices=['fp16', 'bf16', 'float'], value='float', interactive=True, ) - device = gr.Dropdown( + device = gr.Radio( label='Device', choices=[ 'cpu', diff --git a/kohya_gui/tensorboard_gui.py b/kohya_gui/tensorboard_gui.py index cf8dcf110..7f011408b 100644 --- a/kohya_gui/tensorboard_gui.py +++ b/kohya_gui/tensorboard_gui.py @@ -22,11 +22,12 @@ def start_tensorboard(headless, logging_dir, wait_time=5): # Read the TENSORBOARD_PORT from the environment, or use the default tensorboard_port = os.environ.get('TENSORBOARD_PORT', DEFAULT_TENSORBOARD_PORT) - - if not os.listdir(logging_dir): - log.info('Error: log folder is empty') - msgbox(msg='Error: log folder is empty') - return + + # Check if logging directory exists and is not empty; if not, warn the user and exit + if not os.path.exists(logging_dir) or not os.listdir(logging_dir): + log.error('Error: logging folder does not exist or does not contain logs.') + msgbox(msg='Error: logging folder does not exist or does not contain logs.') + return # Exit the function with an error code run_cmd = [ TENSORBOARD, @@ -74,7 +75,7 @@ def stop_tensorboard(): except Exception as e: log.error('Failed to stop Tensorboard:', e) else: - log.info('Tensorboard is not running...') + log.warning('Tensorboard is not running...') def gradio_tensorboard(): diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index b8ee8846f..734fb5692 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -23,6 +23,8 @@ SaveConfigFile, save_to_file, scriptdir, + list_files, + create_refresh_button, ) from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel @@ -556,9 +558,9 @@ def train_model( ) if sdxl: - run_cmd += fr' "{scriptdir}/sdxl_train_textual_inversion.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/sdxl_train_textual_inversion.py"' else: - run_cmd += fr' "{scriptdir}/train_textual_inversion.py"' + run_cmd += fr' "{scriptdir}/sd-scripts/train_textual_inversion.py"' run_cmd += run_cmd_advanced_training( adaptive_noise_scale=adaptive_noise_scale, @@ -665,7 +667,8 @@ def train_model( # Saving config file for model current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + config_dir = os.path.dirname(os.path.dirname(train_data_dir)) + file_path = os.path.join(config_dir, f"{output_name}_{formatted_datetime}.json") log.info(f"Saving training config to {file_path}...") @@ -677,12 +680,15 @@ def train_model( log.info(run_cmd) + env = os.environ.copy() + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + # Run the command - executor.execute_command(run_cmd=run_cmd) + executor.execute_command(run_cmd=run_cmd, env=env) # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f"{output_dir}/{output_name}") + last_dir = pathlib.Path(fr"{output_dir}/{output_name}") if not last_dir.is_dir(): # Copy inference model for v2 if required @@ -691,37 +697,48 @@ def train_model( def ti_tab( headless=False, + default_output_dir=None, ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab("Training"): + current_embedding_dir = default_output_dir if default_output_dir is not None and default_output_dir != "" else os.path.join(scriptdir, "outputs") + + with gr.Tab("Training"), gr.Column(variant="compact"): gr.Markdown("Train a TI using kohya textual inversion python code...") - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel( - save_model_as_choices=[ - "ckpt", - "safetensors", - ], - headless=headless, - ) + with gr.Column(): + source_model = SourceModel( + save_model_as_choices=[ + "ckpt", + "safetensors", + ], + headless=headless, + ) - with gr.Tab("Folders"): + with gr.Accordion("Folders", open=False), gr.Group(): folders = Folders(headless=headless) - with gr.Tab("Parameters"): - with gr.Tab("Basic", elem_id="basic_tab"): + with gr.Accordion("Parameters", open=False), gr.Column(): + with gr.Group(elem_id="basic_tab"): with gr.Row(): - weights = gr.Textbox( - label='Resume TI training', - placeholder='(Optional) Path to existing TI embedding file to keep training', + + def list_embedding_files(path): + current_embedding_dir = path + return list(list_files(path, exts=[".pt", ".ckpt", ".safetensors" ], all=True)) + + weights = gr.Dropdown( + label='Resume TI training (Optional. Path to existing TI embedding file to keep training)', + choices=list_embedding_files(current_embedding_dir), + value="", + interactive=True, + allow_custom_value=True, ) + create_refresh_button(weights, lambda: None, lambda: {"choices": list_embedding_files(current_embedding_dir)}, "open_folder_small") weights_file_input = gr.Button( "📂", elem_id="open_folder_small", + elem_classes=['tool'], visible=(not headless), ) weights_file_input.click( @@ -729,6 +746,13 @@ def ti_tab( outputs=weights, show_progress=False, ) + weights.change( + fn=lambda path: gr.Dropdown().update(choices=list_embedding_files(path)), + inputs=weights, + outputs=weights, + show_progress=False, + ) + with gr.Row(): token_string = gr.Textbox( label="Token string", @@ -771,7 +795,7 @@ def ti_tab( show_sdxl_cache_text_encoder_outputs=False, ) - with gr.Tab("Advanced", elem_id="advanced_tab"): + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless) advanced_training.color_aug.change( color_aug_changed, @@ -779,15 +803,15 @@ def ti_tab( outputs=[basic_training.cache_latents], ) - with gr.Tab("Samples", elem_id="samples_tab"): + with gr.Accordion("Samples", open=False, elem_id="samples_tab"): sample = SampleImages() - with gr.Tab("Dataset Preparation"): + with gr.Accordion("Dataset Preparation", open=False): gr.Markdown( "This section provide Dreambooth tools to help setup your dataset..." ) gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, + train_data_dir_input=source_model.train_data_dir, reg_data_dir_input=folders.reg_data_dir, output_dir_input=folders.output_dir, logging_dir_input=folders.logging_dir, @@ -795,18 +819,26 @@ def ti_tab( ) gradio_dataset_balancing_tab(headless=headless) - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") - button_stop_training = gr.Button("Stop training") + # Setup Configuration Files Gradio + with gr.Accordion("Configuration", open=False): + config = ConfigurationFile(headless=headless, output_dir=folders.output_dir) + + + with gr.Column(), gr.Group(): + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") + + button_stop_training = gr.Button("Stop training") - button_print = gr.Button("Print training command") + button_print = gr.Button("Print training command") # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() + with gr.Column(), gr.Group(): + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() button_start_tensorboard.click( start_tensorboard, @@ -825,7 +857,7 @@ def ti_tab( source_model.v_parameterization, source_model.sdxl_checkbox, folders.logging_dir, - folders.train_data_dir, + source_model.train_data_dir, folders.reg_data_dir, folders.output_dir, basic_training.max_resolution, @@ -836,7 +868,7 @@ def ti_tab( basic_training.epoch, basic_training.save_every_n_epochs, basic_training.mixed_precision, - basic_training.save_precision, + source_model.save_precision, basic_training.seed, basic_training.num_cpu_threads_per_process, basic_training.cache_latents, @@ -863,7 +895,7 @@ def ti_tab( advanced_training.multi_gpu, advanced_training.gpu_ids, advanced_training.vae, - folders.output_name, + source_model.output_name, advanced_training.max_token_length, basic_training.max_train_epochs, advanced_training.max_data_loader_n_workers, @@ -933,12 +965,12 @@ def ti_tab( show_progress=False, ) - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) + #config.button_save_as_config.click( + # save_configuration, + # inputs=[dummy_db_true, config.config_file_name] + settings_list, + # outputs=[config.config_file_name], + # show_progress=False, + #) button_run.click( train_model, @@ -955,7 +987,7 @@ def ti_tab( ) return ( - folders.train_data_dir, + source_model.train_data_dir, folders.reg_data_dir, folders.output_dir, folders.logging_dir, diff --git a/kohya_gui/utilities.py b/kohya_gui/utilities.py index 887de8be9..3125e45b9 100644 --- a/kohya_gui/utilities.py +++ b/kohya_gui/utilities.py @@ -16,10 +16,10 @@ def utilities_tab( - train_data_dir_input=gr.Textbox(), - reg_data_dir_input=gr.Textbox(), - output_dir_input=gr.Textbox(), - logging_dir_input=gr.Textbox(), + train_data_dir_input=gr.Dropdown(), + reg_data_dir_input=gr.Dropdown(), + output_dir_input=gr.Dropdown(), + logging_dir_input=gr.Dropdown(), enable_copy_info_button=bool(False), enable_dreambooth_tab=True, headless=False diff --git a/kohya_gui/verify_lora_gui.py b/kohya_gui/verify_lora_gui.py index bde13b491..af9b2ba02 100644 --- a/kohya_gui/verify_lora_gui.py +++ b/kohya_gui/verify_lora_gui.py @@ -8,6 +8,8 @@ get_any_file_path, get_file_path, scriptdir, + list_files, + create_refresh_button, ) from .custom_logging import setup_logging @@ -36,16 +38,12 @@ def verify_lora( msgbox('The provided model A is not a file') return - run_cmd = [ - PYTHON, - fr'"{scriptdir}/networks/check_lora_weights.py"', - f'{lora_model}', - ] - - log.info(' '.join(run_cmd)) + run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/networks/check_lora_weights.py" "{lora_model}"' + + log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command process = subprocess.Popen( @@ -62,6 +60,12 @@ def verify_lora( def gradio_verify_lora_tab(headless=False): + current_model_dir = os.path.join(scriptdir, "outputs") + + def list_models(path): + current_model_dir = path + return list(list_files(path, exts=[".pt", ".safetensors"], all=True)) + with gr.Tab('Verify LoRA'): gr.Markdown( 'This utility can verify a LoRA network to make sure it is properly trained.' @@ -70,15 +74,19 @@ def gradio_verify_lora_tab(headless=False): lora_ext = gr.Textbox(value='*.pt *.safetensors', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) - with gr.Row(): - lora_model = gr.Textbox( - label='LoRA model', - placeholder='Path to the LoRA model to verify', + with gr.Group(), gr.Row(): + lora_model = gr.Dropdown( + label='LoRA model (path to the LoRA model to verify)', interactive=True, + choices=list_models(current_model_dir), + value="", + allow_custom_value=True, ) + create_refresh_button(lora_model, lambda: None, lambda: {"choices": list_models(current_model_dir)}, "open_folder_small") button_lora_model_file = gr.Button( folder_symbol, elem_id='open_folder_small', + elem_classes=['tool'], visible=(not headless), ) button_lora_model_file.click( @@ -89,6 +97,13 @@ def gradio_verify_lora_tab(headless=False): ) verify_button = gr.Button('Verify', variant='primary') + lora_model.change( + fn=lambda path: gr.Dropdown().update(choices=list_models(path)), + inputs=lora_model, + outputs=lora_model, + show_progress=False, + ) + lora_model_verif_output = gr.Textbox( label='Output', placeholder='Verification output', diff --git a/kohya_gui/wd14_caption_gui.py b/kohya_gui/wd14_caption_gui.py index 53cc020b4..2c2c34ece 100644 --- a/kohya_gui/wd14_caption_gui.py +++ b/kohya_gui/wd14_caption_gui.py @@ -1,7 +1,7 @@ import gradio as gr from easygui import msgbox import subprocess -from .common_gui import get_folder_path, add_pre_postfix, scriptdir +from .common_gui import get_folder_path, add_pre_postfix, scriptdir, list_dirs import os from .custom_logging import setup_logging @@ -40,7 +40,7 @@ def caption_images( return log.info(f'Captioning files in {train_data_dir}...') - run_cmd = fr'accelerate launch "{scriptdir}/finetune/tag_images_by_wd14_tagger.py"' + run_cmd = fr'accelerate launch "{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py"' run_cmd += f' --batch_size={int(batch_size)}' run_cmd += f' --general_threshold={general_threshold}' run_cmd += f' --character_threshold={character_threshold}' @@ -68,12 +68,12 @@ def caption_images( if not undesired_tags == '': run_cmd += f' --undesired_tags="{undesired_tags}"' - run_cmd += f' "{train_data_dir}"' + run_cmd += fr' "{train_data_dir}"' log.info(run_cmd) env = os.environ.copy() - env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{env.get('PYTHONPATH', '')}" + env['PYTHONPATH'] = fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" # Run the command subprocess.run(run_cmd, shell=True, env=env) @@ -94,7 +94,16 @@ def caption_images( ### -def gradio_wd14_caption_gui_tab(headless=False): +def gradio_wd14_caption_gui_tab(headless=False, default_train_dir=None): + from .common_gui import create_refresh_button + + default_train_dir = default_train_dir if default_train_dir is not None else os.path.join(scriptdir, "data") + current_train_dir = default_train_dir + + def list_train_dirs(path): + current_train_dir = path + return list(list_dirs(path)) + with gr.Tab('WD14 Captioning'): gr.Markdown( 'This utility will use WD14 to caption files for each images in a folder.' @@ -102,14 +111,17 @@ def gradio_wd14_caption_gui_tab(headless=False): # Input Settings # with gr.Section('Input Settings'): - with gr.Row(): - train_data_dir = gr.Textbox( - label='Image folder to caption', - placeholder='Directory containing the images to caption', + with gr.Group(), gr.Row(): + train_data_dir = gr.Dropdown( + label='Image folder to caption (containing the images to caption)', + choices=list_train_dirs(default_train_dir), + value="", interactive=True, + allow_custom_value=True, ) + create_refresh_button(train_data_dir, lambda: None, lambda: {"choices": list_train_dir(current_train_dir)},"open_folder_small") button_train_data_dir_input = gr.Button( - '📂', elem_id='open_folder_small', visible=(not headless) + '📂', elem_id='open_folder_small', elem_classes=['tool'], visible=(not headless) ) button_train_data_dir_input.click( get_folder_path, @@ -259,3 +271,10 @@ def gradio_wd14_caption_gui_tab(headless=False): ], show_progress=False, ) + + train_data_dir.change( + fn=lambda path: gr.Dropdown().update(choices=list_train_dirs(path)), + inputs=train_data_dir, + outputs=train_data_dir, + show_progress=False, + ) diff --git a/library/__init__.py b/library/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/library/attention_processors.py b/library/attention_processors.py deleted file mode 100644 index 310c2cb1c..000000000 --- a/library/attention_processors.py +++ /dev/null @@ -1,227 +0,0 @@ -import math -from typing import Any -from einops import rearrange -import torch -from diffusers.models.attention_processor import Attention - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - -EPSILON = 1e-6 - - -class FlashAttentionFunction(torch.autograd.function.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full( - (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device - ) - - scale = q.shape[-1] ** -0.5 - - if mask is None: - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = ( - torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale - ) - - if row_mask is not None: - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones( - (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device - ).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if row_mask is not None: - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( - min=EPSILON - ) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = torch.einsum( - "... i j, ... j d -> ... i d", exp_weights, vc - ) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = ( - exp_row_max_diff * row_sums - + exp_block_row_max_diff * block_row_sums - ) - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( - (exp_block_row_max_diff / new_row_sums) * exp_values - ) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = ( - torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale - ) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones( - (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device - ).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if row_mask is not None: - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) - dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -class FlashAttnProcessor: - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ) -> Any: - q_bucket_size = 512 - k_bucket_size = 1024 - - h = attn.heads - q = attn.to_q(hidden_states) - - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) - encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) - - if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: - context_k, context_v = attn.hypernetwork.forward( - hidden_states, encoder_hidden_states - ) - context_k = context_k.to(hidden_states.dtype) - context_v = context_v.to(hidden_states.dtype) - else: - context_k = encoder_hidden_states - context_v = encoder_hidden_states - - k = attn.to_k(context_k) - v = attn.to_v(context_v) - del encoder_hidden_states, hidden_states - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = FlashAttentionFunction.apply( - q, k, v, attention_mask, False, q_bucket_size, k_bucket_size - ) - - out = rearrange(out, "b h n d -> b n (h d)") - - out = attn.to_out[0](out) - out = attn.to_out[1](out) - return out diff --git a/library/config_util.py b/library/config_util.py deleted file mode 100644 index fc4b36175..000000000 --- a/library/config_util.py +++ /dev/null @@ -1,689 +0,0 @@ -import argparse -from dataclasses import ( - asdict, - dataclass, -) -import functools -import random -from textwrap import dedent, indent -import json -from pathlib import Path - -# from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) - -import toml -import voluptuous -from voluptuous import ( - Any, - ExactSequence, - MultipleInvalid, - Object, - Required, - Schema, -) -from transformers import CLIPTokenizer - -from . import train_util -from .train_util import ( - DreamBoothSubset, - FineTuningSubset, - ControlNetSubset, - DreamBoothDataset, - FineTuningDataset, - ControlNetDataset, - DatasetGroup, -) -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") - - -# TODO: inherit Params class in Subset, Dataset - - -@dataclass -class BaseSubsetParams: - image_dir: Optional[str] = None - num_repeats: int = 1 - shuffle_caption: bool = False - caption_separator: str = (",",) - keep_tokens: int = 0 - keep_tokens_separator: str = (None,) - color_aug: bool = False - flip_aug: bool = False - face_crop_aug_range: Optional[Tuple[float, float]] = None - random_crop: bool = False - caption_prefix: Optional[str] = None - caption_suffix: Optional[str] = None - caption_dropout_rate: float = 0.0 - caption_dropout_every_n_epochs: int = 0 - caption_tag_dropout_rate: float = 0.0 - token_warmup_min: int = 1 - token_warmup_step: float = 0 - - -@dataclass -class DreamBoothSubsetParams(BaseSubsetParams): - is_reg: bool = False - class_tokens: Optional[str] = None - caption_extension: str = ".caption" - - -@dataclass -class FineTuningSubsetParams(BaseSubsetParams): - metadata_file: Optional[str] = None - - -@dataclass -class ControlNetSubsetParams(BaseSubsetParams): - conditioning_data_dir: str = None - caption_extension: str = ".caption" - - -@dataclass -class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - network_multiplier: float = 1.0 - debug_dataset: bool = False - - -@dataclass -class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 - - -@dataclass -class FineTuningDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - - -@dataclass -class ControlNetDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - - -@dataclass -class SubsetBlueprint: - params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] - - -@dataclass -class DatasetBlueprint: - is_dreambooth: bool - is_controlnet: bool - params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] - subsets: Sequence[SubsetBlueprint] - - -@dataclass -class DatasetGroupBlueprint: - datasets: Sequence[DatasetBlueprint] - - -@dataclass -class Blueprint: - dataset_group: DatasetGroupBlueprint - - -class ConfigSanitizer: - # @curry - @staticmethod - def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: - Schema(ExactSequence([klass, klass]))(value) - return tuple(value) - - # @curry - @staticmethod - def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: - Schema(Any(klass, ExactSequence([klass, klass])))(value) - try: - Schema(klass)(value) - return (value, value) - except: - return ConfigSanitizer.__validate_and_convert_twodim(klass, value) - - # subset schema - SUBSET_ASCENDABLE_SCHEMA = { - "color_aug": bool, - "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), - "flip_aug": bool, - "num_repeats": int, - "random_crop": bool, - "shuffle_caption": bool, - "keep_tokens": int, - "keep_tokens_separator": str, - "token_warmup_min": int, - "token_warmup_step": Any(float, int), - "caption_prefix": str, - "caption_suffix": str, - } - # DO means DropOut - DO_SUBSET_ASCENDABLE_SCHEMA = { - "caption_dropout_every_n_epochs": int, - "caption_dropout_rate": Any(float, int), - "caption_tag_dropout_rate": Any(float, int), - } - # DB means DreamBooth - DB_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - "class_tokens": str, - } - DB_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - "is_reg": bool, - } - # FT means FineTuning - FT_SUBSET_DISTINCT_SCHEMA = { - Required("metadata_file"): str, - "image_dir": str, - } - CN_SUBSET_ASCENDABLE_SCHEMA = { - "caption_extension": str, - } - CN_SUBSET_DISTINCT_SCHEMA = { - Required("image_dir"): str, - Required("conditioning_data_dir"): str, - } - - # datasets schema - DATASET_ASCENDABLE_SCHEMA = { - "batch_size": int, - "bucket_no_upscale": bool, - "bucket_reso_steps": int, - "enable_bucket": bool, - "max_bucket_reso": int, - "min_bucket_reso": int, - "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), - "network_multiplier": float, - } - - # options handled by argparse but not handled by user config - ARGPARSE_SPECIFIC_SCHEMA = { - "debug_dataset": bool, - "max_token_length": Any(None, int), - "prior_loss_weight": Any(float, int), - } - # for handling default None value of argparse - ARGPARSE_NULLABLE_OPTNAMES = [ - "face_crop_aug_range", - "resolution", - ] - # prepare map because option name may differ among argparse and user config - ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { - "train_batch_size": "batch_size", - "dataset_repeats": "num_repeats", - } - - def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" - - self.db_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_DISTINCT_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.ft_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.FT_SUBSET_DISTINCT_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.cn_subset_schema = self.__merge_dict( - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_DISTINCT_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.db_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.db_subset_schema]}, - ) - - self.ft_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.ft_subset_schema]}, - ) - - self.cn_dataset_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.CN_SUBSET_ASCENDABLE_SCHEMA, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - {"subsets": [self.cn_subset_schema]}, - ) - - if support_dreambooth and support_finetuning: - - def validate_flex_dataset(dataset_config: dict): - subsets_config = dataset_config.get("subsets", []) - - if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): - return Schema(self.cn_dataset_schema)(dataset_config) - # check dataset meets FT style - # NOTE: all FT subsets should have "metadata_file" - elif all(["metadata_file" in subset for subset in subsets_config]): - return Schema(self.ft_dataset_schema)(dataset_config) - # check dataset meets DB style - # NOTE: all DB subsets should have no "metadata_file" - elif all(["metadata_file" not in subset for subset in subsets_config]): - return Schema(self.db_dataset_schema)(dataset_config) - else: - raise voluptuous.Invalid( - "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" - ) - - self.dataset_schema = validate_flex_dataset - elif support_dreambooth: - self.dataset_schema = self.db_dataset_schema - elif support_finetuning: - self.dataset_schema = self.ft_dataset_schema - elif support_controlnet: - self.dataset_schema = self.cn_dataset_schema - - self.general_schema = self.__merge_dict( - self.DATASET_ASCENDABLE_SCHEMA, - self.SUBSET_ASCENDABLE_SCHEMA, - self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, - self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, - self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, - ) - - self.user_config_validator = Schema( - { - "general": self.general_schema, - "datasets": [self.dataset_schema], - } - ) - - self.argparse_schema = self.__merge_dict( - self.general_schema, - self.ARGPARSE_SPECIFIC_SCHEMA, - {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, - {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, - ) - - self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) - - def sanitize_user_config(self, user_config: dict) -> dict: - try: - return self.user_config_validator(user_config) - except MultipleInvalid: - # TODO: エラー発生時のメッセージをわかりやすくする - logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") - raise - - # NOTE: In nature, argument parser result is not needed to be sanitize - # However this will help us to detect program bug - def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: - try: - return self.argparse_config_validator(argparse_namespace) - except MultipleInvalid: - # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") - raise - - # NOTE: value would be overwritten by latter dict if there is already the same key - @staticmethod - def __merge_dict(*dict_list: dict) -> dict: - merged = {} - for schema in dict_list: - # merged |= schema - for k, v in schema.items(): - merged[k] = v - return merged - - -class BlueprintGenerator: - BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} - - def __init__(self, sanitizer: ConfigSanitizer): - self.sanitizer = sanitizer - - # runtime_params is for parameters which is only configurable on runtime, such as tokenizer - def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: - sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) - sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) - - # convert argparse namespace to dict like config - # NOTE: it is ok to have extra entries in dict - optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME - argparse_config = { - optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() - } - - general_config = sanitized_user_config.get("general", {}) - - dataset_blueprints = [] - for dataset_config in sanitized_user_config.get("datasets", []): - # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets - subsets = dataset_config.get("subsets", []) - is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) - is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) - if is_controlnet: - subset_params_klass = ControlNetSubsetParams - dataset_params_klass = ControlNetDatasetParams - elif is_dreambooth: - subset_params_klass = DreamBoothSubsetParams - dataset_params_klass = DreamBoothDatasetParams - else: - subset_params_klass = FineTuningSubsetParams - dataset_params_klass = FineTuningDatasetParams - - subset_blueprints = [] - for subset_config in subsets: - params = self.generate_params_by_fallbacks( - subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] - ) - subset_blueprints.append(SubsetBlueprint(params)) - - params = self.generate_params_by_fallbacks( - dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] - ) - dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) - - dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) - - return Blueprint(dataset_group_blueprint) - - @staticmethod - def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): - name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME - search_value = BlueprintGenerator.search_value - default_params = asdict(param_klass()) - param_names = default_params.keys() - - params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} - - return param_klass(**params) - - @staticmethod - def search_value(key: str, fallbacks: Sequence[dict], default_value=None): - for cand in fallbacks: - value = cand.get(key) - if value is not None: - return value - - return default_value - - -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): - datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - - for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.is_controlnet: - subset_klass = ControlNetSubset - dataset_klass = ControlNetDataset - elif dataset_blueprint.is_dreambooth: - subset_klass = DreamBoothSubset - dataset_klass = DreamBoothDataset - else: - subset_klass = FineTuningSubset - dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) - datasets.append(dataset) - - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent( - f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) - - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """ - ), - " ", - ) - - if is_dreambooth: - info += indent( - dedent( - f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n""" - ), - " ", - ) - elif not is_controlnet: - info += indent( - dedent( - f"""\ - metadata_file: {subset.metadata_file} - \n""" - ), - " ", - ) - - logger.info(f'{info}') - - # make buckets first because it determines the length of dataset - # and set the same seed for all datasets - seed = random.randint(0, 2**31) # actual seed is seed + epoch_no - for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") - dataset.make_buckets() - dataset.set_seed(seed) - - return DatasetGroup(datasets) - - -def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): - def extract_dreambooth_params(name: str) -> Tuple[int, str]: - tokens = name.split("_") - try: - n_repeats = int(tokens[0]) - except ValueError as e: - logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") - return 0, "" - caption_by_folder = "_".join(tokens[1:]) - return n_repeats, caption_by_folder - - def generate(base_dir: Optional[str], is_reg: bool): - if base_dir is None: - return [] - - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] - - subsets_config = [] - for subdir in base_dir.iterdir(): - if not subdir.is_dir(): - continue - - num_repeats, class_tokens = extract_dreambooth_params(subdir.name) - if num_repeats < 1: - continue - - subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} - subsets_config.append(subset_config) - - return subsets_config - - subsets_config = [] - subsets_config += generate(train_data_dir, False) - subsets_config += generate(reg_data_dir, True) - - return subsets_config - - -def generate_controlnet_subsets_config_by_subdirs( - train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" -): - def generate(base_dir: Optional[str]): - if base_dir is None: - return [] - - base_dir: Path = Path(base_dir) - if not base_dir.is_dir(): - return [] - - subsets_config = [] - subset_config = { - "image_dir": train_data_dir, - "conditioning_data_dir": conditioning_data_dir, - "caption_extension": caption_extension, - "num_repeats": 1, - } - subsets_config.append(subset_config) - - return subsets_config - - subsets_config = [] - subsets_config += generate(train_data_dir) - - return subsets_config - - -def load_user_config(file: str) -> dict: - file: Path = Path(file) - if not file.is_file(): - raise ValueError(f"file not found / ファイルが見つかりません: {file}") - - if file.name.lower().endswith(".json"): - try: - with open(file, "r") as f: - config = json.load(f) - except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - elif file.name.lower().endswith(".toml"): - try: - config = toml.load(file) - except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") - raise - else: - raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") - - return config - - -# for config test -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--support_dreambooth", action="store_true") - parser.add_argument("--support_finetuning", action="store_true") - parser.add_argument("--support_controlnet", action="store_true") - parser.add_argument("--support_dropout", action="store_true") - parser.add_argument("dataset_config") - config_args, remain = parser.parse_known_args() - - parser = argparse.ArgumentParser() - train_util.add_dataset_arguments( - parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout - ) - train_util.add_training_arguments(parser, config_args.support_dreambooth) - argparse_namespace = parser.parse_args(remain) - train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) - - logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') - - user_config = load_user_config(config_args.dataset_config) - - logger.info("") - logger.info("[user_config]") - logger.info(f'{user_config}') - - sanitizer = ConfigSanitizer( - config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout - ) - sanitized_user_config = sanitizer.sanitize_user_config(user_config) - - logger.info("") - logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') - - blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) - - logger.info("") - logger.info("[blueprint]") - logger.info(f'{blueprint}') diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py deleted file mode 100644 index 9713ede6d..000000000 --- a/library/custom_train_functions.py +++ /dev/null @@ -1,532 +0,0 @@ -import torch -import argparse -import random -import re -from typing import List, Optional, Union -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def prepare_scheduler_for_custom_training(noise_scheduler, device): - if hasattr(noise_scheduler, "all_snr"): - return - - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - alpha = sqrt_alphas_cumprod - sigma = sqrt_one_minus_alphas_cumprod - all_snr = (alpha / sigma) ** 2 - - noise_scheduler.all_snr = all_snr.to(device) - - -def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): - # fix beta: zero terminal SNR - logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") - - def enforce_zero_terminal_snr(betas): - # Convert betas to alphas_bar_sqrt - alphas = 1 - betas - alphas_bar = alphas.cumprod(0) - alphas_bar_sqrt = alphas_bar.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - # Shift so last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - # Scale so first timestep is back to old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 - alphas = alphas_bar[1:] / alphas_bar[:-1] - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - return betas - - betas = noise_scheduler.betas - betas = enforce_zero_terminal_snr(betas) - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - - # logger.info(f"original: {noise_scheduler.betas}") - # logger.info(f"fixed: {betas}") - - noise_scheduler.betas = betas - noise_scheduler.alphas = alphas - noise_scheduler.alphas_cumprod = alphas_cumprod - - -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): - snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) - if v_prediction: - snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) - else: - snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) - loss = loss * snr_weight - return loss - - -def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): - scale = get_snr_scale(timesteps, noise_scheduler) - loss = loss * scale - return loss - - -def get_snr_scale(timesteps, noise_scheduler): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - scale = snr_t / (snr_t + 1) - # # show debug info - # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") - return scale - - -def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): - scale = get_snr_scale(timesteps, noise_scheduler) - # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") - loss = loss + loss / scale * v_pred_like_loss - return loss - -def apply_debiased_estimation(loss, timesteps, noise_scheduler): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1/torch.sqrt(snr_t) - loss = weight * loss - return loss - -# TODO train_utilと分散しているのでどちらかに寄せる - - -def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): - parser.add_argument( - "--min_snr_gamma", - type=float, - default=None, - help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", - ) - parser.add_argument( - "--scale_v_pred_loss_like_noise_pred", - action="store_true", - help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", - ) - parser.add_argument( - "--v_pred_like_loss", - type=float, - default=None, - help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", - ) - parser.add_argument( - "--debiased_estimation_loss", - action="store_true", - help="debiased estimation loss / debiased estimation loss", - ) - if support_weighted_captions: - parser.add_argument( - "--weighted_captions", - action="store_true", - default=False, - help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", - ) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - tokenizer, - text_encoder, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - if clip_skip is None or clip_skip == 1: - text_embedding = text_encoder(text_input_chunk)[0] - else: - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-clip_skip] - text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = text_encoder(text_input)[0] - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-clip_skip] - text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings - - -def get_weighted_text_embeddings( - tokenizer, - text_encoder, - prompt: Union[str, List[str]], - device, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - clip_skip=None, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - tokenizer, - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - return text_embeddings - - -# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): - b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! - u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) - for i in range(iterations): - r = random.random() * 2 + 2 # Rather than always going 2x, - wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) - noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i - if wn == 1 or hn == 1: - break # Lowest resolution is 1x1 - return noise / noise.std() # Scaled back to roughly unit variance - - -# https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): - if noise_offset is None: - return noise - if adaptive_noise_scale is not None: - # latent shape: (batch_size, channels, height, width) - # abs mean value for each channel - latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) - - # multiply adaptive noise scale to the mean value and add it to the noise offset - noise_offset = noise_offset + adaptive_noise_scale * latent_mean - noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative - - noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) - return noise - - -""" -########################################## -# Perlin Noise -def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): - delta = (res[0] / shape[0], res[1] / shape[1]) - d = (shape[0] // res[0], shape[1] // res[1]) - - grid = ( - torch.stack( - torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)), - dim=-1, - ) - % 1 - ) - angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device) - gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) - - tile_grads = ( - lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] - .repeat_interleave(d[0], 0) - .repeat_interleave(d[1], 1) - ) - dot = lambda grad, shift: ( - torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1) - * grad[: shape[0], : shape[1]] - ).sum(dim=-1) - - n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) - n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) - n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) - n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) - t = fade(grid[: shape[0], : shape[1]]) - return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) - - -def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): - noise = torch.zeros(shape, device=device) - frequency = 1 - amplitude = 1 - for _ in range(octaves): - noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1])) - frequency *= 2 - amplitude *= persistence - return noise - - -def perlin_noise(noise, device, octaves): - _, c, w, h = noise.shape - perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves) - noise_perlin = [] - for _ in range(c): - noise_perlin.append(perlin()) - noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h) - noise += noise_perlin # broadcast for each batch - return noise / noise.std() # Scaled back to roughly unit variance -""" diff --git a/library/device_utils.py b/library/device_utils.py deleted file mode 100644 index 8823c5d9a..000000000 --- a/library/device_utils.py +++ /dev/null @@ -1,84 +0,0 @@ -import functools -import gc - -import torch - -try: - HAS_CUDA = torch.cuda.is_available() -except Exception: - HAS_CUDA = False - -try: - HAS_MPS = torch.backends.mps.is_available() -except Exception: - HAS_MPS = False - -try: - import intel_extension_for_pytorch as ipex # noqa - - HAS_XPU = torch.xpu.is_available() -except Exception: - HAS_XPU = False - - -def clean_memory(): - gc.collect() - if HAS_CUDA: - torch.cuda.empty_cache() - if HAS_XPU: - torch.xpu.empty_cache() - if HAS_MPS: - torch.mps.empty_cache() - - -def clean_memory_on_device(device: torch.device): - r""" - Clean memory on the specified device, will be called from training scripts. - """ - gc.collect() - - # device may "cuda" or "cuda:0", so we need to check the type of device - if device.type == "cuda": - torch.cuda.empty_cache() - if device.type == "xpu": - torch.xpu.empty_cache() - if device.type == "mps": - torch.mps.empty_cache() - - -@functools.lru_cache(maxsize=None) -def get_preferred_device() -> torch.device: - r""" - Do not call this function from training scripts. Use accelerator.device instead. - """ - if HAS_CUDA: - device = torch.device("cuda") - elif HAS_XPU: - device = torch.device("xpu") - elif HAS_MPS: - device = torch.device("mps") - else: - device = torch.device("cpu") - print(f"get_preferred_device() -> {device}") - return device - - -def init_ipex(): - """ - Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. - - This function should run right after importing torch and before doing anything else. - - If IPEX is not available, this function does nothing. - """ - try: - if HAS_XPU: - from library.ipex import ipex_init - - is_initialized, error_message = ipex_init() - if not is_initialized: - print("failed to initialize ipex:", error_message) - else: - return - except Exception as e: - print("failed to initialize ipex:", e) diff --git a/library/huggingface_util.py b/library/huggingface_util.py deleted file mode 100644 index 57b19d982..000000000 --- a/library/huggingface_util.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Union, BinaryIO -from huggingface_hub import HfApi -from pathlib import Path -import argparse -import os -from library.utils import fire_in_thread -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): - api = HfApi( - token=token, - ) - try: - api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) - return True - except: - return False - - -def upload( - args: argparse.Namespace, - src: Union[str, Path, bytes, BinaryIO], - dest_suffix: str = "", - force_sync_upload: bool = False, -): - repo_id = args.huggingface_repo_id - repo_type = args.huggingface_repo_type - token = args.huggingface_token - path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None - private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" - api = HfApi(token=token) - if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): - try: - api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) - except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので - logger.error("===========================================") - logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") - logger.error("===========================================") - - is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) - - def uploader(): - try: - if is_folder: - api.upload_folder( - repo_id=repo_id, - repo_type=repo_type, - folder_path=src, - path_in_repo=path_in_repo, - ) - else: - api.upload_file( - repo_id=repo_id, - repo_type=repo_type, - path_or_fileobj=src, - path_in_repo=path_in_repo, - ) - except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので - logger.error("===========================================") - logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") - logger.error("===========================================") - - if args.async_upload and not force_sync_upload: - fire_in_thread(uploader) - else: - uploader() - - -def list_dir( - repo_id: str, - subfolder: str, - repo_type: str, - revision: str = "main", - token: str = None, -): - api = HfApi( - token=token, - ) - repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) - file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] - return file_list diff --git a/library/hypernetwork.py b/library/hypernetwork.py deleted file mode 100644 index fbd3fb24e..000000000 --- a/library/hypernetwork.py +++ /dev/null @@ -1,223 +0,0 @@ -import torch -import torch.nn.functional as F -from diffusers.models.attention_processor import ( - Attention, - AttnProcessor2_0, - SlicedAttnProcessor, - XFormersAttnProcessor -) - -try: - import xformers.ops -except: - xformers = None - - -loaded_networks = [] - - -def apply_single_hypernetwork( - hypernetwork, hidden_states, encoder_hidden_states -): - context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) - return context_k, context_v - - -def apply_hypernetworks(context_k, context_v, layer=None): - if len(loaded_networks) == 0: - return context_v, context_v - for hypernetwork in loaded_networks: - context_k, context_v = hypernetwork.forward(context_k, context_v) - - context_k = context_k.to(dtype=context_k.dtype) - context_v = context_v.to(dtype=context_k.dtype) - - return context_k, context_v - - - -def xformers_forward( - self: XFormersAttnProcessor, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: torch.Tensor = None, -): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) - - key = attn.to_k(context_k) - value = attn.to_v(context_v) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, - key, - value, - attn_bias=attention_mask, - op=self.attention_op, - scale=attn.scale, - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -def sliced_attn_forward( - self: SlicedAttnProcessor, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - attention_mask: torch.Tensor = None, -): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) - - key = attn.to_k(context_k) - value = attn.to_v(context_v) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, _ = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] if attention_mask is not None else None - ) - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -def v2_0_forward( - self: AttnProcessor2_0, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, -): - batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - inner_dim = hidden_states.shape[-1] - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view( - batch_size, attn.heads, -1, attention_mask.shape[-1] - ) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) - - key = attn.to_k(context_k) - value = attn.to_v(context_v) - - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -def replace_attentions_for_hypernetwork(): - import diffusers.models.attention_processor - - diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( - xformers_forward - ) - diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( - sliced_attn_forward - ) - diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py deleted file mode 100644 index 972a3bf63..000000000 --- a/library/ipex/__init__.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -import sys -import contextlib -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from .hijacks import ipex_hijacks - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -def ipex_init(): # pylint: disable=too-many-statements - try: - if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: - return True, "Skipping IPEX hijack" - else: - # Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - - # Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - - # RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed - - # AMP: - torch.cuda.amp = torch.xpu.amp - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False - - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught - try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - - # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 - - # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.backends.cuda.is_built = lambda *args, **kwargs: True - torch.version.cuda = "12.1" - torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] - torch.cuda.get_device_properties.major = 12 - torch.cuda.get_device_properties.minor = 1 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 - - ipex_hijacks() - if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass - torch.cuda.is_xpu_hijacked = True - except Exception as e: - return False, e - return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py deleted file mode 100644 index 8253c5b17..000000000 --- a/library/ipex/attention.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -from functools import cache - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers - -sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) - -# Find something divisible with the input_tokens -@cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size - -# Find slice sizes for SDPA -@cache -def find_sdpa_slice_sizes(query_shape, query_element_size): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape - - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three - - do_split = False - do_split_2 = False - do_split_3 = False - - if block_size > sdpa_slice_trigger_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -# Find slice sizes for BMM -@cache -def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): - batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] - slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = input_tokens - split_3_slice_size = mat2_atten_shape - - do_split = False - do_split_2 = False - do_split_3 = False - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - -original_torch_bmm = torch.bmm -def torch_bmm_32_bit(input, mat2, *, out=None): - if input.device.type != "xpu": - return original_torch_bmm(input, mat2, out=out) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) - - # Slice BMM - if do_split: - batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] - hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - out=out - ) - else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) - else: - hidden_states[start_idx:end_idx] = original_torch_bmm( - input[start_idx:end_idx], - mat2[start_idx:end_idx], - out=out - ) - else: - return original_torch_bmm(input, mat2, out=out) - torch.xpu.synchronize(input.device) - return hidden_states - -original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - if query.device.type != "xpu": - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) - - # Slice SDPA - if do_split: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[start_idx:end_idx, start_idx_2:end_idx_2], - key[start_idx:end_idx, start_idx_2:end_idx_2], - value[start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) - torch.xpu.synchronize(query.device) - return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py deleted file mode 100644 index 732a18568..000000000 --- a/library/ipex/diffusers.py +++ /dev/null @@ -1,312 +0,0 @@ -import os -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.24.0 # pylint: disable=import-error -from diffusers.models.attention_processor import Attention -from diffusers.utils import USE_PEFT_BACKEND -from functools import cache - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) - -@cache -def find_slice_size(slice_size, slice_block_size): - while (slice_size * slice_block_size) > attention_slice_rate: - slice_size = slice_size // 2 - if slice_size <= 1: - slice_size = 1 - break - return slice_size - -@cache -def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): - if len(query_shape) == 3: - batch_size_attention, query_tokens, shape_three = query_shape - shape_four = 1 - else: - batch_size_attention, query_tokens, shape_three, shape_four = query_shape - if slice_size is not None: - batch_size_attention = slice_size - - slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - split_2_slice_size = query_tokens - split_3_slice_size = shape_three - - do_split = False - do_split_2 = False - do_split_3 = False - - if query_device_type != "xpu": - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - - if block_size > attention_slice_rate: - do_split = True - split_slice_size = find_slice_size(split_slice_size, slice_block_size) - if split_slice_size * slice_block_size > attention_slice_rate: - slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size - do_split_2 = True - split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) - if split_2_slice_size * slice_2_block_size > attention_slice_rate: - slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size - do_split_3 = True - split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) - - return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size - -class SlicedAttnProcessor: # pylint: disable=too-few-public-methods - r""" - Processor for implementing sliced attention. - - Args: - slice_size (`int`, *optional*): - The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and - `attention_head_dim` must be a multiple of the `slice_size`. - """ - - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention, query_tokens, shape_three = query.shape - hidden_states = torch.zeros( - (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) - - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - #################################################################### - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class AttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, - encoder_hidden_states=None, attention_mask=None, - temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches - - residual = hidden_states - - args = () if USE_PEFT_BACKEND else (scale,) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - #################################################################### - # ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] - hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) - do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) - - if do_split: - for i in range(batch_size_attention // split_slice_size): - start_idx = i * split_slice_size - end_idx = (i + 1) * split_slice_size - if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name - start_idx_2 = i2 * split_2_slice_size - end_idx_2 = (i2 + 1) * split_2_slice_size - if do_split_3: - for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name - start_idx_3 = i3 * split_3_slice_size - end_idx_3 = (i3 + 1) * split_3_slice_size - - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) - - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice - del attn_slice - else: - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - del query_slice - del key_slice - del attn_mask_slice - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - del attn_slice - torch.xpu.synchronize(query.device) - else: - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - #################################################################### - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -def ipex_diffusers(): - #ARC GPUs can't allocate more than 4GB to a single block: - diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor - diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py deleted file mode 100644 index 6eb56bc2b..000000000 --- a/library/ipex/gradscaler.py +++ /dev/null @@ -1,183 +0,0 @@ -from collections import defaultdict -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import - -# pylint: disable=protected-access, missing-function-docstring, line-too-long - -device_supports_fp64 = torch.xpu.has_fp64_dtype() -OptState = ipex.cpu.autocast._grad_scaler.OptState -_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator -_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state - -def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument - per_device_inv_scale = _MultiDeviceReplicator(inv_scale) - per_device_found_inf = _MultiDeviceReplicator(found_inf) - - # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. - # There could be hundreds of grads, so we'd like to iterate through them just once. - # However, we don't know their devices or dtypes in advance. - - # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict - # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] - # sync grad to master weight - if hasattr(optimizer, "sync_grad"): - optimizer.sync_grad() - with torch.no_grad(): - for group in optimizer.param_groups: - for param in group["params"]: - if param.grad is None: - continue - if (not allow_fp16) and param.grad.dtype == torch.float16: - raise ValueError("Attempting to unscale FP16 gradients.") - if param.grad.is_sparse: - # is_coalesced() == False means the sparse grad has values with duplicate indices. - # coalesce() deduplicates indices and adds all values that have the same index. - # For scaled fp16 values, there's a good chance coalescing will cause overflow, - # so we should check the coalesced _values(). - if param.grad.dtype is torch.float16: - param.grad = param.grad.coalesce() - to_unscale = param.grad._values() - else: - to_unscale = param.grad - - # -: is there a way to split by device and dtype without appending in the inner loop? - to_unscale = to_unscale.to("cpu") - per_device_and_dtype_grads[to_unscale.device][ - to_unscale.dtype - ].append(to_unscale) - - for _, per_dtype_grads in per_device_and_dtype_grads.items(): - for grads in per_dtype_grads.values(): - core._amp_foreach_non_finite_check_and_unscale_( - grads, - per_device_found_inf.get("cpu"), - per_device_inv_scale.get("cpu"), - ) - - return per_device_found_inf._per_device_tensors - -def unscale_(self, optimizer): - """ - Divides ("unscales") the optimizer's gradient tensors by the scale factor. - :meth:`unscale_` is optional, serving cases where you need to - :ref:`modify or inspect gradients` - between the backward pass(es) and :meth:`step`. - If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. - Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: - ... - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - scaler.step(optimizer) - scaler.update() - Args: - optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. - .. warning:: - :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, - and only after all gradients for that optimizer's assigned parameters have been accumulated. - Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. - .. warning:: - :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. - """ - if not self._enabled: - return - - self._check_scale_growth_tracker("unscale_") - - optimizer_state = self._per_optimizer_states[id(optimizer)] - - if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise - raise RuntimeError( - "unscale_() has already been called on this optimizer since the last update()." - ) - elif optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. - assert self._scale is not None - if device_supports_fp64: - inv_scale = self._scale.double().reciprocal().float() - else: - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) - found_inf = torch.full( - (1,), 0.0, dtype=torch.float32, device=self._scale.device - ) - - optimizer_state["found_inf_per_device"] = self._unscale_grads_( - optimizer, inv_scale, found_inf, False - ) - optimizer_state["stage"] = OptState.UNSCALED - -def update(self, new_scale=None): - """ - Updates the scale factor. - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - Args: - new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [ - found_inf.to(device="cpu", non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf_combined += found_infs[i] - - to_device = _scale.device - _scale = _scale.to("cpu") - _growth_tracker = _growth_tracker.to("cpu") - - core._amp_update_scale_( - _scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval, - ) - - _scale = _scale.to(to_device) - _growth_tracker = _growth_tracker.to(to_device) - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - -def gradscaler_init(): - torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ - torch.xpu.amp.GradScaler.unscale_ = unscale_ - torch.xpu.amp.GradScaler.update = update - return torch.xpu.amp.GradScaler diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py deleted file mode 100644 index b1b9ccf0e..000000000 --- a/library/ipex/hijacks.py +++ /dev/null @@ -1,298 +0,0 @@ -import os -from functools import wraps -from contextlib import nullcontext -import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import numpy as np - -device_supports_fp64 = torch.xpu.has_fp64_dtype() - -# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return - -class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods - def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument - if isinstance(device_ids, list) and len(device_ids) > 1: - logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") - return module.to("xpu") - -def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return nullcontext() - -@property -def is_cuda(self): - return self.device.type == 'xpu' or self.device.type == 'cuda' - -def check_device(device): - return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) - -def return_xpu(device): - return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" - - -# Autocast -original_autocast_init = torch.amp.autocast_mode.autocast.__init__ -@wraps(torch.amp.autocast_mode.autocast.__init__) -def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): - if device_type == "cuda": - return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) - else: - return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) - -# Latent Antialias CPU Offload: -original_interpolate = torch.nn.functional.interpolate -@wraps(torch.nn.functional.interpolate) -def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments - if antialias or align_corners is not None: - return_device = tensor.device - return_dtype = tensor.dtype - return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) - else: - return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) - - -# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): -original_from_numpy = torch.from_numpy -@wraps(torch.from_numpy) -def from_numpy(ndarray): - if ndarray.dtype == float: - return original_from_numpy(ndarray.astype('float32')) - else: - return original_from_numpy(ndarray) - -original_as_tensor = torch.as_tensor -@wraps(torch.as_tensor) -def as_tensor(data, dtype=None, device=None): - if check_device(device): - device = return_xpu(device) - if isinstance(data, np.ndarray) and data.dtype == float and not ( - (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): - return original_as_tensor(data, dtype=torch.float32, device=device) - else: - return original_as_tensor(data, dtype=dtype, device=device) - - -if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: - original_torch_bmm = torch.bmm - original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -else: - # 32 bit attention workarounds for Alchemist: - try: - from .attention import torch_bmm_32_bit as original_torch_bmm - from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention - except Exception: # pylint: disable=broad-exception-caught - original_torch_bmm = torch.bmm - original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention - - -# Data Type Errors: -@wraps(torch.bmm) -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) - return original_torch_bmm(input, mat2, out=out) - -@wraps(torch.nn.functional.scaled_dot_product_attention) -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - if query.dtype != key.dtype: - key = key.to(dtype=query.dtype) - if query.dtype != value.dtype: - value = value.to(dtype=query.dtype) - if attn_mask is not None and query.dtype != attn_mask.dtype: - attn_mask = attn_mask.to(dtype=query.dtype) - return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) - -# A1111 FP16 -original_functional_group_norm = torch.nn.functional.group_norm -@wraps(torch.nn.functional.group_norm) -def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): - if weight is not None and input.dtype != weight.data.dtype: - input = input.to(dtype=weight.data.dtype) - if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: - bias.data = bias.data.to(dtype=weight.data.dtype) - return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) - -# A1111 BF16 -original_functional_layer_norm = torch.nn.functional.layer_norm -@wraps(torch.nn.functional.layer_norm) -def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): - if weight is not None and input.dtype != weight.data.dtype: - input = input.to(dtype=weight.data.dtype) - if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: - bias.data = bias.data.to(dtype=weight.data.dtype) - return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps) - -# Training -original_functional_linear = torch.nn.functional.linear -@wraps(torch.nn.functional.linear) -def functional_linear(input, weight, bias=None): - if input.dtype != weight.data.dtype: - input = input.to(dtype=weight.data.dtype) - if bias is not None and bias.data.dtype != weight.data.dtype: - bias.data = bias.data.to(dtype=weight.data.dtype) - return original_functional_linear(input, weight, bias=bias) - -original_functional_conv2d = torch.nn.functional.conv2d -@wraps(torch.nn.functional.conv2d) -def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - if input.dtype != weight.data.dtype: - input = input.to(dtype=weight.data.dtype) - if bias is not None and bias.data.dtype != weight.data.dtype: - bias.data = bias.data.to(dtype=weight.data.dtype) - return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - -# A1111 Embedding BF16 -original_torch_cat = torch.cat -@wraps(torch.cat) -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) - -# SwinIR BF16: -original_functional_pad = torch.nn.functional.pad -@wraps(torch.nn.functional.pad) -def functional_pad(input, pad, mode='constant', value=None): - if mode == 'reflect' and input.dtype == torch.bfloat16: - return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) - else: - return original_functional_pad(input, pad, mode=mode, value=value) - - -original_torch_tensor = torch.tensor -@wraps(torch.tensor) -def torch_tensor(data, *args, dtype=None, device=None, **kwargs): - if check_device(device): - device = return_xpu(device) - if not device_supports_fp64: - if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): - if dtype == torch.float64: - dtype = torch.float32 - elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): - dtype = torch.float32 - return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) - -original_Tensor_to = torch.Tensor.to -@wraps(torch.Tensor.to) -def Tensor_to(self, device=None, *args, **kwargs): - if check_device(device): - return original_Tensor_to(self, return_xpu(device), *args, **kwargs) - else: - return original_Tensor_to(self, device, *args, **kwargs) - -original_Tensor_cuda = torch.Tensor.cuda -@wraps(torch.Tensor.cuda) -def Tensor_cuda(self, device=None, *args, **kwargs): - if check_device(device): - return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) - else: - return original_Tensor_cuda(self, device, *args, **kwargs) - -original_UntypedStorage_init = torch.UntypedStorage.__init__ -@wraps(torch.UntypedStorage.__init__) -def UntypedStorage_init(*args, device=None, **kwargs): - if check_device(device): - return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) - else: - return original_UntypedStorage_init(*args, device=device, **kwargs) - -original_UntypedStorage_cuda = torch.UntypedStorage.cuda -@wraps(torch.UntypedStorage.cuda) -def UntypedStorage_cuda(self, device=None, *args, **kwargs): - if check_device(device): - return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) - else: - return original_UntypedStorage_cuda(self, device, *args, **kwargs) - -original_torch_empty = torch.empty -@wraps(torch.empty) -def torch_empty(*args, device=None, **kwargs): - if check_device(device): - return original_torch_empty(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_empty(*args, device=device, **kwargs) - -original_torch_randn = torch.randn -@wraps(torch.randn) -def torch_randn(*args, device=None, **kwargs): - if check_device(device): - return original_torch_randn(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_randn(*args, device=device, **kwargs) - -original_torch_ones = torch.ones -@wraps(torch.ones) -def torch_ones(*args, device=None, **kwargs): - if check_device(device): - return original_torch_ones(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_ones(*args, device=device, **kwargs) - -original_torch_zeros = torch.zeros -@wraps(torch.zeros) -def torch_zeros(*args, device=None, **kwargs): - if check_device(device): - return original_torch_zeros(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_zeros(*args, device=device, **kwargs) - -original_torch_linspace = torch.linspace -@wraps(torch.linspace) -def torch_linspace(*args, device=None, **kwargs): - if check_device(device): - return original_torch_linspace(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_linspace(*args, device=device, **kwargs) - -original_torch_Generator = torch.Generator -@wraps(torch.Generator) -def torch_Generator(device=None): - if check_device(device): - return original_torch_Generator(return_xpu(device)) - else: - return original_torch_Generator(device) - -original_torch_load = torch.load -@wraps(torch.load) -def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): - if check_device(map_location): - return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) - else: - return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) - - -# Hijack Functions: -def ipex_hijacks(): - torch.tensor = torch_tensor - torch.Tensor.to = Tensor_to - torch.Tensor.cuda = Tensor_cuda - torch.UntypedStorage.__init__ = UntypedStorage_init - torch.UntypedStorage.cuda = UntypedStorage_cuda - torch.empty = torch_empty - torch.randn = torch_randn - torch.ones = torch_ones - torch.zeros = torch_zeros - torch.linspace = torch_linspace - torch.Generator = torch_Generator - torch.load = torch_load - - torch.backends.cuda.sdp_kernel = return_null_context - torch.nn.DataParallel = DummyDataParallel - torch.UntypedStorage.is_cuda = is_cuda - torch.amp.autocast_mode.autocast.__init__ = autocast_init - - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention - torch.nn.functional.group_norm = functional_group_norm - torch.nn.functional.layer_norm = functional_layer_norm - torch.nn.functional.linear = functional_linear - torch.nn.functional.conv2d = functional_conv2d - torch.nn.functional.interpolate = interpolate - torch.nn.functional.pad = functional_pad - - torch.bmm = torch_bmm - torch.cat = torch_cat - if not device_supports_fp64: - torch.from_numpy = from_numpy - torch.as_tensor = as_tensor diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py deleted file mode 100644 index 7fc117aa8..000000000 --- a/library/lpw_stable_diffusion.py +++ /dev/null @@ -1,1233 +0,0 @@ -# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# and modify to support SD2.x - -import inspect -import re -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection - -import diffusers -from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker -from diffusers.utils import logging - -try: - from diffusers.utils import PIL_INTERPOLATION -except ImportError: - if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } - else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } -# ------------------------------------------------------------------------------ - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - pipe: StableDiffusionPipeline, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - if clip_skip is None or clip_skip == 1: - text_embedding = pipe.text_encoder(text_input_chunk)[0] - else: - enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-clip_skip] - text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = pipe.text_encoder(text_input)[0] - else: - enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-clip_skip] - text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings - - -def get_weighted_text_embeddings( - pipe: StableDiffusionPipeline, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - - Args: - pipe (`StableDiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - pad = pipe.tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( - pipe, - uncond_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings - return text_embeddings, None - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask, scale_factor=8): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -def prepare_controlnet_image( - image: PIL.Image.Image, - width: int, - height: int, - batch_size: int, - num_images_per_prompt: int, - device: torch.device, - dtype: torch.dtype, - do_classifier_free_guidance: bool = False, - guess_mode: bool = False, -): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - -class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing - weighting in prompt. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - # clip_skip: int, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - image_encoder: CLIPVisionModelWithProjection = None, - clip_skip: int = 1, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - image_encoder=image_encoder, - ) - self.custom_clip_skip = clip_skip - self.__init__additional__() - - def __init__additional__(self): - if not hasattr(self, "vae_scale_factor"): - setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.custom_clip_skip, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - return text_embeddings - - def check_inputs(self, prompt, height, width, strength, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - logger.info(f'{height} {width}') - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - def get_timesteps(self, num_inference_steps, strength, device, is_text2img): - if is_text2img: - return self.scheduler.timesteps.to(device), num_inference_steps - else: - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(device) - return timesteps, num_inference_steps - t_start - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - def decode_latents(self, latents): - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): - if image is None: - shape = ( - batch_size, - self.unet.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - - if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents, None, None - else: - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - init_latents = torch.cat([init_latents] * batch_size, dim=0) - init_latents_orig = init_latents - shape = init_latents.shape - - # add noise to latents using the timesteps - if device.type == "mps": - noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.add_noise(init_latents, noise, timestep) - return latents, init_latents_orig, noise - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - strength: float = 0.8, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - controlnet=None, - controlnet_image=None, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - controlnet (`diffusers.ControlNetModel`, *optional*): - A controlnet model to be used for the inference. If not provided, controlnet will be disabled. - controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet - inference. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - `None` if cancelled by `is_cancelled_callback`, - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if controlnet is not None and controlnet_image is None: - raise ValueError("controlnet_image must be provided if controlnet is not None.") - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, strength, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - ) - dtype = text_embeddings.dtype - - # 4. Preprocess image and mask - if isinstance(image, PIL.Image.Image): - image = preprocess_image(image) - if image is not None: - image = image.to(device=self.device, dtype=dtype) - if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image, self.vae_scale_factor) - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=dtype) - mask = torch.cat([mask] * batch_size * num_images_per_prompt) - else: - mask = None - - if controlnet_image is not None: - controlnet_image = prepare_controlnet_image( - controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False - ) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents, init_latents_orig, noise = self.prepare_latents( - image, - latent_timestep, - batch_size * num_images_per_prompt, - height, - width, - dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - return latents - - def latents_to_image(self, latents): - # 9. Post-processing - image = self.decode_latents(latents.to(self.vae.dtype)) - image = self.numpy_to_pil(image) - return image - - def text2img( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) - - def img2img( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for image-to-image generation. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) - - def inpaint( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for inpaint. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) \ No newline at end of file diff --git a/library/lpw_stable_diffusion_orig.py b/library/lpw_stable_diffusion_orig.py deleted file mode 100644 index 9dce91a76..000000000 --- a/library/lpw_stable_diffusion_orig.py +++ /dev/null @@ -1,1254 +0,0 @@ -# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# and modify to support SD2.x - -import inspect -import re -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -import diffusers -from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker -from diffusers.utils import logging - - -try: - from diffusers.utils import PIL_INTERPOLATION -except ImportError: - if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } - else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } -# ------------------------------------------------------------------------------ - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - pipe: StableDiffusionPipeline, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - if clip_skip is None or clip_skip == 1: - text_embedding = pipe.text_encoder(text_input_chunk)[0] - else: - enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-clip_skip] - text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = pipe.text_encoder(text_input)[0] - else: - enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-clip_skip] - text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings - - -def get_weighted_text_embeddings( - pipe: StableDiffusionPipeline, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - - Args: - pipe (`StableDiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - pad = pipe.tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( - pipe, - uncond_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings - return text_embeddings, None - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask, scale_factor=8): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -def prepare_controlnet_image( - image: PIL.Image.Image, - width: int, - height: int, - batch_size: int, - num_images_per_prompt: int, - device: torch.device, - dtype: torch.dtype, - do_classifier_free_guidance: bool = False, - guess_mode: bool = False, -): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - -class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing - weighting in prompt. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - # clip_skip: int, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - clip_skip: int = 1, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) - self.clip_skip = clip_skip - self.__init__additional__() - - # else: - # def __init__( - # self, - # vae: AutoencoderKL, - # text_encoder: CLIPTextModel, - # tokenizer: CLIPTokenizer, - # unet: UNet2DConditionModel, - # scheduler: SchedulerMixin, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - # ): - # super().__init__( - # vae=vae, - # text_encoder=text_encoder, - # tokenizer=tokenizer, - # unet=unet, - # scheduler=scheduler, - # safety_checker=safety_checker, - # feature_extractor=feature_extractor, - # ) - # self.__init__additional__() - - def __init__additional__(self): - if not hasattr(self, "vae_scale_factor"): - setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - return text_embeddings - - def check_inputs(self, prompt, height, width, strength, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - print(height, width) - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - def get_timesteps(self, num_inference_steps, strength, device, is_text2img): - if is_text2img: - return self.scheduler.timesteps.to(device), num_inference_steps - else: - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(device) - return timesteps, num_inference_steps - t_start - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - def decode_latents(self, latents): - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): - if image is None: - shape = ( - batch_size, - self.unet.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - - if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents, None, None - else: - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - init_latents = torch.cat([init_latents] * batch_size, dim=0) - init_latents_orig = init_latents - shape = init_latents.shape - - # add noise to latents using the timesteps - if device.type == "mps": - noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.add_noise(init_latents, noise, timestep) - return latents, init_latents_orig, noise - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - strength: float = 0.8, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - controlnet=None, - controlnet_image=None, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - controlnet (`diffusers.ControlNetModel`, *optional*): - A controlnet model to be used for the inference. If not provided, controlnet will be disabled. - controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet - inference. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - `None` if cancelled by `is_cancelled_callback`, - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if controlnet is not None and controlnet_image is None: - raise ValueError("controlnet_image must be provided if controlnet is not None.") - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, strength, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - ) - dtype = text_embeddings.dtype - - # 4. Preprocess image and mask - if isinstance(image, PIL.Image.Image): - image = preprocess_image(image) - if image is not None: - image = image.to(device=self.device, dtype=dtype) - if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image, self.vae_scale_factor) - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=dtype) - mask = torch.cat([mask] * batch_size * num_images_per_prompt) - else: - mask = None - - if controlnet_image is not None: - controlnet_image = prepare_controlnet_image( - controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False - ) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents, init_latents_orig, noise = self.prepare_latents( - image, - latent_timestep, - batch_size * num_images_per_prompt, - height, - width, - dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - return latents - - def latents_to_image(self, latents): - # 9. Post-processing - image = self.decode_latents(latents.to(self.vae.dtype)) - image = self.numpy_to_pil(image) - return image - - def text2img( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) - - def img2img( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for image-to-image generation. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) - - def inpaint( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for inpaint. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) diff --git a/library/model_util.py b/library/model_util.py deleted file mode 100644 index be410a026..000000000 --- a/library/model_util.py +++ /dev/null @@ -1,1356 +0,0 @@ -# v1: split from train_db_fixed.py. -# v2: support safetensors - -import math -import os - -import torch -from library.device_utils import init_ipex -init_ipex() - -import diffusers -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel -from safetensors.torch import load_file, save_file -from library.original_unet import UNet2DConditionModel -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# DiffUsers版StableDiffusionのモデルパラメータ -NUM_TRAIN_TIMESTEPS = 1000 -BETA_START = 0.00085 -BETA_END = 0.0120 - -UNET_PARAMS_MODEL_CHANNELS = 320 -UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] -UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] -UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` -UNET_PARAMS_IN_CHANNELS = 4 -UNET_PARAMS_OUT_CHANNELS = 4 -UNET_PARAMS_NUM_RES_BLOCKS = 2 -UNET_PARAMS_CONTEXT_DIM = 768 -UNET_PARAMS_NUM_HEADS = 8 -# UNET_PARAMS_USE_LINEAR_PROJECTION = False - -VAE_PARAMS_Z_CHANNELS = 4 -VAE_PARAMS_RESOLUTION = 256 -VAE_PARAMS_IN_CHANNELS = 3 -VAE_PARAMS_OUT_CH = 3 -VAE_PARAMS_CH = 128 -VAE_PARAMS_CH_MULT = [1, 2, 4, 4] -VAE_PARAMS_NUM_RES_BLOCKS = 2 - -# V2 -V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] -V2_UNET_PARAMS_CONTEXT_DIM = 1024 -# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True - -# Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" -DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" - - -# region StableDiffusion->Diffusersの変換コード -# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - if diffusers.__version__ < "0.17.0": - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - else: - new_item = new_item.replace("q.weight", "to_q.weight") - new_item = new_item.replace("q.bias", "to_q.bias") - - new_item = new_item.replace("k.weight", "to_k.weight") - new_item = new_item.replace("k.bias", "to_k.bias") - - new_item = new_item.replace("v.weight", "to_v.weight") - new_item = new_item.replace("v.bias", "to_v.bias") - - new_item = new_item.replace("proj_out.weight", "to_out.0.weight") - new_item = new_item.replace("proj_out.bias", "to_out.0.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - reshaping = False - if diffusers.__version__ < "0.17.0": - if "proj_attn.weight" in new_path: - reshaping = True - else: - if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: - reshaping = True - - if reshaping: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) - - -def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっている - # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 - if v2 and not config.get("use_linear_projection", False): - linear_transformer_to_conv(new_checkpoint) - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params - - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, - ) - if v2 and use_linear_projection_in_v2: - config["use_linear_projection"] = True - - return config - - -def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config - - -def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # remove position_ids for newer transformer, which causes error :( - if "text_model.embeddings.position_ids" in text_model_dict: - text_model_dict.pop("text_model.embeddings.position_ids") - - return text_model_dict - - -def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None - - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif ".attn.out_proj" in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif ".attn.in_proj" in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif ".positional_embedding" in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif ".text_projection" in key: - key = None # 使われない??? - elif ".logit_scale" in key: - key = None # 使われない??? - elif ".token_embedding" in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif ".ln_final" in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if ".resblocks.23." in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if ".resblocks.23." in key: - continue - if ".resblocks" in key and ".attn.in_proj_" in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # rename or add position_ids - ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" - if ANOTHER_POSITION_IDS_KEY in new_sd: - # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] - del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - - new_sd["text_model.embeddings.position_ids"] = position_ids - return new_sd - - -# endregion - - -# region Diffusers->StableDiffusion の変換コード -# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) - - -def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - - -def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks - - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict - - -def controlnet_conversion_map(): - unet_conversion_map = [ - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("middle_block_out.0.weight", "controlnet_mid_block.weight"), - ("middle_block_out.0.bias", "controlnet_mid_block.bias"), - ] - - unet_conversion_map_resnet = [ - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - for j in range(2): - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] - for i, hf_prefix in enumerate(controlnet_cond_embedding_names): - hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." - sd_prefix = f"input_hint_block.{i*2}." - unet_conversion_map_layer.append((sd_prefix, hf_prefix)) - - for i in range(12): - hf_prefix = f"controlnet_down_blocks.{i}." - sd_prefix = f"zero_convs.{i}.0." - unet_conversion_map_layer.append((sd_prefix, hf_prefix)) - - return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer - - -def convert_controlnet_state_dict_to_sd(controlnet_state_dict): - unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() - - mapping = {k: k for k in controlnet_state_dict.keys()} - for sd_name, diffusers_name in unet_conversion_map: - mapping[diffusers_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, diffusers_part in unet_conversion_map_resnet: - v = v.replace(diffusers_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, diffusers_part in unet_conversion_map_layer: - v = v.replace(diffusers_part, sd_part) - mapping[k] = v - new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} - return new_state_dict - - -def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): - unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() - - mapping = {k: k for k in controlnet_state_dict.keys()} - for sd_name, diffusers_name in unet_conversion_map: - mapping[sd_name] = diffusers_name - for k, v in mapping.items(): - for sd_part, diffusers_part in unet_conversion_map_layer: - v = v.replace(sd_part, diffusers_part) - mapping[k] = v - for k, v in mapping.items(): - if "resnets" in v: - for sd_part, diffusers_part in unet_conversion_map_resnet: - v = v.replace(sd_part, diffusers_part) - mapping[k] = v - new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} - return new_state_dict - - -# ================# -# VAE Conversion # -# ================# - - -def reshape_weight_for_sd(w): - # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) - - -def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] - - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - if diffusers.__version__ < "0.17.0": - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] - else: - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "to_q."), - ("k.", "to_k."), - ("v.", "to_v."), - ("proj_out.", "to_out.0."), - ] - - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") - new_state_dict[k] = reshape_weight_for_sd(v) - - return new_state_dict - - -# endregion - -# region 自作のモデル読み書きなど - - -def is_safetensors(path): - return os.path.splitext(path)[1].lower() == ".safetensors" - - -def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), - ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), - ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), - ] - - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path) # , device) # may causes error - else: - checkpoint = torch.load(ckpt_path, map_location=device) - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - checkpoint = None - - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from) :] - key_reps.append((key, new_key)) - - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - return checkpoint, state_dict - - -# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - - unet = UNet2DConditionModel(**unet_config).to(device) - info = unet.load_state_dict(converted_unet_checkpoint) - logger.info(f"loading u-net: {info}") - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - - vae = AutoencoderKL(**vae_config).to(device) - info = vae.load_state_dict(converted_vae_checkpoint) - logger.info(f"loading vae: {info}") - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - - # logging.set_verbosity_error() # don't show annoying warning - # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) - # logging.set_verbosity_warning() - # logger.info(f"config: {text_model.config}") - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - max_position_embeddings=77, - hidden_act="quick_gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=768, - torch_dtype="float32", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - logger.info(f"loading text encoder: {info}") - - return text_model, vae, unet - - -def get_model_version_str_for_sd1_sd2(v2, v_parameterization): - # only for reference - version_str = "sd" - if v2: - version_str += "_v2" - else: - version_str += "_v1" - if v_parameterization: - version_str += "_v" - return version_str - - -def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif ".self_attn.out_proj" in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif ".self_attn." in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif ".position_embedding" in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif ".token_embedding" in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif "final_layer_norm" in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if "layers" in key and "q_proj" in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - logger.info("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) - for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる - - # Diffusersに含まれない重みを作っておく - new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd["logit_scale"] = torch.tensor(1) - - return new_sd - - -def save_stable_diffusion_checkpoint( - v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None -): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False - else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" - checkpoint = {} - state_dict = {} - strict = False - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) - - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) - - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {"state_dict": state_dict} - - # epoch and global_step are sometimes not int - try: - if "epoch" in checkpoint: - epochs += checkpoint["epoch"] - if "global_step" in checkpoint: - steps += checkpoint["global_step"] - except: - pass - - new_ckpt["epoch"] = epochs - new_ckpt["global_step"] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file, metadata) - else: - torch.save(new_ckpt, output_file) - - return key_count - - -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 - else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - - scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - - # original U-Net cannot be saved, so we need to convert it to the Diffusers version - # TODO this consumes a lot of memory - diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") - diffusers_unet.load_state_dict(unet.state_dict()) - - pipeline = StableDiffusionPipeline( - unet=diffusers_unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - - -VAE_PREFIX = "first_stage_model." - - -def load_vae(vae_id, dtype): - logger.info(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - logger.error(f"exception occurs in loading vae: {e}") - logger.error("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) - return vae - - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") - else: - # StableDiffusion - vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") - vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae - - -# endregion - - -def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = max_width * max_height - - resos = set() - - width = int(math.sqrt(max_area) // divisible) * divisible - resos.add((width, width)) - - width = min_size - while width <= max_size: - height = min(max_size, int((max_area // width) // divisible) * divisible) - if height >= min_size: - resos.add((width, height)) - resos.add((height, width)) - - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) - - width += divisible - - resos = list(resos) - resos.sort() - return resos - - -if __name__ == "__main__": - resos = make_bucket_resolutions((512, 768)) - logger.info(f"{len(resos)}") - logger.info(f"{resos}") - aspect_ratios = [w / h for w, h in resos] - logger.info(f"{aspect_ratios}") - - ars = set() - for ar in aspect_ratios: - if ar in ars: - logger.error(f"error! duplicate ar: {ar}") - ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py deleted file mode 100644 index e944ff22b..000000000 --- a/library/original_unet.py +++ /dev/null @@ -1,1919 +0,0 @@ -# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる -# 条件分岐等で不要な部分は削除している -# コードの多くはDiffusersからコピーしている -# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある - -# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. -# Unnecessary parts are deleted by condition branching. -# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 - -""" -v1.5とv2.1の相違点は -- attention_head_dimがintかlist[int]か -- cross_attention_dimが768か1024か -- use_linear_projection: trueがない(=False, 1.5)かあるか -- upcast_attentionがFalse(1.5)かTrue(2.1)か -- (以下は多分無視していい) -- sample_sizeが64か96か -- dual_cross_attentionがあるかないか -- num_class_embedsがあるかないか -- only_cross_attentionがあるかないか - -v1.5 -{ - "_class_name": "UNet2DConditionModel", - "_diffusers_version": "0.6.0", - "act_fn": "silu", - "attention_head_dim": 8, - "block_out_channels": [ - 320, - 640, - 1280, - 1280 - ], - "center_input_sample": false, - "cross_attention_dim": 768, - "down_block_types": [ - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D" - ], - "downsample_padding": 1, - "flip_sin_to_cos": true, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "out_channels": 4, - "sample_size": 64, - "up_block_types": [ - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D" - ] -} - -v2.1 -{ - "_class_name": "UNet2DConditionModel", - "_diffusers_version": "0.10.0.dev0", - "act_fn": "silu", - "attention_head_dim": [ - 5, - 10, - 20, - 20 - ], - "block_out_channels": [ - 320, - 640, - 1280, - 1280 - ], - "center_input_sample": false, - "cross_attention_dim": 1024, - "down_block_types": [ - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D" - ], - "downsample_padding": 1, - "dual_cross_attention": false, - "flip_sin_to_cos": true, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_class_embeds": null, - "only_cross_attention": false, - "out_channels": 4, - "sample_size": 96, - "up_block_types": [ - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D" - ], - "use_linear_projection": true, - "upcast_attention": true -} -""" - -import math -from types import SimpleNamespace -from typing import Dict, Optional, Tuple, Union -import torch -from torch import nn -from torch.nn import functional as F -from einops import rearrange -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) -TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] -TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4 -IN_CHANNELS: int = 4 -OUT_CHANNELS: int = 4 -LAYERS_PER_BLOCK: int = 2 -LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 -TIME_EMBED_FLIP_SIN_TO_COS: bool = True -TIME_EMBED_FREQ_SHIFT: int = 0 -NORM_GROUPS: int = 32 -NORM_EPS: float = 1e-5 -TRANSFORMER_NORM_NUM_GROUPS = 32 - -DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] -UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] - - -# region memory efficient attention - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) - dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -# endregion - - -def get_parameter_dtype(parameter: torch.nn.Module): - return next(parameter.parameters()).dtype - - -def get_parameter_device(parameter: torch.nn.Module): - return next(parameter.parameters()).device - - -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -# Deep Shrink: We do not common this function, because minimize dependencies. -def resize_like(x, target, mode="bicubic", align_corners=False): - org_dtype = x.dtype - if org_dtype == torch.bfloat16: - x = x.to(torch.float32) - - if x.shape[-2:] != target.shape[-2:]: - if mode == "nearest": - x = F.interpolate(x, size=target.shape[-2:], mode=mode) - else: - x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) - - if org_dtype == torch.bfloat16: - x = x.to(org_dtype) - return x - - -class SampleOutput: - def __init__(self, sample): - self.sample = sample - - -class TimestepEmbedding(nn.Module): - def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - self.act = None - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - - def forward(self, sample): - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class ResnetBlock2D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) - - self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - # if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - - self.use_in_shortcut = self.in_channels != self.out_channels - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - add_downsample=True, - ): - super().__init__() - - self.has_cross_attention = False - resnets = [] - - for i in range(LAYERS_PER_BLOCK): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - ) - ) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - pass - - def set_use_sdpa(self, sdpa): - pass - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class Downsample2D(nn.Module): - def __init__(self, channels, out_channels): - super().__init__() - - self.channels = channels - self.out_channels = out_channels - - self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class CrossAttention(nn.Module): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - upcast_attention: bool = False, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - # no dropout here - - self.use_memory_efficient_attention_xformers = False - self.use_memory_efficient_attention_mem_eff = False - self.use_sdpa = False - - # Attention processor - self.processor = None - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - self.use_memory_efficient_attention_xformers = xformers - self.use_memory_efficient_attention_mem_eff = mem_eff - - def set_use_sdpa(self, sdpa): - self.use_sdpa = sdpa - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def set_processor(self): - return self.processor - - def get_processor(self): - return self.processor - - def forward(self, hidden_states, context=None, mask=None, **kwargs): - if self.processor is not None: - ( - hidden_states, - encoder_hidden_states, - attention_mask, - ) = translate_attention_names_from_diffusers( - hidden_states=hidden_states, context=context, mask=mask, **kwargs - ) - return self.processor( - attn=self, - hidden_states=hidden_states, - encoder_hidden_states=context, - attention_mask=mask, - **kwargs - ) - if self.use_memory_efficient_attention_xformers: - return self.forward_memory_efficient_xformers(hidden_states, context, mask) - if self.use_memory_efficient_attention_mem_eff: - return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) - if self.use_sdpa: - return self.forward_sdpa(hidden_states, context, mask) - - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) - - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - hidden_states = self._attention(query, key, value) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # hidden_states = self.to_out[1](hidden_states) # no dropout - return hidden_states - - def _attention(self, query, key, value): - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - - # cast back to the original dtype - attention_probs = attention_probs.to(value.dtype) - - # compute attention output - hidden_states = torch.bmm(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - # TODO support Hypernetworks - def forward_memory_efficient_xformers(self, x, context=None, mask=None): - import xformers.ops - - h = self.heads - q_in = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, "b n h d -> b n (h d)", h=h) - - out = self.to_out[0](out) - return out - - def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): - flash_func = FlashAttentionFunction - - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - out = self.to_out[0](out) - return out - - def forward_sdpa(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - - out = rearrange(out, "b h n d -> b n (h d)", h=h) - - out = self.to_out[0](out) - return out - -def translate_attention_names_from_diffusers( - hidden_states: torch.FloatTensor, - context: Optional[torch.FloatTensor] = None, - mask: Optional[torch.FloatTensor] = None, - # HF naming - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None -): - # translate from hugging face diffusers - context = context if context is not None else encoder_hidden_states - - # translate from hugging face diffusers - mask = mask if mask is not None else attention_mask - - return hidden_states, context, mask - -# feedforward -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - ): - super().__init__() - inner_dim = int(dim * 4) # mult is always 4 - - self.net = nn.ModuleList([]) - # project in - self.net.append(GEGLU(dim, inner_dim)) - # project dropout - self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 - # project out - self.net.append(nn.Linear(inner_dim, dim)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False - ): - super().__init__() - - # 1. Self-Attn - self.attn1 = CrossAttention( - query_dim=dim, - cross_attention_dim=None, - heads=num_attention_heads, - dim_head=attention_head_dim, - upcast_attention=upcast_attention, - ) - self.ff = FeedForward(dim) - - # 2. Cross-Attn - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - upcast_attention=upcast_attention, - ) - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim) - - def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): - self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) - self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa: bool): - self.attn1.set_use_sdpa(sdpa) - self.attn2.set_use_sdpa(sdpa) - - def forward(self, hidden_states, context=None, timestep=None): - # 1. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - hidden_states = self.attn1(norm_hidden_states) + hidden_states - - # 2. Cross-Attention - norm_hidden_states = self.norm2(hidden_states) - hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states - - # 3. Feed-forward - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - - return hidden_states - - -class Transformer2DModel(nn.Module): - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - use_linear_projection: bool = False, - upcast_attention: bool = False, - ): - super().__init__() - self.in_channels = in_channels - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.use_linear_projection = use_linear_projection - - self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True) - - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - ) - ] - ) - - if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) - else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - for transformer in self.transformer_blocks: - transformer.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa): - for transformer in self.transformer_blocks: - transformer.set_use_sdpa(sdpa) - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): - # 1. Input - batch, _, height, weight = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - hidden_states = self.proj_in(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) - - # 3. Output - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - - if not return_dict: - return (output,) - - return SampleOutput(sample=output) - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - add_downsample=True, - cross_attention_dim=1280, - attn_num_head_channels=1, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - self.has_cross_attention = True - resnets = [] - attentions = [] - - self.attn_num_head_channels = attn_num_head_channels - - for i in range(LAYERS_PER_BLOCK): - in_channels = in_channels if i == 0 else out_channels - - resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels)) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - cross_attention_dim=cross_attention_dim, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - for attn in self.attentions: - attn.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa): - for attn in self.attentions: - attn.set_use_sdpa(sdpa) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - attn_num_head_channels=1, - cross_attention_dim=1280, - use_linear_projection=False, - ): - super().__init__() - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - # Middle block has two resnets and one attention - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - ), - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - ), - ] - attentions = [ - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - cross_attention_dim=cross_attention_dim, - use_linear_projection=use_linear_projection, - ) - ] - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - for attn in self.attentions: - attn.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa): - for attn in self.attentions: - attn.set_use_sdpa(sdpa) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): - for i, resnet in enumerate(self.resnets): - attn = None if i == 0 else self.attentions[i - 1] - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - if attn is not None: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states - )[0] - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - if attn is not None: - hidden_states = attn(hidden_states, encoder_hidden_states).sample - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class Upsample2D(nn.Module): - def __init__(self, channels, out_channels): - super().__init__() - self.channels = channels - self.out_channels = out_channels - self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, hidden_states, output_size): - assert hidden_states.shape[1] == self.channels - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - add_upsample=True, - ): - super().__init__() - - self.has_cross_attention = False - resnets = [] - - for i in range(LAYERS_PER_BLOCK_UP): - res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - pass - - def set_use_sdpa(self, sdpa): - pass - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - attn_num_head_channels=1, - cross_attention_dim=1280, - add_upsample=True, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(LAYERS_PER_BLOCK_UP): - res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - ) - ) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - cross_attention_dim=cross_attention_dim, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - for attn in self.attentions: - attn.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa): - for attn in self.attentions: - attn.set_use_sdpa(sdpa) - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - ): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -def get_down_block( - down_block_type, - in_channels, - out_channels, - add_downsample, - attn_num_head_channels, - cross_attention_dim, - use_linear_projection, - upcast_attention, -): - if down_block_type == "DownBlock2D": - return DownBlock2D( - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - ) - elif down_block_type == "CrossAttnDownBlock2D": - return CrossAttnDownBlock2D( - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - - -def get_up_block( - up_block_type, - in_channels, - out_channels, - prev_output_channel, - add_upsample, - attn_num_head_channels, - cross_attention_dim=None, - use_linear_projection=False, - upcast_attention=False, -): - if up_block_type == "UpBlock2D": - return UpBlock2D( - in_channels=in_channels, - prev_output_channel=prev_output_channel, - out_channels=out_channels, - add_upsample=add_upsample, - ) - elif up_block_type == "CrossAttnUpBlock2D": - return CrossAttnUpBlock2D( - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - attn_num_head_channels=attn_num_head_channels, - cross_attention_dim=cross_attention_dim, - add_upsample=add_upsample, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - - -class UNet2DConditionModel(nn.Module): - _supports_gradient_checkpointing = True - - def __init__( - self, - sample_size: Optional[int] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - cross_attention_dim: int = 1280, - use_linear_projection: bool = False, - upcast_attention: bool = False, - **kwargs, - ): - super().__init__() - assert sample_size is not None, "sample_size must be specified" - logger.info( - f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" - ) - - # 外部からの参照用に定義しておく - self.in_channels = IN_CHANNELS - self.out_channels = OUT_CHANNELS - - self.sample_size = sample_size - self.prepare_config(sample_size=sample_size) - - # state_dictの書式が変わるのでmoduleの持ち方は変えられない - - # input - self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) - - # time - self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT) - - self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM) - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * 4 - - # down - output_channel = BLOCK_OUT_CHANNELS[0] - for i, down_block_type in enumerate(DOWN_BLOCK_TYPES): - input_channel = output_channel - output_channel = BLOCK_OUT_CHANNELS[i] - is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 - - down_block = get_down_block( - down_block_type, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - attn_num_head_channels=attention_head_dim[i], - cross_attention_dim=cross_attention_dim, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=BLOCK_OUT_CHANNELS[-1], - attn_num_head_channels=attention_head_dim[-1], - cross_attention_dim=cross_attention_dim, - use_linear_projection=use_linear_projection, - ) - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(UP_BLOCK_TYPES): - is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - add_upsample=add_upsample, - attn_num_head_channels=reversed_attention_head_dim[i], - cross_attention_dim=cross_attention_dim, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) - - # region diffusers compatibility - def prepare_config(self, *args, **kwargs): - self.config = SimpleNamespace(**kwargs) - - @property - def dtype(self) -> torch.dtype: - # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - return get_parameter_dtype(self) - - @property - def device(self) -> torch.device: - # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). - return get_parameter_device(self) - - def set_attention_slice(self, slice_size): - raise NotImplementedError("Attention slicing is not supported for this model.") - - def is_gradient_checkpointing(self) -> bool: - return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - - def enable_gradient_checkpointing(self): - self.set_gradient_checkpointing(value=True) - - def disable_gradient_checkpointing(self): - self.set_gradient_checkpointing(value=False) - - def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: - modules = self.down_blocks + [self.mid_block] + self.up_blocks - for module in modules: - module.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa: bool) -> None: - modules = self.down_blocks + [self.mid_block] + self.up_blocks - for module in modules: - module.set_use_sdpa(sdpa) - - def set_gradient_checkpointing(self, value=False): - modules = self.down_blocks + [self.mid_block] + self.up_blocks - for module in modules: - logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") - module.gradient_checkpointing = value - - # endregion - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - ) -> Union[Dict, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a dict instead of a plain tuple. - - Returns: - `SampleOutput` or `tuple`: - `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある - # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する - # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - # 64で割り切れないときはupsamplerにサイズを伝える - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 1. time - timesteps = timestep - timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - # timestepsは重みを含まないので常にfloat32のテンソルを返す - # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある - # time_projでキャストしておけばいいんじゃね? - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = self.conv_in(sample) - - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 - # まあこちらのほうがわかりやすいかもしれない - if downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # skip connectionにControlNetの出力を追加する - if down_block_additional_residuals is not None: - down_block_res_samples = list(down_block_res_samples) - for i in range(len(down_block_res_samples)): - down_block_res_samples[i] += down_block_additional_residuals[i] - down_block_res_samples = tuple(down_block_res_samples) - - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - - # ControlNetの出力を追加する - if mid_block_additional_residual is not None: - sample += mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection - - # if we have not reached the final block and need to forward the upsample size, we do it here - # 前述のように最後のブロック以外ではupsample_sizeを伝える - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return SampleOutput(sample=sample) - - def handle_unusual_timesteps(self, sample, timesteps): - r""" - timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。 - """ - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - return timesteps - - -class InferUNet2DConditionModel: - def __init__(self, original_unet: UNet2DConditionModel): - self.delegate = original_unet - - # override original model's forward method: because forward is not called by `__call__` - # overriding `__call__` is not enough, because nn.Module.forward has a special handling - self.delegate.forward = self.forward - - # override original model's up blocks' forward method - for up_block in self.delegate.up_blocks: - if up_block.__class__.__name__ == "UpBlock2D": - - def resnet_wrapper(func, block): - def forward(*args, **kwargs): - return func(block, *args, **kwargs) - - return forward - - up_block.forward = resnet_wrapper(self.up_block_forward, up_block) - - elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": - - def cross_attn_up_wrapper(func, block): - def forward(*args, **kwargs): - return func(block, *args, **kwargs) - - return forward - - up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) - - # Deep Shrink - self.ds_depth_1 = None - self.ds_depth_2 = None - self.ds_timesteps_1 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - - # call original model's methods - def __getattr__(self, name): - return getattr(self.delegate, name) - - def __call__(self, *args, **kwargs): - return self.delegate(*args, **kwargs) - - def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): - if ds_depth_1 is None: - logger.info("Deep Shrink is disabled.") - self.ds_depth_1 = None - self.ds_timesteps_1 = None - self.ds_depth_2 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - else: - logger.info( - f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" - ) - self.ds_depth_1 = ds_depth_1 - self.ds_timesteps_1 = ds_timesteps_1 - self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 - self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 - self.ds_ratio = ds_ratio - - def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet in _self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # Deep Shrink - if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: - hidden_states = resize_like(hidden_states, res_hidden_states) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - - if _self.upsamplers is not None: - for upsampler in _self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - def cross_attn_up_block_forward( - self, - _self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - ): - for resnet, attn in zip(_self.resnets, _self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # Deep Shrink - if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: - hidden_states = resize_like(hidden_states, res_hidden_states) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - if _self.upsamplers is not None: - for upsampler in _self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - ) -> Union[Dict, Tuple]: - r""" - current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. - """ - - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a dict instead of a plain tuple. - - Returns: - `SampleOutput` or `tuple`: - `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - - _self = self.delegate - - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある - # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する - # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い - default_overall_up_factor = 2**_self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - # 64で割り切れないときはupsamplerにサイズを伝える - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - # logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 1. time - timesteps = timestep - timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 - - t_emb = _self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - # timestepsは重みを含まないので常にfloat32のテンソルを返す - # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある - # time_projでキャストしておけばいいんじゃね? - t_emb = t_emb.to(dtype=_self.dtype) - emb = _self.time_embedding(t_emb) - - # 2. pre-process - sample = _self.conv_in(sample) - - down_block_res_samples = (sample,) - for depth, downsample_block in enumerate(_self.down_blocks): - # Deep Shrink - if self.ds_depth_1 is not None: - if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( - self.ds_depth_2 is not None - and depth == self.ds_depth_2 - and timesteps[0] < self.ds_timesteps_1 - and timesteps[0] >= self.ds_timesteps_2 - ): - org_dtype = sample.dtype - if org_dtype == torch.bfloat16: - sample = sample.to(torch.float32) - sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) - - # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 - # まあこちらのほうがわかりやすいかもしれない - if downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # skip connectionにControlNetの出力を追加する - if down_block_additional_residuals is not None: - down_block_res_samples = list(down_block_res_samples) - for i in range(len(down_block_res_samples)): - down_block_res_samples[i] += down_block_additional_residuals[i] - down_block_res_samples = tuple(down_block_res_samples) - - # 4. mid - sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - - # ControlNetの出力を追加する - if mid_block_additional_residual is not None: - sample += mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(_self.up_blocks): - is_final_block = i == len(_self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection - - # if we have not reached the final block and need to forward the upsample size, we do it here - # 前述のように最後のブロック以外ではupsample_sizeを伝える - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - - # 6. post-process - sample = _self.conv_norm_out(sample) - sample = _self.conv_act(sample) - sample = _self.conv_out(sample) - - if not return_dict: - return (sample,) - - return SampleOutput(sample=sample) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py deleted file mode 100644 index a63bd82ec..000000000 --- a/library/sai_model_spec.py +++ /dev/null @@ -1,309 +0,0 @@ -# based on https://github.com/Stability-AI/ModelSpec -import datetime -import hashlib -from io import BytesIO -import os -from typing import List, Optional, Tuple, Union -import safetensors -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -r""" -# Metadata Example -metadata = { - # === Must === - "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec - "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID - "modelspec.implementation": "sgm", - "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc - # === Should === - "modelspec.author": "Example Corp", # Your name or company name - "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know - "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created - # === Can === - "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. - "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model -} -""" - -BASE_METADATA = { - # === Must === - "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec - "modelspec.architecture": None, - "modelspec.implementation": None, - "modelspec.title": None, - "modelspec.resolution": None, - # === Should === - "modelspec.description": None, - "modelspec.author": None, - "modelspec.date": None, - # === Can === - "modelspec.license": None, - "modelspec.tags": None, - "modelspec.merged_from": None, - "modelspec.prediction_type": None, - "modelspec.timestep_range": None, - "modelspec.encoder_layer": None, -} - -# 別に使うやつだけ定義 -MODELSPEC_TITLE = "modelspec.title" - -ARCH_SD_V1 = "stable-diffusion-v1" -ARCH_SD_V2_512 = "stable-diffusion-v2-512" -ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" -ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" - -ADAPTER_LORA = "lora" -ADAPTER_TEXTUAL_INVERSION = "textual-inversion" - -IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" -IMPL_DIFFUSERS = "diffusers" - -PRED_TYPE_EPSILON = "epsilon" -PRED_TYPE_V = "v" - - -def load_bytes_in_safetensors(tensors): - bytes = safetensors.torch.save(tensors) - b = BytesIO(bytes) - - b.seek(0) - header = b.read(8) - n = int.from_bytes(header, "little") - - offset = n + 8 - b.seek(offset) - - return b.read() - - -def precalculate_safetensors_hashes(state_dict): - # calculate each tensor one by one to reduce memory usage - hash_sha256 = hashlib.sha256() - for tensor in state_dict.values(): - single_tensor_sd = {"tensor": tensor} - bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) - hash_sha256.update(bytes_for_tensor) - - return f"0x{hash_sha256.hexdigest()}" - - -def update_hash_sha256(metadata: dict, state_dict: dict): - raise NotImplementedError - - -def build_metadata( - state_dict: Optional[dict], - v2: bool, - v_parameterization: bool, - sdxl: bool, - lora: bool, - textual_inversion: bool, - timestamp: float, - title: Optional[str] = None, - reso: Optional[Union[int, Tuple[int, int]]] = None, - is_stable_diffusion_ckpt: Optional[bool] = None, - author: Optional[str] = None, - description: Optional[str] = None, - license: Optional[str] = None, - tags: Optional[str] = None, - merged_from: Optional[str] = None, - timesteps: Optional[Tuple[int, int]] = None, - clip_skip: Optional[int] = None, -): - # if state_dict is None, hash is not calculated - - metadata = {} - metadata.update(BASE_METADATA) - - # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する - # if state_dict is not None: - # hash = precalculate_safetensors_hashes(state_dict) - # metadata["modelspec.hash_sha256"] = hash - - if sdxl: - arch = ARCH_SD_XL_V1_BASE - elif v2: - if v_parameterization: - arch = ARCH_SD_V2_768_V - else: - arch = ARCH_SD_V2_512 - else: - arch = ARCH_SD_V1 - - if lora: - arch += f"/{ADAPTER_LORA}" - elif textual_inversion: - arch += f"/{ADAPTER_TEXTUAL_INVERSION}" - - metadata["modelspec.architecture"] = arch - - if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: - # Stable Diffusion ckpt, TI, SDXL LoRA - impl = IMPL_STABILITY_AI - else: - # v1/v2 LoRA or Diffusers - impl = IMPL_DIFFUSERS - metadata["modelspec.implementation"] = impl - - if title is None: - if lora: - title = "LoRA" - elif textual_inversion: - title = "TextualInversion" - else: - title = "Checkpoint" - title += f"@{timestamp}" - metadata[MODELSPEC_TITLE] = title - - if author is not None: - metadata["modelspec.author"] = author - else: - del metadata["modelspec.author"] - - if description is not None: - metadata["modelspec.description"] = description - else: - del metadata["modelspec.description"] - - if merged_from is not None: - metadata["modelspec.merged_from"] = merged_from - else: - del metadata["modelspec.merged_from"] - - if license is not None: - metadata["modelspec.license"] = license - else: - del metadata["modelspec.license"] - - if tags is not None: - metadata["modelspec.tags"] = tags - else: - del metadata["modelspec.tags"] - - # remove microsecond from time - int_ts = int(timestamp) - - # time to iso-8601 compliant date - date = datetime.datetime.fromtimestamp(int_ts).isoformat() - metadata["modelspec.date"] = date - - if reso is not None: - # comma separated to tuple - if isinstance(reso, str): - reso = tuple(map(int, reso.split(","))) - if len(reso) == 1: - reso = (reso[0], reso[0]) - else: - # resolution is defined in dataset, so use default - if sdxl: - reso = 1024 - elif v2 and v_parameterization: - reso = 768 - else: - reso = 512 - if isinstance(reso, int): - reso = (reso, reso) - - metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" - - if v_parameterization: - metadata["modelspec.prediction_type"] = PRED_TYPE_V - else: - metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON - - if timesteps is not None: - if isinstance(timesteps, str) or isinstance(timesteps, int): - timesteps = (timesteps, timesteps) - if len(timesteps) == 1: - timesteps = (timesteps[0], timesteps[0]) - metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" - else: - del metadata["modelspec.timestep_range"] - - if clip_skip is not None: - metadata["modelspec.encoder_layer"] = f"{clip_skip}" - else: - del metadata["modelspec.encoder_layer"] - - # # assert all values are filled - # assert all([v is not None for v in metadata.values()]), metadata - if not all([v is not None for v in metadata.values()]): - logger.error(f"Internal error: some metadata values are None: {metadata}") - - return metadata - - -# region utils - - -def get_title(metadata: dict) -> Optional[str]: - return metadata.get(MODELSPEC_TITLE, None) - - -def load_metadata_from_safetensors(model: str) -> dict: - if not model.endswith(".safetensors"): - return {} - - with safetensors.safe_open(model, framework="pt") as f: - metadata = f.metadata() - if metadata is None: - metadata = {} - return metadata - - -def build_merged_from(models: List[str]) -> str: - def get_title(model: str): - metadata = load_metadata_from_safetensors(model) - title = metadata.get(MODELSPEC_TITLE, None) - if title is None: - title = os.path.splitext(os.path.basename(model))[0] # use filename - return title - - titles = [get_title(model) for model in models] - return ", ".join(titles) - - -# endregion - - -r""" -if __name__ == "__main__": - import argparse - import torch - from safetensors.torch import load_file - from library import train_util - - parser = argparse.ArgumentParser() - parser.add_argument("--ckpt", type=str, required=True) - args = parser.parse_args() - - print(f"Loading {args.ckpt}") - state_dict = load_file(args.ckpt) - - print(f"Calculating metadata") - metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) - print(metadata) - del state_dict - - # by reference implementation - with open(args.ckpt, mode="rb") as file_data: - file_hash = hashlib.sha256() - head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix - header = json.loads(file_data.read(head_len[0])) # header itself, json string - content = ( - file_data.read() - ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. - file_hash.update(content) - # ===== Update the hash for modelspec ===== - by_ref = f"0x{file_hash.hexdigest()}" - print(by_ref) - print("is same?", by_ref == metadata["modelspec.hash_sha256"]) - -""" diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py deleted file mode 100644 index 03b182566..000000000 --- a/library/sdxl_lpw_stable_diffusion.py +++ /dev/null @@ -1,1347 +0,0 @@ -# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# and modify to support SD2.x - -import inspect -import re -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from packaging import version -from tqdm import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from diffusers import SchedulerMixin, StableDiffusionPipeline -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker -from diffusers.utils import logging -from PIL import Image - -from library import sdxl_model_util, sdxl_train_util, train_util - - -try: - from diffusers.utils import PIL_INTERPOLATION -except ImportError: - if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } - else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } -# ------------------------------------------------------------------------------ - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device): - if not is_sdxl_text_encoder2: - # text_encoder1: same as SD1/2 - enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) - hidden_states = enc_out["hidden_states"][11] - pool = None - else: - # text_encoder2 - enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) - hidden_states = enc_out["hidden_states"][-2] # penuultimate layer - # pool = enc_out["text_embeds"] - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id) - hidden_states = hidden_states.to(device) - if pool is not None: - pool = pool.to(device) - return hidden_states, pool - - -def get_unweighted_text_embeddings( - pipe: StableDiffusionPipeline, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - is_sdxl_text_encoder2: bool, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - text_pool = None - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - text_embedding, current_text_pool = get_hidden_states( - pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device - ) - if text_pool is None: - text_pool = current_text_pool - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device) - return text_embeddings, text_pool - - -def get_weighted_text_embeddings( - pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, - is_sdxl_text_encoder2=False, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - - Args: - pipe (`StableDiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - pad = pipe.tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - - # get the embeddings - text_embeddings, text_pool = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - is_sdxl_text_encoder2, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - - if uncond_prompt is not None: - uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( - pipe, - uncond_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, - pad, - is_sdxl_text_encoder2, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - return text_embeddings, text_pool, None, None - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask, scale_factor=8): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -def prepare_controlnet_image( - image: PIL.Image.Image, - width: int, - height: int, - batch_size: int, - num_images_per_prompt: int, - device: torch.device, - dtype: torch.dtype, - do_classifier_free_guidance: bool = False, - guess_mode: bool = False, -): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - -class SdxlStableDiffusionLongPromptWeightingPipeline: - r""" - Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing - weighting in prompt. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: List[CLIPTextModel], - tokenizer: List[CLIPTokenizer], - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - # clip_skip: int, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - clip_skip: int = 1, - ): - # clip skip is ignored currently - self.tokenizer = tokenizer[0] - self.text_encoder = text_encoder[0] - self.unet = unet - self.scheduler = scheduler - self.safety_checker = safety_checker - self.feature_extractor = feature_extractor - self.requires_safety_checker = requires_safety_checker - self.vae = vae - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.progress_bar = lambda x: tqdm(x, leave=False) - - self.clip_skip = clip_skip - self.tokenizers = tokenizer - self.text_encoders = text_encoder - - # self.__init__additional__() - - # def __init__additional__(self): - # if not hasattr(self, "vae_scale_factor"): - # setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) - - def to(self, device=None, dtype=None): - if device is not None: - self.device = device - # self.vae.to(device=self.device) - if dtype is not None: - self.dtype = dtype - - # do not move Text Encoders to device, because Text Encoder should be on CPU - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - is_sdxl_text_encoder2=is_sdxl_text_encoder2, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ?? - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if text_pool is not None: - text_pool = text_pool.repeat(1, num_images_per_prompt) - text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1) - - if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - if uncond_pool is not None: - uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) - uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) - - return text_embeddings, text_pool, uncond_embeddings, uncond_pool - - return text_embeddings, text_pool, None, None - - def check_inputs(self, prompt, height, width, strength, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - def get_timesteps(self, num_inference_steps, strength, device, is_text2img): - if is_text2img: - return self.scheduler.timesteps.to(device), num_inference_steps - else: - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(device) - return timesteps, num_inference_steps - t_start - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - def decode_latents(self, latents): - with torch.no_grad(): - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - - # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32 - # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0) - # print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16 - # self.vae.to("cpu") - # self.vae.set_use_memory_efficient_attention_xformers(False) - # image = self.vae.decode(latents.to("cpu")).sample - - image = self.vae.decode(latents.to(self.vae.dtype)).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): - if image is None: - shape = ( - batch_size, - self.unet.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - - if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents, None, None - else: - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents - init_latents = torch.cat([init_latents] * batch_size, dim=0) - init_latents_orig = init_latents - shape = init_latents.shape - - # add noise to latents using the timesteps - if device.type == "mps": - noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.add_noise(init_latents, noise, timestep) - return latents, init_latents_orig, noise - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - strength: float = 0.8, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - controlnet=None, - controlnet_image=None, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - controlnet (`diffusers.ControlNetModel`, *optional*): - A controlnet model to be used for the inference. If not provided, controlnet will be disabled. - controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet - inference. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - `None` if cancelled by `is_cancelled_callback`, - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if controlnet is not None and controlnet_image is None: - raise ValueError("controlnet_image must be provided if controlnet is not None.") - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, strength, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す - # To simplify the implementation, switch the tokenzer/text encoder and call it twice - text_embeddings_list = [] - text_pool = None - uncond_embeddings_list = [] - uncond_pool = None - for i in range(len(self.tokenizers)): - self.tokenizer = self.tokenizers[i] - self.text_encoder = self.text_encoders[i] - - text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, - is_sdxl_text_encoder2=i == 1, - ) - text_embeddings_list.append(text_embeddings) - uncond_embeddings_list.append(uncond_embeddings) - - if tp1 is not None: - text_pool = tp1 - if up1 is not None: - uncond_pool = up1 - - unet_dtype = self.unet.dtype - dtype = unet_dtype - if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 - dtype = torch.float16 - self.unet.to(dtype) - - # 4. Preprocess image and mask - if isinstance(image, PIL.Image.Image): - image = preprocess_image(image) - if image is not None: - image = image.to(device=self.device, dtype=dtype) - if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image, self.vae_scale_factor) - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=dtype) - mask = torch.cat([mask] * batch_size * num_images_per_prompt) - else: - mask = None - - # ControlNet is not working yet in SDXL, but keep the code here for future use - if controlnet_image is not None: - controlnet_image = prepare_controlnet_image( - controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False - ) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents, init_latents_orig, noise = self.prepare_latents( - image, - latent_timestep, - batch_size * num_images_per_prompt, - height, - width, - dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # create size embs and concat embeddings for SDXL - orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype) - crop_size = torch.zeros_like(orig_size) - target_size = orig_size - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) - - # make conditionings - if do_classifier_free_guidance: - text_embeddings = torch.cat(text_embeddings_list, dim=2) - uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2) - text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype) - - cond_vector = torch.cat([text_pool, embs], dim=1) - uncond_vector = torch.cat([uncond_pool, embs], dim=1) - vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype) - else: - text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) - vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype) - - # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - unet_additional_args = {} - if controlnet is not None: - down_block_res_samples, mid_block_res_sample = controlnet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - controlnet_cond=controlnet_image, - conditioning_scale=1.0, - guess_mode=False, - return_dict=False, - ) - unet_additional_args["down_block_additional_residuals"] = down_block_res_samples - unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) - noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - self.unet.to(unet_dtype) - return latents - - def latents_to_image(self, latents): - # 9. Post-processing - image = self.decode_latents(latents.to(self.vae.dtype)) - image = self.numpy_to_pil(image) - return image - - # copy from pil_utils.py - def numpy_to_pil(self, images: np.ndarray) -> Image.Image: - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def text2img( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) - - def img2img( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for image-to-image generation. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. - `image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) - - def inpaint( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: int = 1, - ): - r""" - Function for inpaint. - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - image=image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - is_cancelled_callback=is_cancelled_callback, - callback_steps=callback_steps, - ) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py deleted file mode 100644 index f03f1bae5..000000000 --- a/library/sdxl_model_util.py +++ /dev/null @@ -1,577 +0,0 @@ -import torch -from accelerate import init_empty_weights -from accelerate.utils.modeling import set_module_tensor_to_device -from safetensors.torch import load_file, save_file -from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer -from typing import List -from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel -from library import model_util -from library import sdxl_original_unet -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -VAE_SCALE_FACTOR = 0.13025 -MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" - -# Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" - -DIFFUSERS_SDXL_UNET_CONFIG = { - "act_fn": "silu", - "addition_embed_type": "text_time", - "addition_embed_type_num_heads": 64, - "addition_time_embed_dim": 256, - "attention_head_dim": [5, 10, 20], - "block_out_channels": [320, 640, 1280], - "center_input_sample": False, - "class_embed_type": None, - "class_embeddings_concat": False, - "conv_in_kernel": 3, - "conv_out_kernel": 3, - "cross_attention_dim": 2048, - "cross_attention_norm": None, - "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"], - "downsample_padding": 1, - "dual_cross_attention": False, - "encoder_hid_dim": None, - "encoder_hid_dim_type": None, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_only_cross_attention": None, - "mid_block_scale_factor": 1, - "mid_block_type": "UNetMidBlock2DCrossAttn", - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_attention_heads": None, - "num_class_embeds": None, - "only_cross_attention": False, - "out_channels": 4, - "projection_class_embeddings_input_dim": 2816, - "resnet_out_scale_factor": 1.0, - "resnet_skip_time_act": False, - "resnet_time_scale_shift": "default", - "sample_size": 128, - "time_cond_proj_dim": None, - "time_embedding_act_fn": None, - "time_embedding_dim": None, - "time_embedding_type": "positional", - "timestep_post_act": None, - "transformer_layers_per_block": [1, 2, 10], - "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], - "upcast_attention": False, - "use_linear_projection": True, -} - - -def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): - SDXL_KEY_PREFIX = "conditioner.embedders.1.model." - - # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す - # logit_scaleはcheckpointの保存時に使用する - def convert_key(key): - # common conversion - key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") - key = key.replace(SDXL_KEY_PREFIX, "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif ".attn.out_proj" in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif ".attn.in_proj" in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif ".positional_embedding" in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif ".text_projection" in key: - key = key.replace("text_model.text_projection", "text_projection.weight") - elif ".logit_scale" in key: - key = None # 後で処理する - elif ".token_embedding" in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif ".ln_final" in key: - key = key.replace(".ln_final", ".final_layer_norm") - # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids - elif ".embeddings.position_ids" in key: - key = None # remove this key: position_ids is not used in newer transformers - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if ".resblocks" in key and ".attn.in_proj_" in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す - logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) - - # temporary workaround for text_projection.weight.weight for Playground-v2 - if "text_projection.weight.weight" in new_sd: - logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") - new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] - del new_sd["text_projection.weight.weight"] - - return new_sd, logit_scale - - -# load state_dict without allocating new tensors -def _load_state_dict_on_device(model, state_dict, device, dtype=None): - # dtype will use fp32 as default - missing_keys = list(model.state_dict().keys() - state_dict.keys()) - unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) - - # similar to model.load_state_dict() - if not missing_keys and not unexpected_keys: - for k in list(state_dict.keys()): - set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) - return "" - - # error_msgs - error_msgs: List[str] = [] - if missing_keys: - error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) - if unexpected_keys: - error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) - - raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) - - -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): - # model_version is reserved for future use - # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching - - # Load the state dict - if model_util.is_safetensors(ckpt_path): - checkpoint = None - try: - state_dict = load_file(ckpt_path, device=map_location) - except: - state_dict = load_file(ckpt_path) # prevent device invalid Error - epoch = None - global_step = None - else: - checkpoint = torch.load(ckpt_path, map_location=map_location) - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - epoch = checkpoint.get("epoch", 0) - global_step = checkpoint.get("global_step", 0) - else: - state_dict = checkpoint - epoch = 0 - global_step = 0 - checkpoint = None - - # U-Net - logger.info("building U-Net") - with init_empty_weights(): - unet = sdxl_original_unet.SdxlUNet2DConditionModel() - - logger.info("loading U-Net from checkpoint") - unet_sd = {} - for k in list(state_dict.keys()): - if k.startswith("model.diffusion_model."): - unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) - info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) - logger.info(f"U-Net: {info}") - - # Text Encoders - logger.info("building text encoders") - - # Text Encoder 1 is same to Stability AI's SDXL - text_model1_cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - max_position_embeddings=77, - hidden_act="quick_gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=768, - # torch_dtype="float32", - # transformers_version="4.25.0.dev0", - ) - with init_empty_weights(): - text_model1 = CLIPTextModel._from_config(text_model1_cfg) - - # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. - # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. - text_model2_cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1280, - intermediate_size=5120, - num_hidden_layers=32, - num_attention_heads=20, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=1280, - # torch_dtype="float32", - # transformers_version="4.25.0.dev0", - ) - with init_empty_weights(): - text_model2 = CLIPTextModelWithProjection(text_model2_cfg) - - logger.info("loading text encoders from checkpoint") - te1_sd = {} - te2_sd = {} - for k in list(state_dict.keys()): - if k.startswith("conditioner.embedders.0.transformer."): - te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) - elif k.startswith("conditioner.embedders.1.model."): - te2_sd[k] = state_dict.pop(k) - - # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers - if "text_model.embeddings.position_ids" in te1_sd: - te1_sd.pop("text_model.embeddings.position_ids") - - info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 - logger.info(f"text encoder 1: {info1}") - - converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) - info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 - logger.info(f"text encoder 2: {info2}") - - # prepare vae - logger.info("building VAE") - vae_config = model_util.create_vae_diffusers_config() - with init_empty_weights(): - vae = AutoencoderKL(**vae_config) - - logger.info("loading VAE from checkpoint") - converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) - info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) - logger.info(f"VAE: {info}") - - ckpt_info = (epoch, global_step) if epoch is not None else None - return text_model1, text_model2, vae, unet, logit_scale, ckpt_info - - -def make_unet_conversion_map(): - unet_conversion_map_layer = [] - - for i in range(3): # num_blocks is 3 in sdxl - # loop over downblocks/upblocks - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - # if i > 0: commentout for sdxl - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0.", "norm1."), - ("in_layers.2.", "conv1."), - ("out_layers.0.", "norm2."), - ("out_layers.3.", "conv2."), - ("emb_layers.1.", "time_emb_proj."), - ("skip_connection.", "conv_shortcut."), - ] - - unet_conversion_map = [] - for sd, hf in unet_conversion_map_layer: - if "resnets" in hf: - for sd_res, hf_res in unet_conversion_map_resnet: - unet_conversion_map.append((sd + sd_res, hf + hf_res)) - else: - unet_conversion_map.append((sd, hf)) - - for j in range(2): - hf_time_embed_prefix = f"time_embedding.linear_{j+1}." - sd_time_embed_prefix = f"time_embed.{j*2}." - unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) - - for j in range(2): - hf_label_embed_prefix = f"add_embedding.linear_{j+1}." - sd_label_embed_prefix = f"label_emb.0.{j*2}." - unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) - - unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) - unet_conversion_map.append(("out.0.", "conv_norm_out.")) - unet_conversion_map.append(("out.2.", "conv_out.")) - - return unet_conversion_map - - -def convert_diffusers_unet_state_dict_to_sdxl(du_sd): - unet_conversion_map = make_unet_conversion_map() - - conversion_map = {hf: sd for sd, hf in unet_conversion_map} - return convert_unet_state_dict(du_sd, conversion_map) - - -def convert_unet_state_dict(src_sd, conversion_map): - converted_sd = {} - for src_key, value in src_sd.items(): - # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す - src_key_fragments = src_key.split(".")[:-1] # remove weight/bias - while len(src_key_fragments) > 0: - src_key_prefix = ".".join(src_key_fragments) + "." - if src_key_prefix in conversion_map: - converted_prefix = conversion_map[src_key_prefix] - converted_key = converted_prefix + src_key[len(src_key_prefix) :] - converted_sd[converted_key] = value - break - src_key_fragments.pop(-1) - assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" - - return converted_sd - - -def convert_sdxl_unet_state_dict_to_diffusers(sd): - unet_conversion_map = make_unet_conversion_map() - - conversion_dict = {sd: hf for sd, hf in unet_conversion_map} - return convert_unet_state_dict(sd, conversion_dict) - - -def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif ".self_attn.out_proj" in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif ".self_attn." in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif ".position_embedding" in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif ".token_embedding" in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif "text_projection" in key: # no dot in key - key = key.replace("text_projection.weight", "text_projection") - elif "final_layer_norm" in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if "layers" in key and "q_proj" in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - if logit_scale is not None: - new_sd["logit_scale"] = logit_scale - - return new_sd - - -def save_stable_diffusion_checkpoint( - output_file, - text_encoder1, - text_encoder2, - unet, - epochs, - steps, - ckpt_info, - vae, - logit_scale, - metadata, - save_dtype=None, -): - state_dict = {} - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - update_sd("model.diffusion_model.", unet.state_dict()) - - # Convert the text encoders - update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) - - text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale) - update_sd("conditioner.embedders.1.model.", text_enc2_dict) - - # Convert the VAE - vae_dict = model_util.convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {"state_dict": state_dict} - - # epoch and global_step are sometimes not int - if ckpt_info is not None: - epochs += ckpt_info[0] - steps += ckpt_info[1] - - new_ckpt["epoch"] = epochs - new_ckpt["global_step"] = steps - - if model_util.is_safetensors(output_file): - save_file(state_dict, output_file, metadata) - else: - torch.save(new_ckpt, output_file) - - return key_count - - -def save_diffusers_checkpoint( - output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None -): - from diffusers import StableDiffusionXLPipeline - - # convert U-Net - unet_sd = unet.state_dict() - du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) - - diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) - if save_dtype is not None: - diffusers_unet.to(save_dtype) - diffusers_unet.load_state_dict(du_unet_sd) - - # create pipeline to save - if pretrained_model_name_or_path is None: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL - - scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - - # prevent local path from being saved - def remove_name_or_path(model): - if hasattr(model, "config"): - model.config._name_or_path = None - model.config._name_or_path = None - - remove_name_or_path(diffusers_unet) - remove_name_or_path(text_encoder1) - remove_name_or_path(text_encoder2) - remove_name_or_path(scheduler) - remove_name_or_path(tokenizer1) - remove_name_or_path(tokenizer2) - remove_name_or_path(vae) - - pipeline = StableDiffusionXLPipeline( - unet=diffusers_unet, - text_encoder=text_encoder1, - text_encoder_2=text_encoder2, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer1, - tokenizer_2=tokenizer2, - ) - if save_dtype is not None: - pipeline.to(None, save_dtype) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py deleted file mode 100644 index 673cf9f65..000000000 --- a/library/sdxl_original_unet.py +++ /dev/null @@ -1,1284 +0,0 @@ -# Diffusersのコードをベースとした sd_xl_baseのU-Net -# state dictの形式をSDXLに合わせてある - -""" - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - adm_in_channels: 2816 - num_classes: sequential - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2] - num_res_blocks: 2 - channel_mult: [1, 2, 4] - num_head_channels: 64 - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 - context_dim: 2048 - spatial_transformer_attn_type: softmax-xformers - legacy: False -""" - -import math -from types import SimpleNamespace -from typing import Any, Optional -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import functional as F -from einops import rearrange -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -IN_CHANNELS: int = 4 -OUT_CHANNELS: int = 4 -ADM_IN_CHANNELS: int = 2816 -CONTEXT_DIM: int = 2048 -MODEL_CHANNELS: int = 320 -TIME_EMBED_DIM = 320 * 4 - -USE_REENTRANT = True - -# region memory efficient attention - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) - dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -# endregion - - -def get_parameter_dtype(parameter: torch.nn.Module): - return next(parameter.parameters()).dtype - - -def get_parameter_device(parameter: torch.nn.Module): - return next(parameter.parameters()).device - - -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True - emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -# Deep Shrink: We do not common this function, because minimize dependencies. -def resize_like(x, target, mode="bicubic", align_corners=False): - org_dtype = x.dtype - if org_dtype == torch.bfloat16: - x = x.to(torch.float32) - - if x.shape[-2:] != target.shape[-2:]: - if mode == "nearest": - x = F.interpolate(x, size=target.shape[-2:], mode=mode) - else: - x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) - - if org_dtype == torch.bfloat16: - x = x.to(org_dtype) - return x - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - if self.weight.dtype != torch.float32: - return super().forward(x) - return super().forward(x.float()).type(x.dtype) - - -class ResnetBlock2D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.in_layers = nn.Sequential( - GroupNorm32(32, in_channels), - nn.SiLU(), - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), - ) - - self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels)) - - self.out_layers = nn.Sequential( - GroupNorm32(32, out_channels), - nn.SiLU(), - nn.Identity(), # to make state_dict compatible with original model - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), - ) - - if in_channels != out_channels: - self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - else: - self.skip_connection = nn.Identity() - - self.gradient_checkpointing = False - - def forward_body(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - h = h + emb_out[:, :, None, None] - h = self.out_layers(h) - x = self.skip_connection(x) - return x + h - - def forward(self, x, emb): - if self.training and self.gradient_checkpointing: - # logger.info("ResnetBlock2D: gradient_checkpointing") - - def create_custom_forward(func): - def custom_forward(*inputs): - return func(*inputs) - - return custom_forward - - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT) - else: - x = self.forward_body(x, emb) - - return x - - -class Downsample2D(nn.Module): - def __init__(self, channels, out_channels): - super().__init__() - - self.channels = channels - self.out_channels = out_channels - - self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) - - self.gradient_checkpointing = False - - def forward_body(self, hidden_states): - assert hidden_states.shape[1] == self.channels - hidden_states = self.op(hidden_states) - - return hidden_states - - def forward(self, hidden_states): - if self.training and self.gradient_checkpointing: - # logger.info("Downsample2D: gradient_checkpointing") - - def create_custom_forward(func): - def custom_forward(*inputs): - return func(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT - ) - else: - hidden_states = self.forward_body(hidden_states) - - return hidden_states - - -class CrossAttention(nn.Module): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - upcast_attention: bool = False, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - # no dropout here - - self.use_memory_efficient_attention_xformers = False - self.use_memory_efficient_attention_mem_eff = False - self.use_sdpa = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - self.use_memory_efficient_attention_xformers = xformers - self.use_memory_efficient_attention_mem_eff = mem_eff - - def set_use_sdpa(self, sdpa): - self.use_sdpa = sdpa - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def forward(self, hidden_states, context=None, mask=None): - if self.use_memory_efficient_attention_xformers: - return self.forward_memory_efficient_xformers(hidden_states, context, mask) - if self.use_memory_efficient_attention_mem_eff: - return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) - if self.use_sdpa: - return self.forward_sdpa(hidden_states, context, mask) - - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) - - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - hidden_states = self._attention(query, key, value) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # hidden_states = self.to_out[1](hidden_states) # no dropout - return hidden_states - - def _attention(self, query, key, value): - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - - # cast back to the original dtype - attention_probs = attention_probs.to(value.dtype) - - # compute attention output - hidden_states = torch.bmm(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - # TODO support Hypernetworks - def forward_memory_efficient_xformers(self, x, context=None, mask=None): - import xformers.ops - - h = self.heads - q_in = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - del q, k, v - - out = rearrange(out, "b n h d -> b n (h d)", h=h) - - out = self.to_out[0](out) - return out - - def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): - flash_func = FlashAttentionFunction - - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - out = self.to_out[0](out) - return out - - def forward_sdpa(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - - out = rearrange(out, "b h n d -> b n (h d)", h=h) - - out = self.to_out[0](out) - return out - - -# feedforward -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - ): - super().__init__() - inner_dim = int(dim * 4) # mult is always 4 - - self.net = nn.ModuleList([]) - # project in - self.net.append(GEGLU(dim, inner_dim)) - # project dropout - self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 - # project out - self.net.append(nn.Linear(inner_dim, dim)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False - ): - super().__init__() - - self.gradient_checkpointing = False - - # 1. Self-Attn - self.attn1 = CrossAttention( - query_dim=dim, - cross_attention_dim=None, - heads=num_attention_heads, - dim_head=attention_head_dim, - upcast_attention=upcast_attention, - ) - self.ff = FeedForward(dim) - - # 2. Cross-Attn - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - upcast_attention=upcast_attention, - ) - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim) - - def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): - self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) - self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa: bool): - self.attn1.set_use_sdpa(sdpa) - self.attn2.set_use_sdpa(sdpa) - - def forward_body(self, hidden_states, context=None, timestep=None): - # 1. Self-Attention - norm_hidden_states = self.norm1(hidden_states) - - hidden_states = self.attn1(norm_hidden_states) + hidden_states - - # 2. Cross-Attention - norm_hidden_states = self.norm2(hidden_states) - hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states - - # 3. Feed-forward - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - - return hidden_states - - def forward(self, hidden_states, context=None, timestep=None): - if self.training and self.gradient_checkpointing: - # logger.info("BasicTransformerBlock: checkpointing") - - def create_custom_forward(func): - def custom_forward(*inputs): - return func(*inputs) - - return custom_forward - - output = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT - ) - else: - output = self.forward_body(hidden_states, context, timestep) - - return output - - -class Transformer2DModel(nn.Module): - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - use_linear_projection: bool = False, - upcast_attention: bool = False, - num_transformer_layers: int = 1, - ): - super().__init__() - self.in_channels = in_channels - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.use_linear_projection = use_linear_projection - - self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True) - - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - blocks = [] - for _ in range(num_transformer_layers): - blocks.append( - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - cross_attention_dim=cross_attention_dim, - upcast_attention=upcast_attention, - ) - ) - - self.transformer_blocks = nn.ModuleList(blocks) - - if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) - else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - self.gradient_checkpointing = False - - def set_use_memory_efficient_attention(self, xformers, mem_eff): - for transformer in self.transformer_blocks: - transformer.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa): - for transformer in self.transformer_blocks: - transformer.set_use_sdpa(sdpa) - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None): - # 1. Input - batch, _, height, weight = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - hidden_states = self.proj_in(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) - - # 3. Output - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - - return output - - -class Upsample2D(nn.Module): - def __init__(self, channels, out_channels): - super().__init__() - self.channels = channels - self.out_channels = out_channels - self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - - self.gradient_checkpointing = False - - def forward_body(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - hidden_states = self.conv(hidden_states) - - return hidden_states - - def forward(self, hidden_states, output_size=None): - if self.training and self.gradient_checkpointing: - # logger.info("Upsample2D: gradient_checkpointing") - - def create_custom_forward(func): - def custom_forward(*inputs): - return func(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT - ) - else: - hidden_states = self.forward_body(hidden_states, output_size) - - return hidden_states - - -class SdxlUNet2DConditionModel(nn.Module): - _supports_gradient_checkpointing = True - - def __init__( - self, - **kwargs, - ): - super().__init__() - - self.in_channels = IN_CHANNELS - self.out_channels = OUT_CHANNELS - self.model_channels = MODEL_CHANNELS - self.time_embed_dim = TIME_EMBED_DIM - self.adm_in_channels = ADM_IN_CHANNELS - - self.gradient_checkpointing = False - # self.sample_size = sample_size - - # time embedding - self.time_embed = nn.Sequential( - nn.Linear(self.model_channels, self.time_embed_dim), - nn.SiLU(), - nn.Linear(self.time_embed_dim, self.time_embed_dim), - ) - - # label embedding - self.label_emb = nn.Sequential( - nn.Sequential( - nn.Linear(self.adm_in_channels, self.time_embed_dim), - nn.SiLU(), - nn.Linear(self.time_embed_dim, self.time_embed_dim), - ) - ) - - # input - self.input_blocks = nn.ModuleList( - [ - nn.Sequential( - nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)), - ) - ] - ) - - # level 0 - for i in range(2): - layers = [ - ResnetBlock2D( - in_channels=1 * self.model_channels, - out_channels=1 * self.model_channels, - ), - ] - self.input_blocks.append(nn.ModuleList(layers)) - - self.input_blocks.append( - nn.Sequential( - Downsample2D( - channels=1 * self.model_channels, - out_channels=1 * self.model_channels, - ), - ) - ) - - # level 1 - for i in range(2): - layers = [ - ResnetBlock2D( - in_channels=(1 if i == 0 else 2) * self.model_channels, - out_channels=2 * self.model_channels, - ), - Transformer2DModel( - num_attention_heads=2 * self.model_channels // 64, - attention_head_dim=64, - in_channels=2 * self.model_channels, - num_transformer_layers=2, - use_linear_projection=True, - cross_attention_dim=2048, - ), - ] - self.input_blocks.append(nn.ModuleList(layers)) - - self.input_blocks.append( - nn.Sequential( - Downsample2D( - channels=2 * self.model_channels, - out_channels=2 * self.model_channels, - ), - ) - ) - - # level 2 - for i in range(2): - layers = [ - ResnetBlock2D( - in_channels=(2 if i == 0 else 4) * self.model_channels, - out_channels=4 * self.model_channels, - ), - Transformer2DModel( - num_attention_heads=4 * self.model_channels // 64, - attention_head_dim=64, - in_channels=4 * self.model_channels, - num_transformer_layers=10, - use_linear_projection=True, - cross_attention_dim=2048, - ), - ] - self.input_blocks.append(nn.ModuleList(layers)) - - # mid - self.middle_block = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=4 * self.model_channels, - out_channels=4 * self.model_channels, - ), - Transformer2DModel( - num_attention_heads=4 * self.model_channels // 64, - attention_head_dim=64, - in_channels=4 * self.model_channels, - num_transformer_layers=10, - use_linear_projection=True, - cross_attention_dim=2048, - ), - ResnetBlock2D( - in_channels=4 * self.model_channels, - out_channels=4 * self.model_channels, - ), - ] - ) - - # output - self.output_blocks = nn.ModuleList([]) - - # level 2 - for i in range(3): - layers = [ - ResnetBlock2D( - in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels, - out_channels=4 * self.model_channels, - ), - Transformer2DModel( - num_attention_heads=4 * self.model_channels // 64, - attention_head_dim=64, - in_channels=4 * self.model_channels, - num_transformer_layers=10, - use_linear_projection=True, - cross_attention_dim=2048, - ), - ] - if i == 2: - layers.append( - Upsample2D( - channels=4 * self.model_channels, - out_channels=4 * self.model_channels, - ) - ) - - self.output_blocks.append(nn.ModuleList(layers)) - - # level 1 - for i in range(3): - layers = [ - ResnetBlock2D( - in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels, - out_channels=2 * self.model_channels, - ), - Transformer2DModel( - num_attention_heads=2 * self.model_channels // 64, - attention_head_dim=64, - in_channels=2 * self.model_channels, - num_transformer_layers=2, - use_linear_projection=True, - cross_attention_dim=2048, - ), - ] - if i == 2: - layers.append( - Upsample2D( - channels=2 * self.model_channels, - out_channels=2 * self.model_channels, - ) - ) - - self.output_blocks.append(nn.ModuleList(layers)) - - # level 0 - for i in range(3): - layers = [ - ResnetBlock2D( - in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels, - out_channels=1 * self.model_channels, - ), - ] - - self.output_blocks.append(nn.ModuleList(layers)) - - # output - self.out = nn.ModuleList( - [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] - ) - - # region diffusers compatibility - def prepare_config(self): - self.config = SimpleNamespace() - - @property - def dtype(self) -> torch.dtype: - # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - return get_parameter_dtype(self) - - @property - def device(self) -> torch.device: - # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). - return get_parameter_device(self) - - def set_attention_slice(self, slice_size): - raise NotImplementedError("Attention slicing is not supported for this model.") - - def is_gradient_checkpointing(self) -> bool: - return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing = True - self.set_gradient_checkpointing(value=True) - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing = False - self.set_gradient_checkpointing(value=False) - - def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: - blocks = self.input_blocks + [self.middle_block] + self.output_blocks - for block in blocks: - for module in block: - if hasattr(module, "set_use_memory_efficient_attention"): - # logger.info(module.__class__.__name__) - module.set_use_memory_efficient_attention(xformers, mem_eff) - - def set_use_sdpa(self, sdpa: bool) -> None: - blocks = self.input_blocks + [self.middle_block] + self.output_blocks - for block in blocks: - for module in block: - if hasattr(module, "set_use_sdpa"): - module.set_use_sdpa(sdpa) - - def set_gradient_checkpointing(self, value=False): - blocks = self.input_blocks + [self.middle_block] + self.output_blocks - for block in blocks: - for module in block.modules(): - if hasattr(module, "gradient_checkpointing"): - # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") - module.gradient_checkpointing = value - - # endregion - - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - # broadcast timesteps to batch dimension - timesteps = timesteps.expand(x.shape[0]) - - hs = [] - t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) - t_emb = t_emb.to(x.dtype) - emb = self.time_embed(t_emb) - - assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" - assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" - # assert x.dtype == self.dtype - emb = emb + self.label_emb(y) - - def call_module(module, h, emb, context): - x = h - for layer in module: - # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) - if isinstance(layer, ResnetBlock2D): - x = layer(x, emb) - elif isinstance(layer, Transformer2DModel): - x = layer(x, context) - else: - x = layer(x) - return x - - # h = x.type(self.dtype) - h = x - - for module in self.input_blocks: - h = call_module(module, h, emb, context) - hs.append(h) - - h = call_module(self.middle_block, h, emb, context) - - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = call_module(module, h, emb, context) - - h = h.type(x.dtype) - h = call_module(self.out, h, emb, context) - - return h - - -class InferSdxlUNet2DConditionModel: - def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): - self.delegate = original_unet - - # override original model's forward method: because forward is not called by `__call__` - # overriding `__call__` is not enough, because nn.Module.forward has a special handling - self.delegate.forward = self.forward - - # Deep Shrink - self.ds_depth_1 = None - self.ds_depth_2 = None - self.ds_timesteps_1 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - - # call original model's methods - def __getattr__(self, name): - return getattr(self.delegate, name) - - def __call__(self, *args, **kwargs): - return self.delegate(*args, **kwargs) - - def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): - if ds_depth_1 is None: - logger.info("Deep Shrink is disabled.") - self.ds_depth_1 = None - self.ds_timesteps_1 = None - self.ds_depth_2 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - else: - logger.info( - f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" - ) - self.ds_depth_1 = ds_depth_1 - self.ds_timesteps_1 = ds_timesteps_1 - self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 - self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 - self.ds_ratio = ds_ratio - - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - r""" - current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. - """ - _self = self.delegate - - # broadcast timesteps to batch dimension - timesteps = timesteps.expand(x.shape[0]) - - hs = [] - t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) - t_emb = t_emb.to(x.dtype) - emb = _self.time_embed(t_emb) - - assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" - assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" - # assert x.dtype == _self.dtype - emb = emb + _self.label_emb(y) - - def call_module(module, h, emb, context): - x = h - for layer in module: - # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) - if isinstance(layer, ResnetBlock2D): - x = layer(x, emb) - elif isinstance(layer, Transformer2DModel): - x = layer(x, context) - else: - x = layer(x) - return x - - # h = x.type(self.dtype) - h = x - - for depth, module in enumerate(_self.input_blocks): - # Deep Shrink - if self.ds_depth_1 is not None: - if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( - self.ds_depth_2 is not None - and depth == self.ds_depth_2 - and timesteps[0] < self.ds_timesteps_1 - and timesteps[0] >= self.ds_timesteps_2 - ): - # print("downsample", h.shape, self.ds_ratio) - org_dtype = h.dtype - if org_dtype == torch.bfloat16: - h = h.to(torch.float32) - h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) - - h = call_module(module, h, emb, context) - hs.append(h) - - h = call_module(_self.middle_block, h, emb, context) - - for module in _self.output_blocks: - # Deep Shrink - if self.ds_depth_1 is not None: - if hs[-1].shape[-2:] != h.shape[-2:]: - # print("upsample", h.shape, hs[-1].shape) - h = resize_like(h, hs[-1]) - - h = torch.cat([h, hs.pop()], dim=1) - h = call_module(module, h, emb, context) - - # Deep Shrink: in case of depth 0 - if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]: - # print("upsample", h.shape, x.shape) - h = resize_like(h, x) - - h = h.type(x.dtype) - h = call_module(_self.out, h, emb, context) - - return h - - -if __name__ == "__main__": - import time - - logger.info("create unet") - unet = SdxlUNet2DConditionModel() - - unet.to("cuda") - unet.set_use_memory_efficient_attention(True, False) - unet.set_gradient_checkpointing(True) - unet.train() - - # 使用メモリ量確認用の疑似学習ループ - logger.info("preparing optimizer") - - # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working - - # import bitsandbytes - # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working - # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 - # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 - - import transformers - - optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 - - scaler = torch.cuda.amp.GradScaler(enabled=True) - - logger.info("start training") - steps = 10 - batch_size = 1 - - for step in range(steps): - logger.info(f"step {step}") - if step == 1: - time_start = time.perf_counter() - - x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 - t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda") - ctx = torch.randn(batch_size, 77, 2048).cuda() - y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda() - - with torch.cuda.amp.autocast(enabled=True): - output = unet(x, t, ctx, y) - target = torch.randn_like(output) - loss = torch.nn.functional.mse_loss(output, target) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - - time_end = time.perf_counter() - logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py deleted file mode 100644 index 1932bf881..000000000 --- a/library/sdxl_train_util.py +++ /dev/null @@ -1,373 +0,0 @@ -import argparse -import math -import os -from typing import Optional - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from accelerate import init_empty_weights -from tqdm import tqdm -from transformers import CLIPTokenizer -from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -TOKENIZER1_PATH = "openai/clip-vit-large-patch14" -TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - -# DEFAULT_NOISE_OFFSET = 0.0357 - - -def load_target_model(args, accelerator, model_version: str, weight_dtype): - # load models for each process - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 - for pi in range(accelerator.state.num_processes): - if pi == accelerator.state.local_process_index: - logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = _load_target_model( - args.pretrained_model_name_or_path, - args.vae, - model_version, - weight_dtype, - accelerator.device if args.lowram else "cpu", - model_dtype, - ) - - # work on low-ram device - if args.lowram: - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - unet.to(accelerator.device) - vae.to(accelerator.device) - - clean_memory_on_device(accelerator.device) - accelerator.wait_for_everyone() - - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info - - -def _load_target_model( - name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None -): - # model_dtype only work with full fp16/bf16 - name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path - load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers - - if load_stable_diffusion_format: - logger.info(f"load StableDiffusion checkpoint: {name_or_path}") - ( - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) - else: - # Diffusers model is loaded to CPU - from diffusers import StableDiffusionXLPipeline - - variant = "fp16" if weight_dtype == torch.float16 else None - logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") - try: - try: - pipe = StableDiffusionXLPipeline.from_pretrained( - name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None - ) - except EnvironmentError as ex: - if variant is not None: - logger.info("try to load fp32 model") - pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) - else: - raise ex - except EnvironmentError as ex: - logger.error( - f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" - ) - raise ex - - text_encoder1 = pipe.text_encoder - text_encoder2 = pipe.text_encoder_2 - - # convert to fp32 for cache text_encoders outputs - if text_encoder1.dtype != torch.float32: - text_encoder1 = text_encoder1.to(dtype=torch.float32) - if text_encoder2.dtype != torch.float32: - text_encoder2 = text_encoder2.to(dtype=torch.float32) - - vae = pipe.vae - unet = pipe.unet - del pipe - - # Diffusers U-Net to original U-Net - state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) - with init_empty_weights(): - unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet - sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) - logger.info("U-Net converted to original U-Net") - - logit_scale = None - ckpt_info = None - - # VAEを読み込む - if vae_path is not None: - vae = model_util.load_vae(vae_path, weight_dtype) - logger.info("additional VAE loaded") - - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info - - -def load_tokenizers(args: argparse.Namespace): - logger.info("prepare tokenizers") - - original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] - tokeniers = [] - for i, original_path in enumerate(original_paths): - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) - - if tokenizer is None: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - if i == 1: - tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer - - tokeniers.append(tokenizer) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - return tokeniers - - -def match_mixed_precision(args, weight_dtype): - if args.full_fp16: - assert ( - weight_dtype == torch.float16 - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - return weight_dtype - elif args.full_bf16: - assert ( - weight_dtype == torch.bfloat16 - ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - return weight_dtype - else: - return None - - -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - -def get_timestep_embedding(x, outdim): - assert len(x.shape) == 2 - b, dims = x.shape[0], x.shape[1] - x = torch.flatten(x) - emb = timestep_embedding(x, outdim) - emb = torch.reshape(emb, (b, dims * outdim)) - return emb - - -def get_size_embeddings(orig_size, crop_size, target_size, device): - emb1 = get_timestep_embedding(orig_size, 256) - emb2 = get_timestep_embedding(crop_size, 256) - emb3 = get_timestep_embedding(target_size, 256) - vector = torch.cat([emb1, emb2, emb3], dim=1).to(device) - return vector - - -def save_sd_model_on_train_end( - args: argparse.Namespace, - src_path: str, - save_stable_diffusion_format: bool, - use_safetensors: bool, - save_dtype: torch.dtype, - epoch: int, - global_step: int, - text_encoder1, - text_encoder2, - unet, - vae, - logit_scale, - ckpt_info, -): - def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True) - sdxl_model_util.save_stable_diffusion_checkpoint( - ckpt_file, - text_encoder1, - text_encoder2, - unet, - epoch_no, - global_step, - ckpt_info, - vae, - logit_scale, - sai_metadata, - save_dtype, - ) - - def diffusers_saver(out_dir): - sdxl_model_util.save_diffusers_checkpoint( - out_dir, - text_encoder1, - text_encoder2, - unet, - src_path, - vae, - use_safetensors=use_safetensors, - save_dtype=save_dtype, - ) - - train_util.save_sd_model_on_train_end_common( - args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver - ) - - -# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している -# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 -def save_sd_model_on_epoch_end_or_stepwise( - args: argparse.Namespace, - on_epoch_end: bool, - accelerator, - src_path, - save_stable_diffusion_format: bool, - use_safetensors: bool, - save_dtype: torch.dtype, - epoch: int, - num_train_epochs: int, - global_step: int, - text_encoder1, - text_encoder2, - unet, - vae, - logit_scale, - ckpt_info, -): - def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True) - sdxl_model_util.save_stable_diffusion_checkpoint( - ckpt_file, - text_encoder1, - text_encoder2, - unet, - epoch_no, - global_step, - ckpt_info, - vae, - logit_scale, - sai_metadata, - save_dtype, - ) - - def diffusers_saver(out_dir): - sdxl_model_util.save_diffusers_checkpoint( - out_dir, - text_encoder1, - text_encoder2, - unet, - src_path, - vae, - use_safetensors=use_safetensors, - save_dtype=save_dtype, - ) - - train_util.save_sd_model_on_epoch_end_or_stepwise_common( - args, - on_epoch_end, - accelerator, - save_stable_diffusion_format, - use_safetensors, - epoch, - num_train_epochs, - global_step, - sd_saver, - diffusers_saver, - ) - - -def add_sdxl_training_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) - - -def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): - assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" - if args.v_parameterization: - logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") - - if args.clip_skip is not None: - logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") - - # if args.multires_noise_iterations: - # logger.info( - # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" - # ) - # else: - # if args.noise_offset is None: - # args.noise_offset = DEFAULT_NOISE_OFFSET - # elif args.noise_offset != DEFAULT_NOISE_OFFSET: - # logger.info( - # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" - # ) - # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" - - if supportTextEncoderCaching: - if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - args.cache_text_encoder_outputs = True - logger.warning( - "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " - + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" - ) - - -def sample_images(*args, **kwargs): - return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/slicing_vae.py b/library/slicing_vae.py deleted file mode 100644 index ea7653429..000000000 --- a/library/slicing_vae.py +++ /dev/null @@ -1,682 +0,0 @@ -# Modified from Diffusers to reduce VRAM usage - -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block -from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution -from diffusers.models.autoencoder_kl import AutoencoderKLOutput -from .utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def slice_h(x, num_slices): - # slice with pad 1 both sides: to eliminate side effect of padding of conv2d - # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする - # NCHWでもNHWCでもどちらでも動く - size = (x.shape[2] + num_slices - 1) // num_slices - sliced = [] - for i in range(num_slices): - if i == 0: - sliced.append(x[:, :, : size + 1, :]) - else: - end = size * (i + 1) + 1 - if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う - end = x.shape[2] - sliced.append(x[:, :, size * i - 1 : end, :]) - if end >= x.shape[2]: - break - return sliced - - -def cat_h(sliced): - # padding分を除いて結合する - cat = [] - for i, x in enumerate(sliced): - if i == 0: - cat.append(x[:, :, :-1, :]) - elif i == len(sliced) - 1: - cat.append(x[:, :, 1:, :]) - else: - cat.append(x[:, :, 1:-1, :]) - del x - x = torch.cat(cat, dim=2) - return x - - -def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): - assert _self.upsample is None and _self.downsample is None - assert _self.norm1.num_groups == _self.norm2.num_groups - assert temb is None - - # make sure norms are on cpu - org_device = input_tensor.device - cpu_device = torch.device("cpu") - _self.norm1.to(cpu_device) - _self.norm2.to(cpu_device) - - # GroupNormがCPUでfp16で動かない対策 - org_dtype = input_tensor.dtype - if org_dtype == torch.float16: - _self.norm1.to(torch.float32) - _self.norm2.to(torch.float32) - - # すべてのテンソルをCPUに移動する - input_tensor = input_tensor.to(cpu_device) - hidden_states = input_tensor - - # どうもこれは結果が異なるようだ…… - # def sliced_norm1(norm, x): - # num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups - # sliced_tensor = torch.chunk(x, num_div, dim=1) - # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) - # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) - # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) - # normed_tensor = [] - # for i in range(num_div): - # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) - # normed_tensor.append(n) - # del n - # x = torch.cat(normed_tensor, dim=1) - # return num_div, x - - # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない - if org_dtype == torch.float16: - hidden_states = hidden_states.to(torch.float32) - hidden_states = _self.norm1(hidden_states) # run on cpu - if org_dtype == torch.float16: - hidden_states = hidden_states.to(torch.float16) - - sliced = slice_h(hidden_states, num_slices) - del hidden_states - - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - # 計算する部分だけGPUに移動する、以下同様 - x = x.to(org_device) - x = _self.nonlinearity(x) - x = _self.conv1(x) - x = x.to(cpu_device) - sliced[i] = x - del x - - hidden_states = cat_h(sliced) - del sliced - - if org_dtype == torch.float16: - hidden_states = hidden_states.to(torch.float32) - hidden_states = _self.norm2(hidden_states) # run on cpu - if org_dtype == torch.float16: - hidden_states = hidden_states.to(torch.float16) - - sliced = slice_h(hidden_states, num_slices) - del hidden_states - - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - x = x.to(org_device) - x = _self.nonlinearity(x) - x = _self.dropout(x) - x = _self.conv2(x) - x = x.to(cpu_device) - sliced[i] = x - del x - - hidden_states = cat_h(sliced) - del sliced - - # make shortcut - if _self.conv_shortcut is not None: - sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする - del input_tensor - - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - x = x.to(org_device) - x = _self.conv_shortcut(x) - x = x.to(cpu_device) - sliced[i] = x - del x - - input_tensor = torch.cat(sliced, dim=2) - del sliced - - output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor - - output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する - return output_tensor - - -class SlicingEncoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - down_block_types=("DownEncoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - double_z=True, - num_slices=2, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=None, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - ) - self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - # replace forward of ResBlocks - def wrapper(func, module, num_slices): - def forward(*args, **kwargs): - return func(module, num_slices, *args, **kwargs) - - return forward - - self.num_slices = num_slices - div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす - # logger.info(f"initial divisor: {div}") - if div >= 2: - div = int(div) - for resnet in self.mid_block.resnets: - resnet.forward = wrapper(resblock_forward, resnet, div) - # midblock doesn't have downsample - - for i, down_block in enumerate(self.down_blocks[::-1]): - if div >= 2: - div = int(div) - # logger.info(f"down block: {i} divisor: {div}") - for resnet in down_block.resnets: - resnet.forward = wrapper(resblock_forward, resnet, div) - if down_block.downsamplers is not None: - # logger.info("has downsample") - for downsample in down_block.downsamplers: - downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) - div *= 2 - - def forward(self, x): - sample = x - del x - - org_device = sample.device - cpu_device = torch.device("cpu") - - # sample = self.conv_in(sample) - sample = sample.to(cpu_device) - sliced = slice_h(sample, self.num_slices) - del sample - - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - x = x.to(org_device) - x = self.conv_in(x) - x = x.to(cpu_device) - sliced[i] = x - del x - - sample = cat_h(sliced) - del sliced - - sample = sample.to(org_device) - - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) - - # post-process - # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略 - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - def downsample_forward(self, _self, num_slices, hidden_states): - assert hidden_states.shape[1] == _self.channels - assert _self.use_conv and _self.padding == 0 - logger.info(f"downsample forward {num_slices} {hidden_states.shape}") - - org_device = hidden_states.device - cpu_device = torch.device("cpu") - - hidden_states = hidden_states.to(cpu_device) - pad = (0, 1, 0, 1) - hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0) - - # slice with even number because of stride 2 - # strideが2なので偶数でスライスする - # slice with pad 1 both sides: to eliminate side effect of padding of conv2d - size = (hidden_states.shape[2] + num_slices - 1) // num_slices - size = size + 1 if size % 2 == 1 else size - - sliced = [] - for i in range(num_slices): - if i == 0: - sliced.append(hidden_states[:, :, : size + 1, :]) - else: - end = size * (i + 1) + 1 - if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor - end = hidden_states.shape[2] - sliced.append(hidden_states[:, :, size * i - 1 : end, :]) - if end >= hidden_states.shape[2]: - break - del hidden_states - - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - x = x.to(org_device) - x = _self.conv(x) - x = x.to(cpu_device) - - # ここだけ雰囲気が違うのはCopilotのせい - if i == 0: - hidden_states = x - else: - hidden_states = torch.cat([hidden_states, x], dim=2) - - hidden_states = hidden_states.to(org_device) - # logger.info(f"downsample forward done {hidden_states.shape}") - return hidden_states - - -class SlicingDecoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - num_slices=2, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - ) - self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=None, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attention_head_dim=output_channel, - temb_channels=None, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - # replace forward of ResBlocks - def wrapper(func, module, num_slices): - def forward(*args, **kwargs): - return func(module, num_slices, *args, **kwargs) - - return forward - - self.num_slices = num_slices - div = num_slices / (2 ** (len(self.up_blocks) - 1)) - logger.info(f"initial divisor: {div}") - if div >= 2: - div = int(div) - for resnet in self.mid_block.resnets: - resnet.forward = wrapper(resblock_forward, resnet, div) - # midblock doesn't have upsample - - for i, up_block in enumerate(self.up_blocks): - if div >= 2: - div = int(div) - # logger.info(f"up block: {i} divisor: {div}") - for resnet in up_block.resnets: - resnet.forward = wrapper(resblock_forward, resnet, div) - if up_block.upsamplers is not None: - # logger.info("has upsample") - for upsample in up_block.upsamplers: - upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) - div *= 2 - - def forward(self, z): - sample = z - del z - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample) - - # up - for i, up_block in enumerate(self.up_blocks): - sample = up_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - - # conv_out with slicing because of VRAM usage - # conv_outはとてもVRAM使うのでスライスして対応 - org_device = sample.device - cpu_device = torch.device("cpu") - sample = sample.to(cpu_device) - - sliced = slice_h(sample, self.num_slices) - del sample - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - x = x.to(org_device) - x = self.conv_out(x) - x = x.to(cpu_device) - sliced[i] = x - sample = cat_h(sliced) - del sliced - - sample = sample.to(org_device) - return sample - - def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): - assert hidden_states.shape[1] == _self.channels - assert _self.use_conv_transpose == False and _self.use_conv - - org_dtype = hidden_states.dtype - org_device = hidden_states.device - cpu_device = torch.device("cpu") - - hidden_states = hidden_states.to(cpu_device) - sliced = slice_h(hidden_states, num_slices) - del hidden_states - - for i in range(len(sliced)): - x = sliced[i] - sliced[i] = None - - x = x.to(org_device) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - # PyTorch 2で直らないかね…… - if org_dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - - if org_dtype == torch.bfloat16: - x = x.to(org_dtype) - - x = _self.conv(x) - - # upsampleされてるのでpadは2になる - if i == 0: - x = x[:, :, :-2, :] - elif i == num_slices - 1: - x = x[:, :, 2:, :] - else: - x = x[:, :, 2:-2, :] - - x = x.to(cpu_device) - sliced[i] = x - del x - - hidden_states = torch.cat(sliced, dim=2) - # logger.info(f"us hidden_states {hidden_states.shape}") - del sliced - - hidden_states = hidden_states.to(org_device) - return hidden_states - - -class SlicingAutoencoderKL(ModelMixin, ConfigMixin): - r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma - and Max Welling. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - sample_size: int = 32, - num_slices: int = 16, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = SlicingEncoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - num_slices=num_slices, - ) - - # pass init params to Decoder - self.decoder = SlicingDecoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - num_slices=num_slices, - ) - - self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - self.use_slicing = False - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - z = self.post_quant_conv(z) - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - # これはバッチ方向のスライシング 紛らわしい - def enable_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] - decoded = torch.cat(decoded_slices) - else: - decoded = self._decode(z).sample - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/library/train_util.py b/library/train_util.py deleted file mode 100644 index d2b69edb5..000000000 --- a/library/train_util.py +++ /dev/null @@ -1,5064 +0,0 @@ -# common functions for training - -import argparse -import ast -import asyncio -import datetime -import importlib -import json -import logging -import pathlib -import re -import shutil -import time -from typing import ( - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, -) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState -import glob -import math -import os -import random -import hashlib -import subprocess -from io import BytesIO -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device - -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torchvision import transforms -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection -import transformers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION -from diffusers import ( - StableDiffusionPipeline, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - AutoencoderKL, -) -from library import custom_train_functions -from library.original_unet import UNet2DConditionModel -from huggingface_hub import hf_hub_download -import numpy as np -from PIL import Image -import cv2 -import safetensors.torch -from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline -import library.model_util as model_util -import library.huggingface_util as huggingface_util -import library.sai_model_spec as sai_model_spec -from library.utils import setup_logging - -setup_logging() -import logging - -logger = logging.getLogger(__name__) -# from library.attention_processors import FlashAttnProcessor -# from library.hypernetwork import replace_attentions_for_hypernetwork -from library.original_unet import UNet2DConditionModel - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -HIGH_VRAM = False - -# checkpointファイル名 -EPOCH_STATE_NAME = "{}-{:06d}-state" -EPOCH_FILE_NAME = "{}-{:06d}" -EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" -LAST_STATE_NAME = "{}-state" -DEFAULT_EPOCH_NAME = "epoch" -DEFAULT_LAST_OUTPUT_NAME = "last" - -DEFAULT_STEP_NAME = "at" -STEP_STATE_NAME = "{}-step{:08d}-state" -STEP_FILE_NAME = "{}-step{:08d}" -STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}" - -# region dataset - -IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] - -try: - import pillow_avif - - IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) -except: - pass - -# JPEG-XL on Linux -try: - from jxlpy import JXLImagePlugin - - IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) -except: - pass - -# JPEG-XL on Windows -try: - import pillow_jxl - - IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) -except: - pass - -IMAGE_TRANSFORMS = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] -) - -TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" - - -class ImageInfo: - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: Tuple[int, int] = None - self.resized_size: Tuple[int, int] = None - self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_original_size: Tuple[int, int] = None # original image size, not latents size - self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size - self.cond_img_path: str = None - self.image: Optional[Image.Image] = None # optional, original PIL Image - # SDXL, optional - self.text_encoder_outputs_npz: Optional[str] = None - self.text_encoder_outputs1: Optional[torch.Tensor] = None - self.text_encoder_outputs2: Optional[torch.Tensor] = None - self.text_encoder_pool2: Optional[torch.Tensor] = None - - -class BucketManager: - def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: - if max_size is not None: - if max_reso is not None: - assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso" - assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso" - if min_size is not None: - assert max_size >= min_size, "the max_size should be larger than the min_size" - - self.no_upscale = no_upscale - if max_reso is None: - self.max_reso = None - self.max_area = None - else: - self.max_reso = max_reso - self.max_area = max_reso[0] * max_reso[1] - self.min_size = min_size - self.max_size = max_size - self.reso_steps = reso_steps - - self.resos = [] - self.reso_to_id = {} - self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key - - def add_image(self, reso, image_or_info): - bucket_id = self.reso_to_id[reso] - self.buckets[bucket_id].append(image_or_info) - - def shuffle(self): - for bucket in self.buckets: - random.shuffle(bucket) - - def sort(self): - # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す - sorted_resos = self.resos.copy() - sorted_resos.sort() - - sorted_buckets = [] - sorted_reso_to_id = {} - for i, reso in enumerate(sorted_resos): - bucket_id = self.reso_to_id[reso] - sorted_buckets.append(self.buckets[bucket_id]) - sorted_reso_to_id[reso] = i - - self.resos = sorted_resos - self.buckets = sorted_buckets - self.reso_to_id = sorted_reso_to_id - - def make_buckets(self): - resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) - self.set_predefined_resos(resos) - - def set_predefined_resos(self, resos): - # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく - self.predefined_resos = resos.copy() - self.predefined_resos_set = set(resos) - self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) - - def add_if_new_reso(self, reso): - if reso not in self.reso_to_id: - bucket_id = len(self.resos) - self.reso_to_id[reso] = bucket_id - self.resos.append(reso) - self.buckets.append([]) - # logger.info(reso, bucket_id, len(self.buckets)) - - def round_to_steps(self, x): - x = int(x + 0.5) - return x - x % self.reso_steps - - def select_bucket(self, image_width, image_height): - aspect_ratio = image_width / image_height - if not self.no_upscale: - # 拡大および縮小を行う - # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する - reso = (image_width, image_height) - if reso in self.predefined_resos_set: - pass - else: - ar_errors = self.predefined_aspect_ratios - aspect_ratio - predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの - reso = self.predefined_resos[predefined_bucket_id] - - ar_reso = reso[0] / reso[1] - if aspect_ratio > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - - resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) - # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") - else: - # 縮小のみを行う - if image_width * image_height > self.max_area: - # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める - resized_width = math.sqrt(self.max_area * aspect_ratio) - resized_height = self.max_area / resized_width - assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" - - # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ - # 元のbucketingと同じロジック - b_width_rounded = self.round_to_steps(resized_width) - b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) - ar_width_rounded = b_width_rounded / b_height_in_wr - - b_height_rounded = self.round_to_steps(resized_height) - b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) - ar_height_rounded = b_width_in_hr / b_height_rounded - - # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) - # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) - - if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): - resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) - else: - resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) - # logger.info(resized_size) - else: - resized_size = (image_width, image_height) # リサイズは不要 - - # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) - bucket_width = resized_size[0] - resized_size[0] % self.reso_steps - bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") - - reso = (bucket_width, bucket_height) - - self.add_if_new_reso(reso) - - ar_error = (reso[0] / reso[1]) - aspect_ratio - return reso, resized_size, ar_error - - @staticmethod - def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): - # Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める - # Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation. - - bucket_ar = bucket_reso[0] / bucket_reso[1] - image_ar = image_size[0] / image_size[1] - if bucket_ar > image_ar: - # bucketのほうが横長→縦を合わせる - resized_width = bucket_reso[1] * image_ar - resized_height = bucket_reso[1] - else: - resized_width = bucket_reso[0] - resized_height = bucket_reso[0] / image_ar - crop_left = (bucket_reso[0] - resized_width) // 2 - crop_top = (bucket_reso[1] - resized_height) // 2 - crop_right = crop_left + resized_width - crop_bottom = crop_top + resized_height - return crop_left, crop_top, crop_right, crop_bottom - - -class BucketBatchIndex(NamedTuple): - bucket_index: int - bucket_batch_size: int - batch_index: int - - -class AugHelper: - # albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる - - def __init__(self): - pass - - def color_aug(self, image: np.ndarray): - # self.color_aug_method = albu.OneOf( - # [ - # albu.HueSaturationValue(8, 0, 0, p=0.5), - # albu.RandomGamma((95, 105), p=0.5), - # ], - # p=0.33, - # ) - hue_shift_limit = 8 - - # remove dependency to albumentations - if random.random() <= 0.33: - if random.random() > 0.5: - # hue shift - hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) - hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit) - if hue_shift < 0: - hue_shift = 180 + hue_shift - hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180 - image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) - else: - # random gamma - gamma = random.uniform(0.95, 1.05) - image = np.clip(image**gamma, 0, 255).astype(np.uint8) - - return {"image": image} - - def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]: - return self.color_aug if use_color_aug else None - - -class BaseSubset: - def __init__( - self, - image_dir: Optional[str], - num_repeats: int, - shuffle_caption: bool, - caption_separator: str, - keep_tokens: int, - keep_tokens_separator: str, - color_aug: bool, - flip_aug: bool, - face_crop_aug_range: Optional[Tuple[float, float]], - random_crop: bool, - caption_dropout_rate: float, - caption_dropout_every_n_epochs: int, - caption_tag_dropout_rate: float, - caption_prefix: Optional[str], - caption_suffix: Optional[str], - token_warmup_min: int, - token_warmup_step: Union[float, int], - ) -> None: - self.image_dir = image_dir - self.num_repeats = num_repeats - self.shuffle_caption = shuffle_caption - self.caption_separator = caption_separator - self.keep_tokens = keep_tokens - self.keep_tokens_separator = keep_tokens_separator - self.color_aug = color_aug - self.flip_aug = flip_aug - self.face_crop_aug_range = face_crop_aug_range - self.random_crop = random_crop - self.caption_dropout_rate = caption_dropout_rate - self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs - self.caption_tag_dropout_rate = caption_tag_dropout_rate - self.caption_prefix = caption_prefix - self.caption_suffix = caption_suffix - - self.token_warmup_min = token_warmup_min # step=0におけるタグの数 - self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる - - self.img_count = 0 - - -class DreamBoothSubset(BaseSubset): - def __init__( - self, - image_dir: str, - is_reg: bool, - class_tokens: Optional[str], - caption_extension: str, - num_repeats, - shuffle_caption, - caption_separator: str, - keep_tokens, - keep_tokens_separator, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, - ) -> None: - assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" - - super().__init__( - image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, - ) - - self.is_reg = is_reg - self.class_tokens = class_tokens - self.caption_extension = caption_extension - if self.caption_extension and not self.caption_extension.startswith("."): - self.caption_extension = "." + self.caption_extension - - def __eq__(self, other) -> bool: - if not isinstance(other, DreamBoothSubset): - return NotImplemented - return self.image_dir == other.image_dir - - -class FineTuningSubset(BaseSubset): - def __init__( - self, - image_dir, - metadata_file: str, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, - ) -> None: - assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" - - super().__init__( - image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, - ) - - self.metadata_file = metadata_file - - def __eq__(self, other) -> bool: - if not isinstance(other, FineTuningSubset): - return NotImplemented - return self.metadata_file == other.metadata_file - - -class ControlNetSubset(BaseSubset): - def __init__( - self, - image_dir: str, - conditioning_data_dir: str, - caption_extension: str, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, - ) -> None: - assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" - - super().__init__( - image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, - ) - - self.conditioning_data_dir = conditioning_data_dir - self.caption_extension = caption_extension - if self.caption_extension and not self.caption_extension.startswith("."): - self.caption_extension = "." + self.caption_extension - - def __eq__(self, other) -> bool: - if not isinstance(other, ControlNetSubset): - return NotImplemented - return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir - - -class BaseDataset(torch.utils.data.Dataset): - def __init__( - self, - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], - max_token_length: int, - resolution: Optional[Tuple[int, int]], - network_multiplier: float, - debug_dataset: bool, - ) -> None: - super().__init__() - - self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - self.max_token_length = max_token_length - # width/height is used when enable_bucket==False - self.width, self.height = (None, None) if resolution is None else resolution - self.network_multiplier = network_multiplier - self.debug_dataset = debug_dataset - - self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] - - self.token_padding_disabled = False - self.tag_frequency = {} - self.XTI_layers = None - self.token_strings = None - - self.enable_bucket = False - self.bucket_manager: BucketManager = None # not initialized - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None - self.bucket_no_upscale = None - self.bucket_info = None # for metadata - - self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2 - - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ - - self.current_step: int = 0 - self.max_train_steps: int = 0 - self.seed: int = 0 - - # augmentation - self.aug_helper = AugHelper() - - self.image_transforms = IMAGE_TRANSFORMS - - self.image_data: Dict[str, ImageInfo] = {} - self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} - - self.replacements = {} - - # caching - self.caching_mode = None # None, 'latents', 'text' - - def set_seed(self, seed): - self.seed = seed - - def set_caching_mode(self, mode): - self.caching_mode = mode - - def set_current_epoch(self, epoch): - if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch - - def set_current_step(self, step): - self.current_step = step - - def set_max_train_steps(self, max_train_steps): - self.max_train_steps = max_train_steps - - def set_tag_frequency(self, dir_name, captions): - frequency_for_dir = self.tag_frequency.get(dir_name, {}) - self.tag_frequency[dir_name] = frequency_for_dir - for caption in captions: - for tag in caption.split(","): - tag = tag.strip() - if tag: - tag = tag.lower() - frequency = frequency_for_dir.get(tag, 0) - frequency_for_dir[tag] = frequency + 1 - - def disable_token_padding(self): - self.token_padding_disabled = True - - def enable_XTI(self, layers=None, token_strings=None): - self.XTI_layers = layers - self.token_strings = token_strings - - def add_replacement(self, str_from, str_to): - self.replacements[str_from] = str_to - - def process_caption(self, subset: BaseSubset, caption): - # caption に prefix/suffix を付ける - if subset.caption_prefix: - caption = subset.caption_prefix + " " + caption - if subset.caption_suffix: - caption = caption + " " + subset.caption_suffix - - # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い - is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate - is_drop_out = ( - is_drop_out - or subset.caption_dropout_every_n_epochs > 0 - and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 - ) - - if is_drop_out: - caption = "" - else: - if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - fixed_tokens = [] - flex_tokens = [] - if ( - hasattr(subset, "keep_tokens_separator") - and subset.keep_tokens_separator - and subset.keep_tokens_separator in caption - ): - fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) - fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] - flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] - else: - tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] - flex_tokens = tokens[:] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = tokens[subset.keep_tokens :] - - if subset.token_warmup_step < 1: # 初回に上書きする - subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) - if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = ( - math.floor( - (self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) - ) - + subset.token_warmup_min - ) - flex_tokens = flex_tokens[:tokens_len] - - def dropout_tags(tokens): - if subset.caption_tag_dropout_rate <= 0: - return tokens - l = [] - for token in tokens: - if random.random() >= subset.caption_tag_dropout_rate: - l.append(token) - return l - - if subset.shuffle_caption: - random.shuffle(flex_tokens) - - flex_tokens = dropout_tags(flex_tokens) - - caption = ", ".join(fixed_tokens + flex_tokens) - - # textual inversion対応 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) - else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) - - return caption - - def get_input_ids(self, caption, tokenizer=None): - if tokenizer is None: - tokenizer = self.tokenizers[0] - - input_ids = tokenizer( - caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" - ).input_ids - - if self.tokenizer_max_length > tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if tokenizer.pad_token_id == tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range( - 1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2 - ): # (1, 152, 75) - ids_chunk = ( - input_ids[0].unsqueeze(0), - input_ids[i : i + tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0), - ) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 or SDXL - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): - ids_chunk = ( - input_ids[0].unsqueeze(0), # BOS - input_ids[i : i + tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0), - ) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: - ids_chunk[-1] = tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == tokenizer.pad_token_id: - ids_chunk[1] = tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - return input_ids - - def register_image(self, info: ImageInfo, subset: BaseSubset): - self.image_data[info.image_key] = info - self.image_to_subset[info.image_key] = subset - - def make_buckets(self): - """ - bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - min_size and max_size are ignored when enable_bucket is False - """ - logger.info("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) - - if self.enable_bucket: - logger.info("make buckets") - else: - logger.info("prepare dataset") - - # bucketを作成し、画像をbucketに振り分ける - if self.enable_bucket: - if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み - self.bucket_manager = BucketManager( - self.bucket_no_upscale, - (self.width, self.height), - self.min_bucket_reso, - self.max_bucket_reso, - self.bucket_reso_steps, - ) - if not self.bucket_no_upscale: - self.bucket_manager.make_buckets() - else: - logger.warning( - "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" - ) - - img_ar_errors = [] - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( - image_width, image_height - ) - - # logger.info(image_info.image_key, image_info.bucket_reso) - img_ar_errors.append(abs(ar_error)) - - self.bucket_manager.sort() - else: - self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) - self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) - - for image_info in self.image_data.values(): - for _ in range(image_info.num_repeats): - self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) - - # bucket情報を表示、格納する - if self.enable_bucket: - self.bucket_info = {"buckets": {}} - logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") - for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): - count = len(bucket) - if count > 0: - self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) - self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") - - # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] - for bucket_index, bucket in enumerate(self.bucket_manager.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - def shuffle_buckets(self): - # set random seed for this epoch - random.seed(self.seed + self.current_epoch) - - random.shuffle(self.buckets_indices) - self.bucket_manager.shuffle() - - def verify_bucket_reso_steps(self, min_steps: int): - assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, ( - f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n" - + f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります" - ) - - def is_latent_cacheable(self): - return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - - def is_text_encoder_output_cacheable(self): - return all( - [ - not ( - subset.caption_dropout_rate > 0 - or subset.shuffle_caption - or subset.token_warmup_step > 0 - or subset.caption_tag_dropout_rate > 0 - ) - for subset in self.subsets - ] - ) - - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - logger.info("caching latents.") - - image_infos = list(self.image_data.values()) - - # sort by resolution - image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - - # split by resolution - batches = [] - batch = [] - logger.info("checking cache validity...") - for info in tqdm(image_infos): - subset = self.image_to_subset[info.image_key] - - if info.latents_npz is not None: # fine tuning dataset - continue - - # check disk cache exists and size of latents - if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" - if not is_main_process: # store to info only - continue - - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) - - if cache_available: # do not add to batch - continue - - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) - batch = [] - - batch.append(info) - - # if number of data in batch is enough, flush the batch - if len(batch) >= vae_batch_size: - batches.append(batch) - batch = [] - - if len(batch) > 0: - batches.append(batch) - - if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only - return - - # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded - logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) - - # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる - # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する - # SD1/2に対応するにはv2のフラグを持つ必要があるので後回し - def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True - ): - assert len(tokenizers) == 2, "only support SDXL" - - # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する - # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと - logger.info("caching text encoder outputs.") - image_infos = list(self.image_data.values()) - - logger.info("checking cache existence...") - image_infos_to_cache = [] - for info in tqdm(image_infos): - # subset = self.image_to_subset[info.image_key] - if cache_to_disk: - te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - info.text_encoder_outputs_npz = te_out_npz - - if not is_main_process: # store to info only - continue - - if os.path.exists(te_out_npz): - continue - - image_infos_to_cache.append(info) - - if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only - return - - # prepare tokenizers and text encoders - for text_encoder in text_encoders: - text_encoder.to(device) - if weight_dtype is not None: - text_encoder.to(dtype=weight_dtype) - - # create batch - batch = [] - batches = [] - for info in image_infos_to_cache: - input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) - input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) - batch.append((info, input_ids1, input_ids2)) - - if len(batch) >= self.batch_size: - batches.append(batch) - batch = [] - - if len(batch) > 0: - batches.append(batch) - - # iterate batches: call text encoder and cache outputs for memory or disk - logger.info("caching text encoder outputs...") - for batch in tqdm(batches): - infos, input_ids1, input_ids2 = zip(*batch) - input_ids1 = torch.stack(input_ids1, dim=0) - input_ids2 = torch.stack(input_ids2, dim=0) - cache_batch_text_encoder_outputs( - infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype - ) - - def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size - - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) - - face_cx = face_cy = face_w = face_h = 0 - if subset.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - size = min(self.height, self.width) # 短いほう - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + 0.5) - nw = int(width * scale + 0.5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + 0.5) - face_cy = int(face_cy * scale + 0.5) - height, width = nh, nw - - # 顔を中心として448*640とかへ切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if subset.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: - if face_size > size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1 : p1 + target_size, :] - else: - image = image[:, p1 : p1 + target_size] - - return image - - def __len__(self): - return self._length - - def __getitem__(self, index): - bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] - bucket_batch_size = self.buckets_indices[index].bucket_batch_size - image_index = self.buckets_indices[index].batch_index * bucket_batch_size - - if self.caching_mode is not None: # return batch for latents/text encoder outputs caching - return self.get_item_for_caching(bucket, bucket_batch_size, image_index) - - loss_weights = [] - captions = [] - input_ids_list = [] - input_ids2_list = [] - latents_list = [] - images = [] - original_sizes_hw = [] - crop_top_lefts = [] - target_sizes_hw = [] - flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] - - for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] - subset = self.image_to_subset[image_key] - loss_weights.append( - self.prior_loss_weight if image_info.is_reg else 1.0 - ) # in case of fine tuning, is_reg is always False - - flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance - - # image/latentsを処理する - if image_info.latents is not None: # cache_latents=Trueの場合 - original_size = image_info.latents_original_size - crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped - if not flipped: - latents = image_info.latents - else: - latents = image_info.latents_flipped - - image = None - elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) - if flipped: - latents = flipped_latents - del flipped_latents - latents = torch.FloatTensor(latents) - - image = None - else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size - ) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert ( - subset.random_crop - ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p : p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p : p + self.width] - - im_h, im_w = img.shape[0:2] - assert ( - im_h == self.height and im_w == self.width - ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - - original_size = [im_w, im_h] - crop_ltrb = (0, 0, 0, 0) - - # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug) - if aug is not None: - img = aug(image=img)["image"] - - if flipped: - img = img[:, ::-1, :].copy() # copy to avoid negative stride problem - - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - - images.append(image) - latents_list.append(latents) - - target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) - - if not flipped: - crop_left_top = (crop_ltrb[0], crop_ltrb[1]) - else: - # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image - crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) - - original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) - crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) - target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) - flippeds.append(flipped) - - # captionとtext encoder outputを処理する - caption = image_info.caption # default - if image_info.text_encoder_outputs1 is not None: - text_encoder_outputs1_list.append(image_info.text_encoder_outputs1) - text_encoder_outputs2_list.append(image_info.text_encoder_outputs2) - text_encoder_pool2_list.append(image_info.text_encoder_pool2) - captions.append(caption) - elif image_info.text_encoder_outputs_npz is not None: - text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk( - image_info.text_encoder_outputs_npz - ) - text_encoder_outputs1_list.append(text_encoder_outputs1) - text_encoder_outputs2_list.append(text_encoder_outputs2) - text_encoder_pool2_list.append(text_encoder_pool2) - captions.append(caption) - else: - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) - - if not self.token_padding_disabled: # this option might be omitted in future - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) - - if len(self.tokenizers) > 1: - if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) - - example = {} - example["loss_weights"] = torch.FloatTensor(loss_weights) - - if len(text_encoder_outputs1_list) == 0: - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids - if len(self.tokenizers) > 1: - example["input_ids2"] = self.tokenizer[1]( - captions, padding=True, truncation=True, return_tensors="pt" - ).input_ids - else: - example["input_ids2"] = None - else: - example["input_ids"] = torch.stack(input_ids_list) - example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None - example["text_encoder_outputs1_list"] = None - example["text_encoder_outputs2_list"] = None - example["text_encoder_pool2_list"] = None - else: - example["input_ids"] = None - example["input_ids2"] = None - # # for assertion - # example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions]) - # example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions]) - example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list) - example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) - example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) - - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example["images"] = images - - example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None - example["captions"] = captions - - example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) - example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts]) - example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) - example["flippeds"] = flippeds - - example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) - - if self.debug_dataset: - example["image_keys"] = bucket[image_index : image_index + self.batch_size] - return example - - def get_item_for_caching(self, bucket, bucket_batch_size, image_index): - captions = [] - images = [] - input_ids1_list = [] - input_ids2_list = [] - absolute_paths = [] - resized_sizes = [] - bucket_reso = None - flip_aug = None - random_crop = None - - for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] - subset = self.image_to_subset[image_key] - - if flip_aug is None: - flip_aug = subset.flip_aug - random_crop = subset.random_crop - bucket_reso = image_info.bucket_reso - else: - assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" - assert random_crop == subset.random_crop, "random_crop must be same in a batch" - assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" - - caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. - - if self.caching_mode == "latents": - image = load_image(image_info.absolute_path) - else: - image = None - - if self.caching_mode == "text": - input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) - input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) - else: - input_ids1 = None - input_ids2 = None - - captions.append(caption) - images.append(image) - input_ids1_list.append(input_ids1) - input_ids2_list.append(input_ids2) - absolute_paths.append(image_info.absolute_path) - resized_sizes.append(image_info.resized_size) - - example = {} - - if images[0] is None: - images = None - example["images"] = images - - example["captions"] = captions - example["input_ids1_list"] = input_ids1_list - example["input_ids2_list"] = input_ids2_list - example["absolute_paths"] = absolute_paths - example["resized_sizes"] = resized_sizes - example["flip_aug"] = flip_aug - example["random_crop"] = random_crop - example["bucket_reso"] = bucket_reso - return example - - -class DreamBoothDataset(BaseDataset): - def __init__( - self, - subsets: Sequence[DreamBoothSubset], - batch_size: int, - tokenizer, - max_token_length, - resolution, - network_multiplier: float, - enable_bucket: bool, - min_bucket_reso: int, - max_bucket_reso: int, - bucket_reso_steps: int, - bucket_no_upscale: bool, - prior_loss_weight: float, - debug_dataset: bool, - ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) - - assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - - self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.latents_cache = None - - self.enable_bucket = enable_bucket - if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None # この情報は使われない - self.bucket_no_upscale = False - - def read_caption(img_path, caption_extension): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding="utf-8") as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(subset: DreamBoothSubset): - if not os.path.isdir(subset.image_dir): - logger.warning(f"not directory: {subset.image_dir}") - return [], [] - - img_paths = glob_images(subset.image_dir, "*") - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: - logger.warning( - f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" - ) - captions.append("") - missing_captions.append(img_path) - else: - if cap_for_img is None: - captions.append(subset.class_tokens) - missing_captions.append(img_path) - else: - captions.append(cap_for_img) - - self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 - - if missing_captions: - number_of_missing_captions = len(missing_captions) - number_of_missing_captions_to_show = 5 - remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show - - logger.warning( - f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" - ) - for i, missing_caption in enumerate(missing_captions): - if i >= number_of_missing_captions_to_show: - logger.warning(missing_caption + f"... and {remaining_missing_captions} more") - break - logger.warning(missing_caption) - return img_paths, captions - - logger.info("prepare images.") - num_train_images = 0 - num_reg_images = 0 - reg_infos: List[ImageInfo] = [] - for subset in subsets: - if subset.num_repeats < 1: - logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" - ) - continue - - if subset in self.subsets: - logger.warning( - f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" - ) - continue - - img_paths, captions = load_dreambooth_dir(subset) - if len(img_paths) < 1: - logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" - ) - continue - - if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) - else: - num_train_images += subset.num_repeats * len(img_paths) - - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) - if subset.is_reg: - reg_infos.append(info) - else: - self.register_image(info, subset) - - subset.img_count = len(img_paths) - self.subsets.append(subset) - - logger.info(f"{num_train_images} train images with repeating.") - self.num_train_images = num_train_images - - logger.info(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - logger.warning("no regularization images / 正則化画像が見つかりませんでした") - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 # rewrite registered info - n += 1 - if n >= num_train_images: - break - first_loop = False - - self.num_reg_images = num_reg_images - - -class FineTuningDataset(BaseDataset): - def __init__( - self, - subsets: Sequence[FineTuningSubset], - batch_size: int, - tokenizer, - max_token_length, - resolution, - network_multiplier: float, - enable_bucket: bool, - min_bucket_reso: int, - max_bucket_reso: int, - bucket_reso_steps: int, - bucket_no_upscale: bool, - debug_dataset: bool, - ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) - - self.batch_size = batch_size - - self.num_train_images = 0 - self.num_reg_images = 0 - - for subset in subsets: - if subset.num_repeats < 1: - logger.warning( - f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" - ) - continue - - if subset in self.subsets: - logger.warning( - f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" - ) - continue - - # メタデータを読み込む - if os.path.exists(subset.metadata_file): - logger.info(f"loading existing metadata: {subset.metadata_file}") - with open(subset.metadata_file, "rt", encoding="utf-8") as f: - metadata = json.load(f) - else: - raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") - - if len(metadata) < 1: - logger.warning( - f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します" - ) - continue - - tags_list = [] - for image_key, img_md in metadata.items(): - # path情報を作る - abs_path = None - - # まず画像を優先して探す - if os.path.exists(image_key): - abs_path = image_key - else: - # わりといい加減だがいい方法が思いつかん - paths = glob_images(subset.image_dir, image_key) - if len(paths) > 0: - abs_path = paths[0] - - # なければnpzを探す - if abs_path is None: - if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" - else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path - - assert abs_path is not None, f"no image / 画像がありません: {image_key}" - - caption = img_md.get("caption") - tags = img_md.get("tags") - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ", " + tags - tags_list.append(tags) - - if caption is None: - caption = "" - - image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) - image_info.image_size = img_md.get("train_resolution") - - if not subset.color_aug and not subset.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) - - self.register_image(image_info, subset) - - self.num_train_images += len(metadata) * subset.num_repeats - - # TODO do not record tag freq when no tag - self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) - subset.img_count = len(metadata) - self.subsets.append(subset) - - # check existence of all npz files - use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) - if use_npz_latents: - flip_aug_in_subset = False - npz_any = False - npz_all = True - - for image_info in self.image_data.values(): - subset = self.image_to_subset[image_info.image_key] - - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz - - if subset.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - flip_aug_in_subset = True - npz_all = npz_all and has_npz - - if npz_any and not npz_all: - break - - if not npz_any: - use_npz_latents = False - logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") - elif not npz_all: - use_npz_latents = False - logger.warning( - f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" - ) - if flip_aug_in_subset: - logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") - # else: - # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") - - # check min/max bucket size - sizes = set() - resos = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - resos.add(tuple(image_info.image_size)) - - if sizes is None: - if use_npz_latents: - use_npz_latents = False - logger.warning( - f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" - ) - - assert ( - resolution is not None - ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" - - self.enable_bucket = enable_bucket - if self.enable_bucket: - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - if not enable_bucket: - logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") - self.enable_bucket = True - - assert ( - not bucket_no_upscale - ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - - # bucket情報を初期化しておく、make_bucketsで再作成しない - self.bucket_manager = BucketManager(False, None, None, None, None) - self.bucket_manager.set_predefined_resos(resos) - - # npz情報をきれいにしておく - if not use_npz_latents: - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None - - def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + ".npz" - - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + "_flip.npz" - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip - - # if not full path, check image_dir. if image_dir is None, return None - if subset.image_dir is None: - return None, None - - # image_key is relative path - npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") - npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") - - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None - - return npz_file_norm, npz_file_flip - - -class ControlNetDataset(BaseDataset): - def __init__( - self, - subsets: Sequence[ControlNetSubset], - batch_size: int, - tokenizer, - max_token_length, - resolution, - network_multiplier: float, - enable_bucket: bool, - min_bucket_reso: int, - max_bucket_reso: int, - bucket_reso_steps: int, - bucket_no_upscale: bool, - debug_dataset: float, - ) -> None: - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) - - db_subsets = [] - for subset in subsets: - db_subset = DreamBoothSubset( - subset.image_dir, - False, - None, - subset.caption_extension, - subset.num_repeats, - subset.shuffle_caption, - subset.caption_separator, - subset.keep_tokens, - subset.keep_tokens_separator, - subset.color_aug, - subset.flip_aug, - subset.face_crop_aug_range, - subset.random_crop, - subset.caption_dropout_rate, - subset.caption_dropout_every_n_epochs, - subset.caption_tag_dropout_rate, - subset.caption_prefix, - subset.caption_suffix, - subset.token_warmup_min, - subset.token_warmup_step, - ) - db_subsets.append(db_subset) - - self.dreambooth_dataset_delegate = DreamBoothDataset( - db_subsets, - batch_size, - tokenizer, - max_token_length, - resolution, - network_multiplier, - enable_bucket, - min_bucket_reso, - max_bucket_reso, - bucket_reso_steps, - bucket_no_upscale, - 1.0, - debug_dataset, - ) - - # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) - self.image_data = self.dreambooth_dataset_delegate.image_data - self.batch_size = batch_size - self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images - - # assert all conditioning data exists - missing_imgs = [] - cond_imgs_with_img = set() - for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): - db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] - subset = None - for s in subsets: - if s.image_dir == db_subset.image_dir: - subset = s - break - assert subset is not None, "internal error: subset not found" - - if not os.path.isdir(subset.conditioning_data_dir): - logger.warning(f"not directory: {subset.conditioning_data_dir}") - continue - - img_basename = os.path.basename(info.absolute_path) - ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename) - if not os.path.exists(ctrl_img_path): - missing_imgs.append(img_basename) - - info.cond_img_path = ctrl_img_path - cond_imgs_with_img.add(ctrl_img_path) - - extra_imgs = [] - for subset in subsets: - conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") - extra_imgs.extend( - [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] - ) - - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" - - self.conditioning_image_transforms = IMAGE_TRANSFORMS - - def make_buckets(self): - self.dreambooth_dataset_delegate.make_buckets() - self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager - self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - - def __len__(self): - return self.dreambooth_dataset_delegate.__len__() - - def __getitem__(self, index): - example = self.dreambooth_dataset_delegate[index] - - bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[ - self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index - ] - bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size - image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size - - conditioning_images = [] - - for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]): - image_info = self.dreambooth_dataset_delegate.image_data[image_key] - - target_size_hw = example["target_sizes_hw"][i] - original_size_hw = example["original_sizes_hw"][i] - crop_top_left = example["crop_top_lefts"][i] - flipped = example["flippeds"][i] - cond_img = load_image(image_info.cond_img_path) - - if self.dreambooth_dataset_delegate.enable_bucket: - assert ( - cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] - ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = cv2.resize( - cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA - ) # INTER_AREAでやりたいのでcv2でリサイズ - - # TODO support random crop - # 現在サポートしているcropはrandomではなく中央のみ - h, w = target_size_hw - ct = (cond_img.shape[0] - h) // 2 - cl = (cond_img.shape[1] - w) // 2 - cond_img = cond_img[ct : ct + h, cl : cl + w] - else: - # assert ( - # cond_img.shape[0] == self.height and cond_img.shape[1] == self.width - # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - # resize to target - if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) - - if flipped: - cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride - - cond_img = self.conditioning_image_transforms(cond_img) - conditioning_images.append(cond_img) - - example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() - - return example - - -# behave as Dataset mock -class DatasetGroup(torch.utils.data.ConcatDataset): - def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): - self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] - - super().__init__(datasets) - - self.image_data = {} - self.num_train_images = 0 - self.num_reg_images = 0 - - # simply concat together - # TODO: handling image_data key duplication among dataset - # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. - for dataset in datasets: - self.image_data.update(dataset.image_data) - self.num_train_images += dataset.num_train_images - self.num_reg_images += dataset.num_reg_images - - def add_replacement(self, str_from, str_to): - for dataset in self.datasets: - dataset.add_replacement(str_from, str_to) - - # def make_buckets(self): - # for dataset in self.datasets: - # dataset.make_buckets() - - def enable_XTI(self, *args, **kwargs): - for dataset in self.datasets: - dataset.enable_XTI(*args, **kwargs) - - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - for i, dataset in enumerate(self.datasets): - logger.info(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) - - def cache_text_encoder_outputs( - self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True - ): - for i, dataset in enumerate(self.datasets): - logger.info(f"[Dataset {i}]") - dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) - - def set_caching_mode(self, caching_mode): - for dataset in self.datasets: - dataset.set_caching_mode(caching_mode) - - def verify_bucket_reso_steps(self, min_steps: int): - for dataset in self.datasets: - dataset.verify_bucket_reso_steps(min_steps) - - def is_latent_cacheable(self) -> bool: - return all([dataset.is_latent_cacheable() for dataset in self.datasets]) - - def is_text_encoder_output_cacheable(self) -> bool: - return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) - - def set_current_epoch(self, epoch): - for dataset in self.datasets: - dataset.set_current_epoch(epoch) - - def set_current_step(self, step): - for dataset in self.datasets: - dataset.set_current_step(step) - - def set_max_train_steps(self, max_train_steps): - for dataset in self.datasets: - dataset.set_max_train_steps(max_train_steps) - - def disable_token_padding(self): - for dataset in self.datasets: - dataset.disable_token_padding() - - -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): - expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 - - if not os.path.exists(npz_path): - return False - - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - return True - - -# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) -def load_latents_from_disk( - npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents - - -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): - kwargs = {} - if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) - - -def debug_dataset(train_dataset, show_input_ids=False): - logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - logger.info( - "`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します" - ) - - epoch = 1 - while True: - logger.info(f"") - logger.info(f"epoch: {epoch}") - - steps = (epoch - 1) * len(train_dataset) + 1 - indices = list(range(len(train_dataset))) - random.shuffle(indices) - - k = 0 - for i, idx in enumerate(indices): - train_dataset.set_current_epoch(epoch) - train_dataset.set_current_step(steps) - logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") - - example = train_dataset[idx] - if example["latents"] is not None: - logger.info(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate( - zip( - example["image_keys"], - example["captions"], - example["loss_weights"], - example["input_ids"], - example["original_sizes_hw"], - example["crop_top_lefts"], - example["target_sizes_hw"], - example["flippeds"], - ) - ): - logger.info( - f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' - ) - if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") - - if show_input_ids: - logger.info(f"input ids: {iid}") - if "input_ids2" in example: - logger.info(f"input ids2: {example['input_ids2'][j]}") - if example["images"] is not None: - im = example["images"][j] - logger.info(f"image size: {im.size()}") - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - - if "conditioning_images" in example: - cond_img = example["conditioning_images"][j] - logger.info(f"conditioning image size: {cond_img.size()}") - cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) - cond_img = np.transpose(cond_img, (1, 2, 0)) - cond_img = cond_img[:, :, ::-1] - if os.name == "nt": - cv2.imshow("cond_img", cond_img) - - if os.name == "nt": # only windows - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27 or k == ord("s") or k == ord("e"): - break - steps += 1 - - if k == ord("e"): - break - if k == 27 or (example["images"] is None and i >= 8): - k = 27 - break - if k == 27: - break - - epoch += 1 - - -def glob_images(directory, base="*"): - img_paths = [] - for ext in IMAGE_EXTENSIONS: - if base == "*": - img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) - else: - img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - img_paths = list(set(img_paths)) # 重複を排除 - img_paths.sort() - return img_paths - - -def glob_images_pathlib(dir_path, recursive): - image_paths = [] - if recursive: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.rglob("*" + ext)) - else: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.glob("*" + ext)) - image_paths = list(set(image_paths)) # 重複を排除 - image_paths.sort() - return image_paths - - -class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) - - self.num_train_images = 0 # update in subclass - self.num_reg_images = 0 # update in subclass - self.datasets = [self] - self.batch_size = 1 # update in subclass - - self.subsets = [self] - self.num_repeats = 1 # update in subclass if needed - self.img_count = 1 # update in subclass if needed - self.bucket_info = {} - self.is_reg = False - self.image_dir = "dummy" # for metadata - - def verify_bucket_reso_steps(self, min_steps: int): - pass - - def is_latent_cacheable(self) -> bool: - return False - - def __len__(self): - raise NotImplementedError - - # override to avoid shuffling buckets - def set_current_epoch(self, epoch): - self.current_epoch = epoch - - def __getitem__(self, idx): - r""" - The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects. - - Returns: example like this: - - for i in range(batch_size): - image_key = ... # whatever hashable - image_keys.append(image_key) - - image = ... # PIL Image - img_tensor = self.image_transforms(img) - images.append(img_tensor) - - caption = ... # str - input_ids = self.get_input_ids(caption) - input_ids_list.append(input_ids) - - captions.append(caption) - - images = torch.stack(images, dim=0) - input_ids_list = torch.stack(input_ids_list, dim=0) - example = { - "images": images, - "input_ids": input_ids_list, - "captions": captions, # for debug_dataset - "latents": None, - "image_keys": image_keys, # for debug_dataset - "loss_weights": torch.ones(batch_size, dtype=torch.float32), - } - return example - """ - raise NotImplementedError - - -def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: - module = ".".join(args.dataset_class.split(".")[:-1]) - dataset_class = args.dataset_class.split(".")[-1] - module = importlib.import_module(module) - dataset_class = getattr(module, dataset_class) - train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset) - return train_dataset_group - - -def load_image(image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - -# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) -def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] -) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: - image_height, image_width = image.shape[0:2] - original_size = (image_width, image_height) # size before resize - - if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - - image_height, image_width = image.shape[0:2] - - if image_width > reso[0]: - trim_size = image_width - reso[0] - p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # logger.info(f"w {trim_size} {p}") - image = image[:, p : p + reso[0]] - if image_height > reso[1]: - trim_size = image_height - reso[1] - p = trim_size // 2 if not random_crop else random.randint(0, trim_size) - # logger.info(f"h {trim_size} {p}) - image = image[p : p + reso[1]] - - # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない - # I have no idea how to reflect the cropped value in crop left/top in the case of random crop - - crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size) - - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image, original_size, crop_ltrb - - -def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool -) -> None: - r""" - requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz - optionally requires image_infos to have: image - if cache_to_disk is True, set info.latents_npz - flipped latents is also saved if flip_aug is True - if cache_to_disk is False, set info.latents - latents_flipped is also set if flip_aug is True - latents_original_size and latents_crop_ltrb are also set - """ - images = [] - for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) - # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - image = IMAGE_TRANSFORMS(image) - images.append(image) - - info.latents_original_size = original_size - info.latents_crop_ltrb = crop_ltrb - - img_tensors = torch.stack(images, dim=0) - img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) - - with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - - if flip_aug: - img_tensors = torch.flip(img_tensors, dims=[3]) - with torch.no_grad(): - flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - else: - flipped_latents = [None] * len(latents) - - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): - # check NaN - if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): - raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") - - if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) - else: - info.latents = latent - if flip_aug: - info.latents_flipped = flipped_latent - - if not HIGH_VRAM: - clean_memory_on_device(vae.device) - - -def cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype -): - input_ids1 = input_ids1.to(text_encoders[0].device) - input_ids2 = input_ids2.to(text_encoders[1].device) - - with torch.no_grad(): - b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( - max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - dtype, - ) - - # ここでcpuに移動しておかないと、上書きされてしまう - b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 - b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280 - b_pool2 = b_pool2.detach().to("cpu") # b,1280 - - for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): - if cache_to_disk: - save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2) - else: - info.text_encoder_outputs1 = hidden_state1 - info.text_encoder_outputs2 = hidden_state2 - info.text_encoder_pool2 = pool2 - - -def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): - np.savez( - npz_path, - hidden_state1=hidden_state1.cpu().float().numpy(), - hidden_state2=hidden_state2.cpu().float().numpy(), - pool2=pool2.cpu().float().numpy(), - ) - - -def load_text_encoder_outputs_from_disk(npz_path): - with np.load(npz_path) as f: - hidden_state1 = torch.from_numpy(f["hidden_state1"]) - hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None - pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None - return hidden_state1, hidden_state2, pool2 - - -# endregion - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -def model_hash(filename): - """Old model hash used by stable-diffusion-webui""" - try: - with open(filename, "rb") as file: - m = hashlib.sha256() - - file.seek(0x100000) - m.update(file.read(0x10000)) - return m.hexdigest()[0:8] - except FileNotFoundError: - return "NOFILE" - except IsADirectoryError: # Linux? - return "IsADirectory" - except PermissionError: # Windows - return "IsADirectory" - - -def calculate_sha256(filename): - """New model hash used by stable-diffusion-webui""" - try: - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 - - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(blksize), b""): - hash_sha256.update(chunk) - - return hash_sha256.hexdigest() - except FileNotFoundError: - return "NOFILE" - except IsADirectoryError: # Linux? - return "IsADirectory" - except PermissionError: # Windows - return "IsADirectory" - - -def precalculate_safetensors_hashes(tensors, metadata): - """Precalculate the model hashes needed by sd-webui-additional-networks to - save time on indexing the model later.""" - - # Because writing user metadata to the file can change the result of - # sd_models.model_hash(), only retain the training metadata for purposes of - # calculating the hash, as they are meant to be immutable - metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} - - bytes = safetensors.torch.save(tensors, metadata) - b = BytesIO(bytes) - - model_hash = addnet_hash_safetensors(b) - legacy_hash = addnet_hash_legacy(b) - return model_hash, legacy_hash - - -def addnet_hash_legacy(b): - """Old model hash used by sd-webui-additional-networks for .safetensors format files""" - m = hashlib.sha256() - - b.seek(0x100000) - m.update(b.read(0x10000)) - return m.hexdigest()[0:8] - - -def addnet_hash_safetensors(b): - """New model hash used by sd-webui-additional-networks for .safetensors format files""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 - - b.seek(0) - header = b.read(8) - n = int.from_bytes(header, "little") - - offset = n + 8 - b.seek(offset) - for chunk in iter(lambda: b.read(blksize), b""): - hash_sha256.update(chunk) - - return hash_sha256.hexdigest() - - -def get_git_revision_hash() -> str: - try: - return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() - except: - return "(unknown)" - - -# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): -# replace_attentions_for_hypernetwork() -# # unet is not used currently, but it is here for future use -# unet.enable_xformers_memory_efficient_attention() -# return -# if mem_eff_attn: -# unet.set_attn_processor(FlashAttnProcessor()) -# elif xformers: -# unet.enable_xformers_memory_efficient_attention() - - -# def replace_unet_cross_attn_to_xformers(): -# logger.info("CrossAttention.forward has been replaced to enable xformers.") -# try: -# import xformers.ops -# except ImportError: -# raise ImportError("No xformers / xformersがインストールされていないようです") - -# def forward_xformers(self, x, context=None, mask=None): -# h = self.heads -# q_in = self.to_q(x) - -# context = default(context, x) -# context = context.to(x.dtype) - -# if hasattr(self, "hypernetwork") and self.hypernetwork is not None: -# context_k, context_v = self.hypernetwork.forward(x, context) -# context_k = context_k.to(x.dtype) -# context_v = context_v.to(x.dtype) -# else: -# context_k = context -# context_v = context - -# k_in = self.to_k(context_k) -# v_in = self.to_v(context_v) - -# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) -# del q_in, k_in, v_in - -# q = q.contiguous() -# k = k.contiguous() -# v = v.contiguous() -# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - -# out = rearrange(out, "b n h d -> b n (h d)", h=h) - -# # diffusers 0.7.0~ -# out = self.to_out[0](out) -# out = self.to_out[1](out) -# return out - - -# diffusers.models.attention.CrossAttention.forward = forward_xformers -def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - logger.info("Enable memory efficient attention for U-Net") - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - logger.info("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - logger.info("Enable SDPA for U-Net") - unet.set_use_sdpa(True) - - -""" -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): - # vae is not used currently, but it is here for future use - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ - logger.info("Use Diffusers xformers for VAE") - vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) - vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) - - -def replace_vae_attn_to_memory_efficient(): - logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states): - logger.info("forward_flash_attn") - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn -""" - - -# endregion - - -# region arguments - - -def load_metadata_from_safetensors(safetensors_file: str) -> dict: - """r - This method locks the file. see https://github.com/huggingface/safetensors/issues/164 - If the file isn't .safetensors or doesn't have metadata, return empty dict. - """ - if os.path.splitext(safetensors_file)[1] != ".safetensors": - return {} - - with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f: - metadata = f.metadata() - if metadata is None: - metadata = {} - return metadata - - -# this metadata is referred from train_network and various scripts, so we wrote here -SS_METADATA_KEY_V2 = "ss_v2" -SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version" -SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module" -SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim" -SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha" -SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args" - -SS_METADATA_MINIMUM_KEYS = [ - SS_METADATA_KEY_V2, - SS_METADATA_KEY_BASE_MODEL_VERSION, - SS_METADATA_KEY_NETWORK_MODULE, - SS_METADATA_KEY_NETWORK_DIM, - SS_METADATA_KEY_NETWORK_ALPHA, - SS_METADATA_KEY_NETWORK_ARGS, -] - - -def build_minimum_network_metadata( - v2: Optional[bool], - base_model: Optional[str], - network_module: str, - network_dim: str, - network_alpha: str, - network_args: Optional[dict], -): - # old LoRA doesn't have base_model - metadata = { - SS_METADATA_KEY_NETWORK_MODULE: network_module, - SS_METADATA_KEY_NETWORK_DIM: network_dim, - SS_METADATA_KEY_NETWORK_ALPHA: network_alpha, - } - if v2 is not None: - metadata[SS_METADATA_KEY_V2] = v2 - if base_model is not None: - metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model - if network_args is not None: - metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args) - return metadata - - -def get_sai_model_spec( - state_dict: dict, - args: argparse.Namespace, - sdxl: bool, - lora: bool, - textual_inversion: bool, - is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA -): - timestamp = time.time() - - v2 = args.v2 - v_parameterization = args.v_parameterization - reso = args.resolution - - title = args.metadata_title if args.metadata_title is not None else args.output_name - - if args.min_timestep is not None or args.max_timestep is not None: - min_time_step = args.min_timestep if args.min_timestep is not None else 0 - max_time_step = args.max_timestep if args.max_timestep is not None else 1000 - timesteps = (min_time_step, max_time_step) - else: - timesteps = None - - metadata = sai_model_spec.build_metadata( - state_dict, - v2, - v_parameterization, - sdxl, - lora, - textual_inversion, - timestamp, - title=title, - reso=reso, - is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, - author=args.metadata_author, - description=args.metadata_description, - license=args.metadata_license, - tags=args.metadata_tags, - timesteps=timesteps, - clip_skip=args.clip_skip, # None or int - ) - return metadata - - -def add_sd_models_arguments(parser: argparse.ArgumentParser): - # for pretrained models - parser.add_argument( - "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" - ) - parser.add_argument( - "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - - -def add_optimizer_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--optimizer_type", - type=str, - default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", - ) - - # backward compatibility - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)", - ) - parser.add_argument( - "--use_lion_optimizer", - action="store_true", - help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)", - ) - - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument( - "--max_grad_norm", - default=1.0, - type=float, - help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", - ) - - parser.add_argument( - "--optimizer_args", - type=str, - default=None, - nargs="*", - help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', - ) - - parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") - parser.add_argument( - "--lr_scheduler_args", - type=str, - default=None, - nargs="*", - help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', - ) - - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", - ) - parser.add_argument( - "--lr_warmup_steps", - type=int, - default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", - ) - parser.add_argument( - "--lr_scheduler_num_cycles", - type=int, - default=1, - help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", - ) - parser.add_argument( - "--lr_scheduler_power", - type=float, - default=1, - help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", - ) - - -def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument( - "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" - ) - parser.add_argument( - "--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名" - ) - parser.add_argument( - "--huggingface_repo_id", - type=str, - default=None, - help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", - ) - parser.add_argument( - "--huggingface_repo_type", - type=str, - default=None, - help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", - ) - parser.add_argument( - "--huggingface_path_in_repo", - type=str, - default=None, - help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", - ) - parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") - parser.add_argument( - "--huggingface_repo_visibility", - type=str, - default=None, - help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", - ) - parser.add_argument( - "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" - ) - parser.add_argument( - "--resume_from_huggingface", - action="store_true", - help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", - ) - parser.add_argument( - "--async_upload", - action="store_true", - help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", - ) - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving / 保存時に精度を変更して保存する", - ) - parser.add_argument( - "--save_every_n_epochs", - type=int, - default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", - ) - parser.add_argument( - "--save_every_n_steps", - type=int, - default=None, - help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", - ) - parser.add_argument( - "--save_n_epoch_ratio", - type=int, - default=None, - help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)", - ) - parser.add_argument( - "--save_last_n_epochs", - type=int, - default=None, - help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", - ) - parser.add_argument( - "--save_last_n_epochs_state", - type=int, - default=None, - help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", - ) - parser.add_argument( - "--save_last_n_steps", - type=int, - default=None, - help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", - ) - parser.add_argument( - "--save_last_n_steps_state", - type=int, - default=None, - help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", - ) - parser.add_argument( - "--save_state", - action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", - ) - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - - parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") - parser.add_argument( - "--max_token_length", - type=int, - default=None, - choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)", - ) - parser.add_argument( - "--mem_eff_attn", - action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", - ) - parser.add_argument( - "--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う" - ) - parser.add_argument( - "--dynamo_backend", - type=str, - default="inductor", - # available backends: - # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 - # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], - help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", - ) - parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument( - "--sdpa", - action="store_true", - help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", - ) - parser.add_argument( - "--vae", - type=str, - default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", - ) - - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument( - "--max_train_epochs", - type=int, - default=None, - help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", - ) - parser.add_argument( - "--max_data_loader_n_workers", - type=int, - default=8, - help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", - ) - parser.add_argument( - "--persistent_data_loader_workers", - action="store_true", - help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", - ) - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", - ) - parser.add_argument( - "--mixed_precision", - type=str, - default="no", - choices=["no", "fp16", "bf16"], - help="use mixed precision / 混合精度を使う場合、その精度", - ) - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument( - "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" - ) # TODO move to SDXL training, because it is not supported by SD1/2 - parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") - parser.add_argument( - "--ddp_timeout", - type=int, - default=None, - help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", - ) - parser.add_argument( - "--ddp_gradient_as_bucket_view", - action="store_true", - help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", - ) - parser.add_argument( - "--ddp_static_graph", - action="store_true", - help="enable static_graph for DDP / DDPでstatic_graphを有効にする", - ) - parser.add_argument( - "--clip_skip", - type=int, - default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)", - ) - parser.add_argument( - "--logging_dir", - type=str, - default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", - ) - parser.add_argument( - "--log_with", - type=str, - default=None, - choices=["tensorboard", "wandb", "all"], - help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", - ) - parser.add_argument( - "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" - ) - parser.add_argument( - "--log_tracker_name", - type=str, - default=None, - help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", - ) - parser.add_argument( - "--wandb_run_name", - type=str, - default=None, - help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", - ) - parser.add_argument( - "--log_tracker_config", - type=str, - default=None, - help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", - ) - parser.add_argument( - "--wandb_api_key", - type=str, - default=None, - help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", - ) - parser.add_argument( - "--noise_offset", - type=float, - default=None, - help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", - ) - parser.add_argument( - "--multires_noise_iterations", - type=int, - default=None, - help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", - ) - parser.add_argument( - "--ip_noise_gamma", - type=float, - default=None, - help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " - + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", - ) - # parser.add_argument( - # "--perlin_noise", - # type=int, - # default=None, - # help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する", - # ) - parser.add_argument( - "--multires_noise_discount", - type=float, - default=0.3, - help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)", - ) - parser.add_argument( - "--adaptive_noise_scale", - type=float, - default=None, - help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)", - ) - parser.add_argument( - "--zero_terminal_snr", - action="store_true", - help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する", - ) - parser.add_argument( - "--min_timestep", - type=int, - default=None, - help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", - ) - parser.add_argument( - "--max_timestep", - type=int, - default=None, - help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", - ) - - parser.add_argument( - "--lowram", - action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", - ) - parser.add_argument( - "--highvram", - action="store_true", - help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " - + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", - ) - - parser.add_argument( - "--sample_every_n_steps", - type=int, - default=None, - help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", - ) - parser.add_argument( - "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" - ) - parser.add_argument( - "--sample_every_n_epochs", - type=int, - default=None, - help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", - ) - parser.add_argument( - "--sample_prompts", - type=str, - default=None, - help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", - ) - parser.add_argument( - "--sample_sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", - ) - - parser.add_argument( - "--config_file", - type=str, - default=None, - help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", - ) - parser.add_argument( - "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" - ) - - # SAI Model spec - parser.add_argument( - "--metadata_title", - type=str, - default=None, - help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", - ) - parser.add_argument( - "--metadata_author", - type=str, - default=None, - help="author name for model metadata / メタデータに書き込まれるモデル作者名", - ) - parser.add_argument( - "--metadata_description", - type=str, - default=None, - help="description for model metadata / メタデータに書き込まれるモデル説明", - ) - parser.add_argument( - "--metadata_license", - type=str, - default=None, - help="license for model metadata / メタデータに書き込まれるモデルライセンス", - ) - parser.add_argument( - "--metadata_tags", - type=str, - default=None, - help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", - ) - - if support_dreambooth: - # DreamBooth training - parser.add_argument( - "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" - ) - - -def verify_training_args(args: argparse.Namespace): - r""" - Verify training arguments. Also reflect highvram option to global variable - 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する - """ - if args.highvram: - print("highvram is enabled / highvramが有効です") - global HIGH_VRAM - HIGH_VRAM = True - - if args.v_parameterization and not args.v2: - logger.warning( - "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" - ) - if args.v2 and args.clip_skip is not None: - logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - - if args.cache_latents_to_disk and not args.cache_latents: - args.cache_latents = True - logger.warning( - "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" - ) - - # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time - # # Listを使って数えてもいいけど並べてしまえ - # if args.noise_offset is not None and args.multires_noise_iterations is not None: - # raise ValueError( - # "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" - # ) - # if args.noise_offset is not None and args.perlin_noise is not None: - # raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません") - # if args.perlin_noise is not None and args.multires_noise_iterations is not None: - # raise ValueError( - # "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません" - # ) - - if args.adaptive_noise_scale is not None and args.noise_offset is None: - raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") - - if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: - raise ValueError( - "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" - ) - - if args.v_pred_like_loss and args.v_parameterization: - raise ValueError( - "v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません" - ) - - if args.zero_terminal_snr and not args.v_parameterization: - logger.warning( - f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" - + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" - ) - - -def add_dataset_arguments( - parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool -): - # dataset common - parser.add_argument( - "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" - ) - parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" - ) - parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") - parser.add_argument( - "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" - ) - parser.add_argument( - "--caption_extention", - type=str, - default=None, - help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)", - ) - parser.add_argument( - "--keep_tokens", - type=int, - default=0, - help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", - ) - parser.add_argument( - "--keep_tokens_separator", - type=str, - default="", - help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." - + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", - ) - parser.add_argument( - "--caption_prefix", - type=str, - default=None, - help="prefix for caption text / captionのテキストの先頭に付ける文字列", - ) - parser.add_argument( - "--caption_suffix", - type=str, - default=None, - help="suffix for caption text / captionのテキストの末尾に付ける文字列", - ) - parser.add_argument( - "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" - ) - parser.add_argument( - "--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする" - ) - parser.add_argument( - "--face_crop_aug_range", - type=str, - default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)", - ) - parser.add_argument( - "--random_crop", - action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", - ) - parser.add_argument( - "--debug_dataset", - action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)", - ) - parser.add_argument( - "--resolution", - type=str, - default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)", - ) - parser.add_argument( - "--cache_latents", - action="store_true", - help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", - ) - parser.add_argument( - "--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ" - ) - parser.add_argument( - "--cache_latents_to_disk", - action="store_true", - help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", - ) - parser.add_argument( - "--enable_bucket", - action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", - ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") - parser.add_argument( - "--bucket_reso_steps", - type=int, - default=64, - help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", - ) - parser.add_argument( - "--bucket_no_upscale", - action="store_true", - help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", - ) - - parser.add_argument( - "--token_warmup_min", - type=int, - default=1, - help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", - ) - parser.add_argument( - "--token_warmup_step", - type=float, - default=0, - help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", - ) - - parser.add_argument( - "--dataset_class", - type=str, - default=None, - help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)", - ) - - if support_caption_dropout: - # Textual Inversion はcaptionのdropoutをsupportしない - # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに - parser.add_argument( - "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合" - ) - parser.add_argument( - "--caption_dropout_every_n_epochs", - type=int, - default=0, - help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする", - ) - parser.add_argument( - "--caption_tag_dropout_rate", - type=float, - default=0.0, - help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合", - ) - - if support_dreambooth: - # DreamBooth dataset - parser.add_argument( - "--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ" - ) - - if support_caption: - # caption dataset - parser.add_argument( - "--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル" - ) - parser.add_argument( - "--dataset_repeats", - type=int, - default=1, - help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数", - ) - - -def add_sd_saving_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--save_model_as", - type=str, - default=None, - choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)", - ) - parser.add_argument( - "--use_safetensors", - action="store_true", - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)", - ) - - -def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): - if not args.config_file: - return args - - config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file - - if args.output_config: - # check if config file exists - if os.path.exists(config_path): - logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") - exit(1) - - # convert args to dictionary - args_dict = vars(args) - - # remove unnecessary keys - for key in ["config_file", "output_config", "wandb_api_key"]: - if key in args_dict: - del args_dict[key] - - # get default args from parser - default_args = vars(parser.parse_args([])) - - # remove default values: cannot use args_dict.items directly because it will be changed during iteration - for key, value in list(args_dict.items()): - if key in default_args and value == default_args[key]: - del args_dict[key] - - # convert Path to str in dictionary - for key, value in args_dict.items(): - if isinstance(value, pathlib.Path): - args_dict[key] = str(value) - - # convert to toml and output to file - with open(config_path, "w") as f: - toml.dump(args_dict, f) - - logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") - exit(0) - - if not os.path.exists(config_path): - logger.info(f"{config_path} not found.") - exit(1) - - logger.info(f"Loading settings from {config_path}...") - with open(config_path, "r") as f: - config_dict = toml.load(f) - - # combine all sections into one - ignore_nesting_dict = {} - for section_name, section_dict in config_dict.items(): - # if value is not dict, save key and value as is - if not isinstance(section_dict, dict): - ignore_nesting_dict[section_name] = section_dict - continue - - # if value is dict, save all key and value into one dict - for key, value in section_dict.items(): - ignore_nesting_dict[key] = value - - config_args = argparse.Namespace(**ignore_nesting_dict) - args = parser.parse_args(namespace=config_args) - args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) - - return args - - -# endregion - -# region utils - - -def resume_from_local_or_hf_if_specified(accelerator, args): - if not args.resume: - return - - if not args.resume_from_huggingface: - logger.info(f"resume training from local state: {args.resume}") - accelerator.load_state(args.resume) - return - - logger.info(f"resume training from huggingface state: {args.resume}") - repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] - path_in_repo = "/".join(args.resume.split("/")[2:]) - revision = None - repo_type = None - if ":" in path_in_repo: - divided = path_in_repo.split(":") - if len(divided) == 2: - path_in_repo, revision = divided - repo_type = "model" - else: - path_in_repo, revision, repo_type = divided - logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") - - list_files = huggingface_util.list_dir( - repo_id=repo_id, - subfolder=path_in_repo, - revision=revision, - token=args.huggingface_token, - repo_type=repo_type, - ) - - async def download(filename) -> str: - def task(): - return hf_hub_download( - repo_id=repo_id, - filename=filename, - revision=revision, - repo_type=repo_type, - token=args.huggingface_token, - ) - - return await asyncio.get_event_loop().run_in_executor(None, task) - - loop = asyncio.get_event_loop() - results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) - if len(results) == 0: - raise ValueError( - "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" - ) - dirname = os.path.dirname(results[0]) - accelerator.load_state(dirname) - - -def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" - - optimizer_type = args.optimizer_type - if args.use_8bit_adam: - assert ( - not args.use_lion_optimizer - ), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" - assert ( - optimizer_type is None or optimizer_type == "" - ), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" - optimizer_type = "AdamW8bit" - - elif args.use_lion_optimizer: - assert ( - optimizer_type is None or optimizer_type == "" - ), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" - optimizer_type = "Lion" - - if optimizer_type is None or optimizer_type == "": - optimizer_type = "AdamW" - optimizer_type = optimizer_type.lower() - - # 引数を分解する - optimizer_kwargs = {} - if args.optimizer_args is not None and len(args.optimizer_args) > 0: - for arg in args.optimizer_args: - key, value = arg.split("=") - value = ast.literal_eval(value) - - # value = value.split(",") - # for i in range(len(value)): - # if value[i].lower() == "true" or value[i].lower() == "false": - # value[i] = value[i].lower() == "true" - # else: - # value[i] = ast.float(value[i]) - # if len(value) == 1: - # value = value[0] - # else: - # value = tuple(value) - - optimizer_kwargs[key] = value - # logger.info(f"optkwargs {optimizer}_{kwargs}") - - lr = args.learning_rate - optimizer = None - - if optimizer_type == "Lion".lower(): - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - logger.info(f"use Lion optimizer | {optimizer_kwargs}") - optimizer_class = lion_pytorch.Lion - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type.endswith("8bit".lower()): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") - - if optimizer_type == "AdamW8bit".lower(): - logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov8bit".lower(): - logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - logger.warning( - f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" - ) - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "Lion8bit".lower(): - logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.Lion8bit - except AttributeError: - raise AttributeError( - "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" - ) - elif optimizer_type == "PagedAdamW8bit".lower(): - logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdamW8bit - except AttributeError: - raise AttributeError( - "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - elif optimizer_type == "PagedLion8bit".lower(): - logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedLion8bit - except AttributeError: - raise AttributeError( - "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "PagedAdamW".lower(): - logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") - try: - optimizer_class = bnb.optim.PagedAdamW - except AttributeError: - raise AttributeError( - "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "PagedAdamW32bit".lower(): - logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") - try: - optimizer_class = bnb.optim.PagedAdamW32bit - except AttributeError: - raise AttributeError( - "No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov".lower(): - logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - logger.info( - f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します" - ) - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = torch.optim.SGD - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): - # check lr and lr_count, and logger.info warning - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - logger.warning( - f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" - ) - logger.warning("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - logger.warning( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - if optimizer_type.startswith("DAdapt".lower()): - # DAdaptation family - # check dadaptation is installed - try: - import dadaptation - import dadaptation.experimental as experimental - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - - # set optimizer - if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): - optimizer_class = experimental.DAdaptAdamPreprint - logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdaGrad".lower(): - optimizer_class = dadaptation.DAdaptAdaGrad - logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdam".lower(): - optimizer_class = dadaptation.DAdaptAdam - logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdan".lower(): - optimizer_class = dadaptation.DAdaptAdan - logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdanIP".lower(): - optimizer_class = experimental.DAdaptAdanIP - logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptLion".lower(): - optimizer_class = dadaptation.DAdaptLion - logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptSGD".lower(): - optimizer_class = dadaptation.DAdaptSGD - logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") - else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") - - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - else: - # Prodigy - # check Prodigy is installed - try: - import prodigyopt - except ImportError: - raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - - logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") - optimizer_class = prodigyopt.Prodigy - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "Adafactor".lower(): - # 引数を確認して適宜補正する - if "relative_step" not in optimizer_kwargs: - optimizer_kwargs["relative_step"] = True # default - if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - logger.info( - f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" - ) - optimizer_kwargs["relative_step"] = True - logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") - - if optimizer_kwargs["relative_step"]: - logger.info(f"relative_step is true / relative_stepがtrueです") - if lr != 0.0: - logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") - args.learning_rate = None - - # trainable_paramsがgroupだった時の処理:lrを削除する - if type(trainable_params) == list and type(trainable_params[0]) == dict: - has_group_lr = False - for group in trainable_params: - p = group.pop("lr", None) - has_group_lr = has_group_lr or (p is not None) - - if has_group_lr: - # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") - args.unet_lr = None - args.text_encoder_lr = None - - if args.lr_scheduler != "adafactor": - logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") - args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど - - lr = None - else: - if args.max_grad_norm != 0.0: - logger.warning( - f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" - ) - if args.lr_scheduler != "constant_with_warmup": - logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") - if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") - - optimizer_class = transformers.optimization.Adafactor - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "AdamW".lower(): - logger.info(f"use AdamW optimizer | {optimizer_kwargs}") - optimizer_class = torch.optim.AdamW - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - if optimizer is None: - # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - logger.info(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split(".") - optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] - - optimizer_class = getattr(optimizer_module, optimizer_type) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ - optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - - return optimizer_name, optimizer_args, optimizer - - -# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler -# Add some checking and features to the original function. - - -def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): - """ - Unified API to get any scheduler from its name. - """ - name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps - num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps - num_cycles = args.lr_scheduler_num_cycles - power = args.lr_scheduler_power - - lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs - if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: - for arg in args.lr_scheduler_args: - key, value = arg.split("=") - value = ast.literal_eval(value) - lr_scheduler_kwargs[key] = value - - def wrap_check_needless_num_warmup_steps(return_vals): - if num_warmup_steps is not None and num_warmup_steps != 0: - raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") - return return_vals - - # using any lr_scheduler from other library - if args.lr_scheduler_type: - lr_scheduler_type = args.lr_scheduler_type - logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") - if "." not in lr_scheduler_type: # default to use torch.optim - lr_scheduler_module = torch.optim.lr_scheduler - else: - values = lr_scheduler_type.split(".") - lr_scheduler_module = importlib.import_module(".".join(values[:-1])) - lr_scheduler_type = values[-1] - lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) - lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) - return wrap_check_needless_num_warmup_steps(lr_scheduler) - - if name.startswith("adafactor"): - assert ( - type(optimizer) == transformers.optimization.Adafactor - ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" - initial_lr = float(name.split(":")[1]) - # logger.info(f"adafactor scheduler init lr {initial_lr}") - return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - - if name == SchedulerType.CONSTANT: - return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - - if name == SchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - num_cycles=num_cycles, - **lr_scheduler_kwargs, - ) - - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs - ) - - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) - - -def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): - # backward compatibility - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None - - # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" - if args.resolution is not None: - args.resolution = tuple([int(r) for r in args.resolution.split(",")]) - if len(args.resolution) == 1: - args.resolution = (args.resolution[0], args.resolution[0]) - assert ( - len(args.resolution) == 2 - ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - - if args.face_crop_aug_range is not None: - args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) - assert ( - len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] - ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - args.face_crop_aug_range = None - - if support_metadata: - if args.in_json is not None and (args.color_aug or args.random_crop): - logger.warning( - f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" - ) - - -def load_tokenizer(args: argparse.Namespace): - logger.info("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) - if os.path.exists(local_tokenizer_path): - logger.info(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) - - if hasattr(args, "max_token_length") and args.max_token_length is not None: - logger.info(f"update token length: {args.max_token_length}") - - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) - - return tokenizer - - -def prepare_accelerator(args: argparse.Namespace): - if args.logging_dir is None: - logging_dir = None - else: - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) - - if args.log_with is None: - if logging_dir is not None: - log_with = "tensorboard" - else: - log_with = None - else: - log_with = args.log_with - if log_with in ["tensorboard", "all"]: - if logging_dir is None: - raise ValueError( - "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" - ) - if log_with in ["wandb", "all"]: - try: - import wandb - except ImportError: - raise ImportError("No wandb / wandb がインストールされていないようです") - if logging_dir is not None: - os.makedirs(logging_dir, exist_ok=True) - os.environ["WANDB_DIR"] = logging_dir - if args.wandb_api_key is not None: - wandb.login(key=args.wandb_api_key) - - # torch.compile のオプション。 NO の場合は torch.compile は使わない - dynamo_backend = "NO" - if args.torch_compile: - dynamo_backend = args.dynamo_backend - - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - ( - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph - ) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None - ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=log_with, - project_dir=logging_dir, - kwargs_handlers=kwargs_handlers, - dynamo_backend=dynamo_backend, - ) - print("accelerator device:", accelerator.device) - return accelerator - - -def prepare_dtype(args: argparse.Namespace): - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 - - return weight_dtype, save_dtype - - -def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): - name_or_path = args.pretrained_model_name_or_path - name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path - load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers - if load_stable_diffusion_format: - logger.info(f"load StableDiffusion checkpoint: {name_or_path}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( - args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 - ) - else: - # Diffusers model is loaded to CPU - logger.info(f"load Diffusers pretrained models: {name_or_path}") - try: - pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) - except EnvironmentError as ex: - logger.error( - f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" - ) - raise ex - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe - - # Diffusers U-Net to original U-Net - # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう - # logger.info(f"unet config: {unet.config}") - original_unet = UNet2DConditionModel( - unet.config.sample_size, - unet.config.attention_head_dim, - unet.config.cross_attention_dim, - unet.config.use_linear_projection, - unet.config.upcast_attention, - ) - original_unet.load_state_dict(unet.state_dict()) - unet = original_unet - logger.info("U-Net converted to original U-Net") - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - logger.info("additional VAE loaded") - - return text_encoder, vae, unet, load_stable_diffusion_format - - -def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): - # load models for each process - for pi in range(accelerator.state.num_processes): - if pi == accelerator.state.local_process_index: - logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - - text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( - args, - weight_dtype, - accelerator.device if args.lowram else "cpu", - unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, - ) - - # work on low-ram device - if args.lowram: - text_encoder.to(accelerator.device) - unet.to(accelerator.device) - vae.to(accelerator.device) - - clean_memory_on_device(accelerator.device) - accelerator.wait_for_everyone() - - return text_encoder, vae, unet, load_stable_diffusion_format - - -def patch_accelerator_for_fp16_training(accelerator): - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer - - -def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): - # with no_token_padding, the length is not max length, return result immediately - if input_ids.size()[-1] != tokenizer.model_max_length: - return text_encoder(input_ids)[0] - - # input_ids: b,n,77 - b_size = input_ids.size()[0] - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append( - encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] - ) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) - - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) - - return encoder_hidden_states - - -def pool_workaround( - text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int -): - r""" - workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output - instead of the hidden states for the EOS token - If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output - - Original code from CLIP's pooling function: - - \# text_embeds.shape = [batch_size, sequence_length, transformer.width] - \# take features from the eot embedding (eot_token is the highest number in each sequence) - \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), - ] - """ - - # input_ids: b*n,77 - # find index for EOS token - - # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) - # eos_token_index = torch.where(input_ids == eos_token_id)[1] - # eos_token_index = eos_token_index.to(device=last_hidden_state.device) - - # Create a mask where the EOS tokens are - eos_token_mask = (input_ids == eos_token_id).int() - - # Use argmax to find the last index of the EOS token for each element in the batch - eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine - eos_token_index = eos_token_index.to(device=last_hidden_state.device) - - # get hidden states for EOS token - pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] - - # apply projection: projection may be of different dtype than last_hidden_state - pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) - pooled_output = pooled_output.to(last_hidden_state.dtype) - - return pooled_output - - -def get_hidden_states_sdxl( - max_token_length: int, - input_ids1: torch.Tensor, - input_ids2: torch.Tensor, - tokenizer1: CLIPTokenizer, - tokenizer2: CLIPTokenizer, - text_encoder1: CLIPTextModel, - text_encoder2: CLIPTextModelWithProjection, - weight_dtype: Optional[str] = None, - accelerator: Optional[Accelerator] = None, -): - # input_ids: b,n,77 -> b*n, 77 - b_size = input_ids1.size()[0] - input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 - input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 - - # text_encoder1 - enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) - hidden_states1 = enc_out["hidden_states"][11] - - # text_encoder2 - enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) - hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer - - # pool2 = enc_out["text_embeds"] - unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2) - pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) - - # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 - n_size = 1 if max_token_length is None else max_token_length // 75 - hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) - hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) - - if max_token_length is not None: - # bs*3, 77, 768 or 1024 - # encoder1: ... の三連を ... へ戻す - states_list = [hidden_states1[:, 0].unsqueeze(1)] # - for i in range(1, max_token_length, tokenizer1.model_max_length): - states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで - states_list.append(hidden_states1[:, -1].unsqueeze(1)) # - hidden_states1 = torch.cat(states_list, dim=1) - - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [hidden_states2[:, 0].unsqueeze(1)] # - for i in range(1, max_token_length, tokenizer2.model_max_length): - chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで - # this causes an error: - # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation - # if i > 1: - # for j in range(len(chunk)): # batch_size - # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン - # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか - hidden_states2 = torch.cat(states_list, dim=1) - - # pool はnの最初のものを使う - pool2 = pool2[::n_size] - - if weight_dtype is not None: - # this is required for additional network training - hidden_states1 = hidden_states1.to(weight_dtype) - hidden_states2 = hidden_states2.to(weight_dtype) - - return hidden_states1, hidden_states2, pool2 - - -def default_if_none(value, default): - return default if value is None else value - - -def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int): - model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext - - -def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int): - model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - return STEP_FILE_NAME.format(model_name, step_no) + ext - - -def get_last_ckpt_name(args: argparse.Namespace, ext: str): - model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - return model_name + ext - - -def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int): - if args.save_last_n_epochs is None: - return None - - remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs - if remove_epoch_no < 0: - return None - return remove_epoch_no - - -def get_remove_step_no(args: argparse.Namespace, step_no: int): - if args.save_last_n_steps is None: - return None - - # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する - # save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する - remove_step_no = step_no - args.save_last_n_steps - 1 - remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) - if remove_step_no < 0: - return None - return remove_step_no - - -# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している -# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 -def save_sd_model_on_epoch_end_or_stepwise( - args: argparse.Namespace, - on_epoch_end: bool, - accelerator, - src_path: str, - save_stable_diffusion_format: bool, - use_safetensors: bool, - save_dtype: torch.dtype, - epoch: int, - num_train_epochs: int, - global_step: int, - text_encoder, - unet, - vae, -): - def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae - ) - - def diffusers_saver(out_dir): - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) - - save_sd_model_on_epoch_end_or_stepwise_common( - args, - on_epoch_end, - accelerator, - save_stable_diffusion_format, - use_safetensors, - epoch, - num_train_epochs, - global_step, - sd_saver, - diffusers_saver, - ) - - -def save_sd_model_on_epoch_end_or_stepwise_common( - args: argparse.Namespace, - on_epoch_end: bool, - accelerator, - save_stable_diffusion_format: bool, - use_safetensors: bool, - epoch: int, - num_train_epochs: int, - global_step: int, - sd_saver, - diffusers_saver, -): - if on_epoch_end: - epoch_no = epoch + 1 - saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs - if not saving: - return - - model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - remove_no = get_remove_epoch_no(args, epoch_no) - else: - # 保存するか否かは呼び出し側で判断済み - - model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - epoch_no = epoch # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される - remove_no = get_remove_step_no(args, global_step) - - os.makedirs(args.output_dir, exist_ok=True) - if save_stable_diffusion_format: - ext = ".safetensors" if use_safetensors else ".ckpt" - - if on_epoch_end: - ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no) - else: - ckpt_name = get_step_ckpt_name(args, ext, global_step) - - ckpt_file = os.path.join(args.output_dir, ckpt_name) - logger.info("") - logger.info(f"saving checkpoint: {ckpt_file}") - sd_saver(ckpt_file, epoch_no, global_step) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) - - # remove older checkpoints - if remove_no is not None: - if on_epoch_end: - remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no) - else: - remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no) - - remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) - if os.path.exists(remove_ckpt_file): - logger.info(f"removing old checkpoint: {remove_ckpt_file}") - os.remove(remove_ckpt_file) - - else: - if on_epoch_end: - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) - else: - out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) - - logger.info("") - logger.info(f"saving model: {out_dir}") - diffusers_saver(out_dir) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, out_dir, "/" + model_name) - - # remove older checkpoints - if remove_no is not None: - if on_epoch_end: - remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) - else: - remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) - - if os.path.exists(remove_out_dir): - logger.info(f"removing old model: {remove_out_dir}") - shutil.rmtree(remove_out_dir) - - if args.save_state: - if on_epoch_end: - save_and_remove_state_on_epoch_end(args, accelerator, epoch_no) - else: - save_and_remove_state_stepwise(args, accelerator, global_step) - - -def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): - model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) - - logger.info("") - logger.info(f"saving state at epoch {epoch_no}") - os.makedirs(args.output_dir, exist_ok=True) - - state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) - accelerator.save_state(state_dir) - if args.save_state_to_huggingface: - logger.info("uploading state to huggingface.") - huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) - - last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs - if last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - logger.info(f"removing old state: {state_dir_old}") - shutil.rmtree(state_dir_old) - - -def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): - model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) - - logger.info("") - logger.info(f"saving state at step {step_no}") - os.makedirs(args.output_dir, exist_ok=True) - - state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) - accelerator.save_state(state_dir) - if args.save_state_to_huggingface: - logger.info("uploading state to huggingface.") - huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) - - last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps - if last_n_steps is not None: - # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する - remove_step_no = step_no - last_n_steps - 1 - remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) - - if remove_step_no > 0: - state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) - if os.path.exists(state_dir_old): - logger.info(f"removing old state: {state_dir_old}") - shutil.rmtree(state_dir_old) - - -def save_state_on_train_end(args: argparse.Namespace, accelerator): - model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - - logger.info("") - logger.info("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - - state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) - accelerator.save_state(state_dir) - - if args.save_state_to_huggingface: - logger.info("uploading last state to huggingface.") - huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) - - -def save_sd_model_on_train_end( - args: argparse.Namespace, - src_path: str, - save_stable_diffusion_format: bool, - use_safetensors: bool, - save_dtype: torch.dtype, - epoch: int, - global_step: int, - text_encoder, - unet, - vae, -): - def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae - ) - - def diffusers_saver(out_dir): - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) - - save_sd_model_on_train_end_common( - args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver - ) - - -def save_sd_model_on_train_end_common( - args: argparse.Namespace, - save_stable_diffusion_format: bool, - use_safetensors: bool, - epoch: int, - global_step: int, - sd_saver, - diffusers_saver, -): - model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) - - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) - - ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - sd_saver(ckpt_file, epoch, global_step) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) - else: - out_dir = os.path.join(args.output_dir, model_name) - os.makedirs(out_dir, exist_ok=True) - - logger.info(f"save trained model as Diffusers to {out_dir}") - diffusers_saver(out_dir) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) - - -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - if args.multires_noise_iterations: - noise = custom_train_functions.pyramid_noise_like( - noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount - ) - - # Sample a random timestep for each image - b_size = latents.shape[0] - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - return noise, noisy_latents, timesteps - - -def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): - names = [] - if including_unet: - names.append("unet") - names.append("text_encoder1") - names.append("text_encoder2") - - append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) - - -def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names): - lrs = lr_scheduler.get_last_lr() - - for lr_index in range(len(lrs)): - name = names[lr_index] - logs["lr/" + name] = float(lrs[lr_index]) - - if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): - logs["lr/d*lr/" + name] = ( - lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] - ) - - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - - -def get_my_scheduler( - *, - sample_sampler: str, - v_parameterization: bool, -): - sched_init_args = {} - if sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif sample_sampler == "lms" or sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif sample_sampler == "euler" or sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = sample_sampler - elif sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # logger.info("set clip_sample to True") - scheduler.config.clip_sample = True - - return scheduler - - -def sample_images(*args, **kwargs): - return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) - - -def line_to_prompt_dict(line: str) -> dict: - # subset of gen_img_diffusers - prompt_args = line.split(" --") - prompt_dict = {} - prompt_dict["prompt"] = prompt_args[0] - - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - prompt_dict["width"] = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - prompt_dict["height"] = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - prompt_dict["seed"] = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - prompt_dict["scale"] = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - prompt_dict["negative_prompt"] = m.group(1) - continue - - m = re.match(r"ss (.+)", parg, re.IGNORECASE) - if m: - prompt_dict["sample_sampler"] = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: - prompt_dict["controlnet_image"] = m.group(1) - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(ex) - - return prompt_dict - - -def sample_images_common( - pipe_class, - accelerator: Accelerator, - args: argparse.Namespace, - epoch, - steps, - device, - vae, - tokenizer, - text_encoder, - unet, - prompt_replacement=None, - controlnet=None, -): - """ - StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した - """ - - if steps == 0: - if not args.sample_at_first: - return - else: - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: - return - else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch - return - - logger.info("") - logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): - logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") - return - - distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here - - org_vae_device = vae.device # CPUにいるはず - vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device - - # unwrap unet and text_encoder(s) - unet = accelerator.unwrap_model(unet) - if isinstance(text_encoder, (list, tuple)): - text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] - else: - text_encoder = accelerator.unwrap_model(text_encoder) - - # read prompts - if args.sample_prompts.endswith(".txt"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif args.sample_prompts.endswith(".toml"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - data = toml.load(f) - prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] - elif args.sample_prompts.endswith(".json"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - prompts = json.load(f) - - # schedulers: dict = {} cannot find where this is used - default_scheduler = get_my_scheduler( - sample_sampler=args.sample_sampler, - v_parameterization=args.v_parameterization, - ) - - pipeline = pipe_class( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=default_scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - clip_skip=args.clip_skip, - ) - pipeline.to(distributed_state.device) - save_dir = args.output_dir + "/sample" - os.makedirs(save_dir, exist_ok=True) - - # preprocess prompts - for i in range(len(prompts)): - prompt_dict = prompts[i] - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - prompts[i] = prompt_dict - assert isinstance(prompt_dict, dict) - - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) - - # save random state to restore later - rng_state = torch.get_rng_state() - cuda_rng_state = None - try: - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - except Exception: - pass - - if distributed_state.num_processes <= 1: - # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): - for prompt_dict in prompts: - sample_image_inference( - accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet - ) - else: - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [] # list of lists - for i in range(distributed_state.num_processes): - per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - sample_image_inference( - accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet - ) - - # clear pipeline and cache to reduce vram usage - del pipeline - - # I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here. - # with torch.cuda.device(torch.cuda.current_device()): - # torch.cuda.empty_cache() - clean_memory_on_device(accelerator.device) - - torch.set_rng_state(rng_state) - if cuda_rng_state is not None: - torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) - - -def sample_image_inference( - accelerator: Accelerator, - args: argparse.Namespace, - pipeline, - save_dir, - prompt_dict, - epoch, - steps, - prompt_replacement, - controlnet=None, -): - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - else: - # True random sample image generation - torch.seed() - torch.cuda.seed() - - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - pipeline.scheduler = scheduler - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - logger.info(f"prompt: {prompt}") - logger.info(f"negative_prompt: {negative_prompt}") - logger.info(f"height: {height}") - logger.info(f"width: {width}") - logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {scale}") - logger.info(f"sample_sampler: {sampler_name}") - if seed is not None: - logger.info(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) - - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() - - image = pipeline.latents_to_image(latents)[0] - - # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list - # but adding 'enum' to the filename should be enough - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - i: int = prompt_dict["enum"] - img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass - - -# endregion - - -# region 前処理用 - - -class ImageLoadingDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths - - def __len__(self): - return len(self.images) - - def __getitem__(self, idx): - img_path = self.images[idx] - - try: - image = Image.open(img_path).convert("RGB") - # convert to tensor temporarily so dataloader will accept it - tensor_pil = transforms.functional.pil_to_tensor(image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None - - return (tensor_pil, img_path) - - -# endregion - - -# collate_fn用 epoch,stepはmultiprocessing.Value -class collator_class: - def __init__(self, epoch, step, dataset): - self.current_epoch = epoch - self.current_step = step - self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing - - def __call__(self, examples): - worker_info = torch.utils.data.get_worker_info() - # worker_info is None in the main process - if worker_info is not None: - dataset = worker_info.dataset - else: - dataset = self.dataset - - # set epoch and step - dataset.set_current_epoch(self.current_epoch.value) - dataset.set_current_step(self.current_step.value) - return examples[0] - - -class LossRecorder: - def __init__(self): - self.loss_list: List[float] = [] - self.loss_total: float = 0.0 - - def add(self, *, epoch: int, step: int, loss: float) -> None: - if epoch == 0: - self.loss_list.append(loss) - else: - self.loss_total -= self.loss_list[step] - self.loss_list[step] = loss - self.loss_total += loss - - @property - def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) diff --git a/library/utils.py b/library/utils.py deleted file mode 100644 index 3037c055d..000000000 --- a/library/utils.py +++ /dev/null @@ -1,266 +0,0 @@ -import logging -import sys -import threading -import torch -from torchvision import transforms -from typing import * -from diffusers import EulerAncestralDiscreteScheduler -import diffusers.schedulers.scheduling_euler_ancestral_discrete -from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - - -def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() - - -def add_logging_arguments(parser): - parser.add_argument( - "--console_log_level", - type=str, - default=None, - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", - ) - parser.add_argument( - "--console_log_file", - type=str, - default=None, - help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", - ) - parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") - - -def setup_logging(args=None, log_level=None, reset=False): - if logging.root.handlers: - if reset: - # remove all handlers - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - else: - return - - # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO - if log_level is None and args is not None: - log_level = args.console_log_level - if log_level is None: - log_level = "INFO" - log_level = getattr(logging, log_level) - - msg_init = None - if args is not None and args.console_log_file: - handler = logging.FileHandler(args.console_log_file, mode="w") - else: - handler = None - if not args or not args.console_log_simple: - try: - from rich.logging import RichHandler - from rich.console import Console - from rich.logging import RichHandler - - handler = RichHandler(console=Console(stderr=True)) - except ImportError: - # print("rich is not installed, using basic logging") - msg_init = "rich is not installed, using basic logging" - - if handler is None: - handler = logging.StreamHandler(sys.stdout) # same as print - handler.propagate = False - - formatter = logging.Formatter( - fmt="%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - handler.setFormatter(formatter) - logging.root.setLevel(log_level) - logging.root.addHandler(handler) - - if msg_init is not None: - logger = logging.getLogger(__name__) - logger.info(msg_init) - - - -# TODO make inf_utils.py - - -# region Gradual Latent hires fix - - -class GradualLatent: - def __init__( - self, - ratio, - start_timesteps, - every_n_steps, - ratio_step, - s_noise=1.0, - gaussian_blur_ksize=None, - gaussian_blur_sigma=0.5, - gaussian_blur_strength=0.5, - unsharp_target_x=True, - ): - self.ratio = ratio - self.start_timesteps = start_timesteps - self.every_n_steps = every_n_steps - self.ratio_step = ratio_step - self.s_noise = s_noise - self.gaussian_blur_ksize = gaussian_blur_ksize - self.gaussian_blur_sigma = gaussian_blur_sigma - self.gaussian_blur_strength = gaussian_blur_strength - self.unsharp_target_x = unsharp_target_x - - def __str__(self) -> str: - return ( - f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " - + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " - + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " - + f"unsharp_target_x={self.unsharp_target_x})" - ) - - def apply_unshark_mask(self, x: torch.Tensor): - if self.gaussian_blur_ksize is None: - return x - blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) - # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) - mask = (x - blurred) * self.gaussian_blur_strength - sharpened = x + mask - return sharpened - - def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): - org_dtype = x.dtype - if org_dtype == torch.bfloat16: - x = x.float() - - x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - if unsharp and self.gaussian_blur_ksize: - x = self.apply_unshark_mask(x) - - return x - - -class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.resized_size = None - self.gradual_latent = None - - def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): - self.resized_size = size - self.gradual_latent = gradual_latent - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. - - Returns: - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, - [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, - otherwise a tuple is returned where the first element is the sample tensor. - - """ - - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if not self.is_scale_input_called: - # logger.warning( - print( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") - - sigma_from = self.sigmas[self.step_index] - sigma_to = self.sigmas[self.step_index + 1] - sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - - dt = sigma_down - sigma - - device = model_output.device - if self.resized_size is None: - prev_sample = sample + derivative * dt - - noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( - model_output.shape, dtype=model_output.dtype, device=device, generator=generator - ) - s_noise = 1.0 - else: - print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) - s_noise = self.gradual_latent.s_noise - - if self.gradual_latent.unsharp_target_x: - prev_sample = sample + derivative * dt - prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) - else: - sample = self.gradual_latent.interpolate(sample, self.resized_size) - derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) - prev_sample = sample + derivative * dt - - noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( - (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), - dtype=model_output.dtype, - device=device, - generator=generator, - ) - - prev_sample = prev_sample + noise * sigma_up * s_noise - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - -# endregion diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py deleted file mode 100644 index 794659c94..000000000 --- a/networks/check_lora_weights.py +++ /dev/null @@ -1,48 +0,0 @@ -import argparse -import os -import torch -from safetensors.torch import load_file -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def main(file): - logger.info(f"loading: {file}") - if os.path.splitext(file)[1] == ".safetensors": - sd = load_file(file) - else: - sd = torch.load(file, map_location="cpu") - - values = [] - - keys = list(sd.keys()) - for key in keys: - if "lora_up" in key or "lora_down" in key: - values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") - - if args.show_all_keys: - for key in [k for k in keys if k not in values]: - values.append((key, sd[key])) - print(f"number of all modules: {len(values)}") - - for key, value in values: - value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") - parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - - main(args.file) diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py deleted file mode 100644 index c9377bee8..000000000 --- a/networks/control_net_lllite.py +++ /dev/null @@ -1,449 +0,0 @@ -import os -from typing import Optional, List, Type -import torch -from library import sdxl_original_unet -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# input_blocksに適用するかどうか / if True, input_blocks are not applied -SKIP_INPUT_BLOCKS = False - -# output_blocksに適用するかどうか / if True, output_blocks are not applied -SKIP_OUTPUT_BLOCKS = True - -# conv2dに適用するかどうか / if True, conv2d are not applied -SKIP_CONV2D = False - -# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない -# if True, only transformer_blocks are applied, and ResBlocks are not applied -TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks - -# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. -ATTN1_2_ONLY = True - -# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified -ATTN_QKV_ONLY = True - -# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 -# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY -ATTN1_ETC_ONLY = False # True - -# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 -# max index of transformer_blocks. if None, apply to all transformer_blocks -TRANSFORMER_MAX_BLOCK_INDEX = None - - -class LLLiteModule(torch.nn.Module): - def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0): - super().__init__() - - self.is_conv2d = org_module.__class__.__name__ == "Conv2d" - self.lllite_name = name - self.cond_emb_dim = cond_emb_dim - self.org_module = [org_module] - self.dropout = dropout - self.multiplier = multiplier - - if self.is_conv2d: - in_dim = org_module.in_channels - else: - in_dim = org_module.in_features - - # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない - # conditioning1 embeds conditioning image. it is not called for each timestep - modules = [] - modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size - if depth == 1: - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) - elif depth == 2: - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) - elif depth == 3: - # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) - - self.conditioning1 = torch.nn.Sequential(*modules) - - # downで入力の次元数を削減する。LoRAにヒントを得ていることにする - # midでconditioning image embeddingと入力を結合する - # upで元の次元数に戻す - # これらはtimestepごとに呼ばれる - # reduce the number of input dimensions with down. inspired by LoRA - # combine conditioning image embedding and input with mid - # restore to the original dimension with up - # these are called for each timestep - - if self.is_conv2d: - self.down = torch.nn.Sequential( - torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), - torch.nn.ReLU(inplace=True), - ) - self.mid = torch.nn.Sequential( - torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), - torch.nn.ReLU(inplace=True), - ) - self.up = torch.nn.Sequential( - torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), - ) - else: - # midの前にconditioningをreshapeすること / reshape conditioning before mid - self.down = torch.nn.Sequential( - torch.nn.Linear(in_dim, mlp_dim), - torch.nn.ReLU(inplace=True), - ) - self.mid = torch.nn.Sequential( - torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), - torch.nn.ReLU(inplace=True), - ) - self.up = torch.nn.Sequential( - torch.nn.Linear(mlp_dim, in_dim), - ) - - # Zero-Convにする / set to Zero-Conv - torch.nn.init.zeros_(self.up[0].weight) # zero conv - - self.depth = depth # 1~3 - self.cond_emb = None - self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference - self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 - - # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない - # Controlの種類によっては使えるかも - # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice - # it may be available depending on the type of Control - - def set_cond_image(self, cond_image): - r""" - 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む - / call the model inside, so if necessary, surround it with torch.no_grad() - """ - if cond_image is None: - self.cond_emb = None - return - - # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") - cx = self.conditioning1(cond_image) - if not self.is_conv2d: - # reshape / b,c,h,w -> b,h*w,c - n, c, h, w = cx.shape - cx = cx.view(n, c, h * w).permute(0, 2, 1) - self.cond_emb = cx - - def set_batch_cond_only(self, cond_only, zeros): - self.batch_cond_only = cond_only - self.use_zeros_for_batch_uncond = zeros - - def apply_to(self): - self.org_forward = self.org_module[0].forward - self.org_module[0].forward = self.forward - - def forward(self, x): - r""" - 学習用の便利forward。元のモジュールのforwardを呼び出す - / convenient forward for training. call the forward of the original module - """ - if self.multiplier == 0.0 or self.cond_emb is None: - return self.org_forward(x) - - cx = self.cond_emb - - if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only - cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) - if self.use_zeros_for_batch_uncond: - cx[0::2] = 0.0 # uncond is zero - # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") - - # downで入力の次元数を削減し、conditioning image embeddingと結合する - # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している - # down reduces the number of input dimensions and combines it with conditioning image embedding - # we expect that it will mix well by combining in the channel direction instead of adding - - cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2) - cx = self.mid(cx) - - if self.dropout is not None and self.training: - cx = torch.nn.functional.dropout(cx, p=self.dropout) - - cx = self.up(cx) * self.multiplier - - # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward - if self.batch_cond_only: - zx = torch.zeros_like(x) - zx[1::2] += cx - cx = zx - - x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here - return x - - -class ControlNetLLLite(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - - def __init__( - self, - unet: sdxl_original_unet.SdxlUNet2DConditionModel, - cond_emb_dim: int = 16, - mlp_dim: int = 16, - dropout: Optional[float] = None, - varbose: Optional[bool] = False, - multiplier: Optional[float] = 1.0, - ) -> None: - super().__init__() - # self.unets = [unet] - - def create_modules( - root_module: torch.nn.Module, - target_replace_modules: List[torch.nn.Module], - module_class: Type[object], - ) -> List[torch.nn.Module]: - prefix = "lllite_unet" - - modules = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - - if is_linear or (is_conv2d and not SKIP_CONV2D): - # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う - # block index to depth: depth is using to calculate conditioning size and channels - block_name, index1, index2 = (name + "." + child_name).split(".")[:3] - index1 = int(index1) - if block_name == "input_blocks": - if SKIP_INPUT_BLOCKS: - continue - depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) - elif block_name == "middle_block": - depth = 3 - elif block_name == "output_blocks": - if SKIP_OUTPUT_BLOCKS: - continue - depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) - if int(index2) >= 2: - depth -= 1 - else: - raise NotImplementedError() - - lllite_name = prefix + "." + name + "." + child_name - lllite_name = lllite_name.replace(".", "_") - - if TRANSFORMER_MAX_BLOCK_INDEX is not None: - p = lllite_name.find("transformer_blocks") - if p >= 0: - tf_index = int(lllite_name[p:].split("_")[2]) - if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: - continue - - # time embは適用外とする - # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない - # time emb is not applied - # attn2 conditioning (input from CLIP) cannot be applied because the shape is different - if "emb_layers" in lllite_name or ( - "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) - ): - continue - - if ATTN1_2_ONLY: - if not ("attn1" in lllite_name or "attn2" in lllite_name): - continue - if ATTN_QKV_ONLY: - if "to_out" in lllite_name: - continue - - if ATTN1_ETC_ONLY: - if "proj_out" in lllite_name: - pass - elif "attn1" in lllite_name and ( - "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name - ): - pass - elif "ff_net_2" in lllite_name: - pass - else: - continue - - module = module_class( - depth, - cond_emb_dim, - lllite_name, - child_module, - mlp_dim, - dropout=dropout, - multiplier=multiplier, - ) - modules.append(module) - return modules - - target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE - if not TRANSFORMER_ONLY: - target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - # create module instances - self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) - logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") - - def forward(self, x): - return x # dummy - - def set_cond_image(self, cond_image): - r""" - 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む - / call the model inside, so if necessary, surround it with torch.no_grad() - """ - for module in self.unet_modules: - module.set_cond_image(cond_image) - - def set_batch_cond_only(self, cond_only, zeros): - for module in self.unet_modules: - module.set_batch_cond_only(cond_only, zeros) - - def set_multiplier(self, multiplier): - for module in self.unet_modules: - module.multiplier = multiplier - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - info = self.load_state_dict(weights_sd, False) - return info - - def apply_to(self): - logger.info("applying LLLite for U-Net...") - for module in self.unet_modules: - module.apply_to() - self.add_module(module.lllite_name, module) - - # マージできるかどうかを返す - def is_mergeable(self): - return False - - def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - raise NotImplementedError() - - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_optimizer_params(self): - self.requires_grad_(True) - return self.parameters() - - def prepare_grad_etc(self): - self.requires_grad_(True) - - def on_epoch_start(self): - self.train() - - def get_trainable_params(self): - return self.parameters() - - def save_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None - - state_dict = self.state_dict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - -if __name__ == "__main__": - # デバッグ用 / for debug - - # sdxl_original_unet.USE_REENTRANT = False - - # test shape etc - logger.info("create unet") - unet = sdxl_original_unet.SdxlUNet2DConditionModel() - unet.to("cuda").to(torch.float16) - - logger.info("create ControlNet-LLLite") - control_net = ControlNetLLLite(unet, 32, 64) - control_net.apply_to() - control_net.to("cuda") - - logger.info(control_net) - - # logger.info number of parameters - logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}") - - input() - - unet.set_use_memory_efficient_attention(True, False) - unet.set_gradient_checkpointing(True) - unet.train() # for gradient checkpointing - - control_net.train() - - # # visualize - # import torchviz - # logger.info("run visualize") - # controlnet.set_control(conditioning_image) - # output = unet(x, t, ctx, y) - # logger.info("make_dot") - # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # logger.info("render") - # image.format = "svg" # "png" - # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time - # input() - - import bitsandbytes - - optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3) - - scaler = torch.cuda.amp.GradScaler(enabled=True) - - logger.info("start training") - steps = 10 - - sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] - for step in range(steps): - logger.info(f"step {step}") - - batch_size = 1 - conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 - x = torch.randn(batch_size, 4, 128, 128).cuda() - t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() - ctx = torch.randn(batch_size, 77, 2048).cuda() - y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() - - with torch.cuda.amp.autocast(enabled=True): - control_net.set_cond_image(conditioning_image) - - output = unet(x, t, ctx, y) - target = torch.randn_like(output) - loss = torch.nn.functional.mse_loss(output, target) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - logger.info(f"{sample_param}") - - # from safetensors.torch import save_file - - # save_file(control_net.state_dict(), "logs/control_net.safetensors") diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py deleted file mode 100644 index 65b3520cf..000000000 --- a/networks/control_net_lllite_for_train.py +++ /dev/null @@ -1,505 +0,0 @@ -# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装 -# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward - -import os -import re -from typing import Optional, List, Type -import torch -from library import sdxl_original_unet -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# input_blocksに適用するかどうか / if True, input_blocks are not applied -SKIP_INPUT_BLOCKS = False - -# output_blocksに適用するかどうか / if True, output_blocks are not applied -SKIP_OUTPUT_BLOCKS = True - -# conv2dに適用するかどうか / if True, conv2d are not applied -SKIP_CONV2D = False - -# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない -# if True, only transformer_blocks are applied, and ResBlocks are not applied -TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks - -# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. -ATTN1_2_ONLY = True - -# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified -ATTN_QKV_ONLY = True - -# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 -# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY -ATTN1_ETC_ONLY = False # True - -# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 -# max index of transformer_blocks. if None, apply to all transformer_blocks -TRANSFORMER_MAX_BLOCK_INDEX = None - -ORIGINAL_LINEAR = torch.nn.Linear -ORIGINAL_CONV2D = torch.nn.Conv2d - - -def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None: - # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない - # conditioning1 embeds conditioning image. it is not called for each timestep - modules = [] - modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size - if depth == 1: - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) - elif depth == 2: - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) - elif depth == 3: - # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) - modules.append(torch.nn.ReLU(inplace=True)) - modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) - - module.lllite_conditioning1 = torch.nn.Sequential(*modules) - - # downで入力の次元数を削減する。LoRAにヒントを得ていることにする - # midでconditioning image embeddingと入力を結合する - # upで元の次元数に戻す - # これらはtimestepごとに呼ばれる - # reduce the number of input dimensions with down. inspired by LoRA - # combine conditioning image embedding and input with mid - # restore to the original dimension with up - # these are called for each timestep - - module.lllite_down = torch.nn.Sequential( - ORIGINAL_LINEAR(in_dim, mlp_dim), - torch.nn.ReLU(inplace=True), - ) - module.lllite_mid = torch.nn.Sequential( - ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim), - torch.nn.ReLU(inplace=True), - ) - module.lllite_up = torch.nn.Sequential( - ORIGINAL_LINEAR(mlp_dim, in_dim), - ) - - # Zero-Convにする / set to Zero-Conv - torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv - - -class LLLiteLinear(ORIGINAL_LINEAR): - def __init__(self, in_features: int, out_features: int, **kwargs): - super().__init__(in_features, out_features, **kwargs) - self.enabled = False - - def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0): - self.enabled = True - self.lllite_name = name - self.cond_emb_dim = cond_emb_dim - self.dropout = dropout - self.multiplier = multiplier # ignored - - in_dim = self.in_features - add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) - - self.cond_image = None - self.cond_emb = None - - def set_cond_image(self, cond_image): - self.cond_image = cond_image - self.cond_emb = None - - def forward(self, x): - if not self.enabled: - return super().forward(x) - - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb - - # reshape / b,c,h,w -> b,h*w,c - n, c, h, w = cx.shape - cx = cx.view(n, c, h * w).permute(0, 2, 1) - - cx = torch.cat([cx, self.lllite_down(x)], dim=2) - cx = self.lllite_mid(cx) - - if self.dropout is not None and self.training: - cx = torch.nn.functional.dropout(cx, p=self.dropout) - - cx = self.lllite_up(cx) * self.multiplier - - x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here - return x - - -class LLLiteConv2d(ORIGINAL_CONV2D): - def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs): - super().__init__(in_channels, out_channels, kernel_size, **kwargs) - self.enabled = False - - def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0): - self.enabled = True - self.lllite_name = name - self.cond_emb_dim = cond_emb_dim - self.dropout = dropout - self.multiplier = multiplier # ignored - - in_dim = self.in_channels - add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) - - self.cond_image = None - self.cond_emb = None - - def set_cond_image(self, cond_image): - self.cond_image = cond_image - self.cond_emb = None - - def forward(self, x): # , cond_image=None): - if not self.enabled: - return super().forward(x) - - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb - - cx = torch.cat([cx, self.down(x)], dim=1) - cx = self.mid(cx) - - if self.dropout is not None and self.training: - cx = torch.nn.functional.dropout(cx, p=self.dropout) - - cx = self.up(cx) * self.multiplier - - x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here - return x - - -class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - LLLITE_PREFIX = "lllite_unet" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def apply_lllite( - self, - cond_emb_dim: int = 16, - mlp_dim: int = 16, - dropout: Optional[float] = None, - varbose: Optional[bool] = False, - multiplier: Optional[float] = 1.0, - ) -> None: - def apply_to_modules( - root_module: torch.nn.Module, - target_replace_modules: List[torch.nn.Module], - ) -> List[torch.nn.Module]: - prefix = "lllite_unet" - - modules = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "LLLiteLinear" - is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d" - - if is_linear or (is_conv2d and not SKIP_CONV2D): - # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う - # block index to depth: depth is using to calculate conditioning size and channels - block_name, index1, index2 = (name + "." + child_name).split(".")[:3] - index1 = int(index1) - if block_name == "input_blocks": - if SKIP_INPUT_BLOCKS: - continue - depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) - elif block_name == "middle_block": - depth = 3 - elif block_name == "output_blocks": - if SKIP_OUTPUT_BLOCKS: - continue - depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) - if int(index2) >= 2: - depth -= 1 - else: - raise NotImplementedError() - - lllite_name = prefix + "." + name + "." + child_name - lllite_name = lllite_name.replace(".", "_") - - if TRANSFORMER_MAX_BLOCK_INDEX is not None: - p = lllite_name.find("transformer_blocks") - if p >= 0: - tf_index = int(lllite_name[p:].split("_")[2]) - if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: - continue - - # time embは適用外とする - # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない - # time emb is not applied - # attn2 conditioning (input from CLIP) cannot be applied because the shape is different - if "emb_layers" in lllite_name or ( - "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) - ): - continue - - if ATTN1_2_ONLY: - if not ("attn1" in lllite_name or "attn2" in lllite_name): - continue - if ATTN_QKV_ONLY: - if "to_out" in lllite_name: - continue - - if ATTN1_ETC_ONLY: - if "proj_out" in lllite_name: - pass - elif "attn1" in lllite_name and ( - "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name - ): - pass - elif "ff_net_2" in lllite_name: - pass - else: - continue - - child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier) - modules.append(child_module) - - return modules - - target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE - if not TRANSFORMER_ONLY: - target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - # create module instances - self.lllite_modules = apply_to_modules(self, target_modules) - logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") - - # def prepare_optimizer_params(self): - def prepare_params(self): - train_params = [] - non_train_params = [] - for name, p in self.named_parameters(): - if "lllite" in name: - train_params.append(p) - else: - non_train_params.append(p) - logger.info(f"count of trainable parameters: {len(train_params)}") - logger.info(f"count of non-trainable parameters: {len(non_train_params)}") - - for p in non_train_params: - p.requires_grad_(False) - - # without this, an error occurs in the optimizer - # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn - non_train_params[0].requires_grad_(True) - - for p in train_params: - p.requires_grad_(True) - - return train_params - - # def prepare_grad_etc(self): - # self.requires_grad_(True) - - # def on_epoch_start(self): - # self.train() - - def get_trainable_params(self): - return [p[1] for p in self.named_parameters() if "lllite" in p[0]] - - def save_lllite_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None - - org_state_dict = self.state_dict() - - # copy LLLite keys from org_state_dict to state_dict with key conversion - state_dict = {} - for key in org_state_dict.keys(): - # split with ".lllite" - pos = key.find(".lllite") - if pos < 0: - continue - lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos] - lllite_key = lllite_key.replace(".", "_") + key[pos:] - lllite_key = lllite_key.replace(".lllite_", ".") - state_dict[lllite_key] = org_state_dict[key] - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - def load_lllite_weights(self, file, non_lllite_unet_sd=None): - r""" - LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。 - この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。 - - If you do not want to load LLLite weights (use initialized values), specify None for file. - In this case, specify the state_dict of U-Net for non_lllite_unet_sd. - """ - if not file: - state_dict = self.state_dict() - for key in non_lllite_unet_sd: - if key in state_dict: - state_dict[key] = non_lllite_unet_sd[key] - info = self.load_state_dict(state_dict, False) - return info - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - # module_name = module_name.replace("_block", "@blocks") - # module_name = module_name.replace("_layer", "@layer") - # module_name = module_name.replace("to_", "to@") - # module_name = module_name.replace("time_embed", "time@embed") - # module_name = module_name.replace("label_emb", "label@emb") - # module_name = module_name.replace("skip_connection", "skip@connection") - # module_name = module_name.replace("proj_in", "proj@in") - # module_name = module_name.replace("proj_out", "proj@out") - pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)") - - # convert to lllite with U-Net state dict - state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {} - for key in weights_sd.keys(): - # split with "." - pos = key.find(".") - if pos < 0: - continue - - module_name = key[:pos] - weight_name = key[pos + 1 :] # exclude "." - module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "") - - # これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion - # module_name = module_name.replace("_", ".") - - # ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@" - matches = pattern.findall(module_name) - if matches is not None: - for m in matches: - logger.info(f"{module_name} {m}") - module_name = module_name.replace(m, m.replace("_", "@")) - module_name = module_name.replace("_", ".") - module_name = module_name.replace("@", "_") - - lllite_key = module_name + ".lllite_" + weight_name - - state_dict[lllite_key] = weights_sd[key] - - info = self.load_state_dict(state_dict, False) - return info - - def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs): - for m in self.lllite_modules: - m.set_cond_image(cond_image) - return super().forward(x, timesteps, context, y, **kwargs) - - -def replace_unet_linear_and_conv2d(): - logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") - sdxl_original_unet.torch.nn.Linear = LLLiteLinear - sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d - - -if __name__ == "__main__": - # デバッグ用 / for debug - - # sdxl_original_unet.USE_REENTRANT = False - replace_unet_linear_and_conv2d() - - # test shape etc - logger.info("create unet") - unet = SdxlUNet2DConditionModelControlNetLLLite() - - logger.info("enable ControlNet-LLLite") - unet.apply_lllite(32, 64, None, False, 1.0) - unet.to("cuda") # .to(torch.float16) - - # from safetensors.torch import load_file - - # model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors") - # unet_sd = {} - - # # copy U-Net keys from unet_state_dict to state_dict - # prefix = "model.diffusion_model." - # for key in model_sd.keys(): - # if key.startswith(prefix): - # converted_key = key[len(prefix) :] - # unet_sd[converted_key] = model_sd[key] - - # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) - # logger.info(info) - - # logger.info(unet) - - # logger.info number of parameters - params = unet.prepare_params() - logger.info(f"number of parameters {sum(p.numel() for p in params)}") - # logger.info("type any key to continue") - # input() - - unet.set_use_memory_efficient_attention(True, False) - unet.set_gradient_checkpointing(True) - unet.train() # for gradient checkpointing - - # # visualize - # import torchviz - # logger.info("run visualize") - # controlnet.set_control(conditioning_image) - # output = unet(x, t, ctx, y) - # logger.info("make_dot") - # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) - # logger.info("render") - # image.format = "svg" # "png" - # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time - # input() - - import bitsandbytes - - optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3) - - scaler = torch.cuda.amp.GradScaler(enabled=True) - - logger.info("start training") - steps = 10 - batch_size = 1 - - sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] - for step in range(steps): - logger.info(f"step {step}") - - conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 - x = torch.randn(batch_size, 4, 128, 128).cuda() - t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() - ctx = torch.randn(batch_size, 77, 2048).cuda() - y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() - - with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): - output = unet(x, t, ctx, y, conditioning_image) - target = torch.randn_like(output) - loss = torch.nn.functional.mse_loss(output, target) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - logger.info(sample_param) - - # from safetensors.torch import save_file - - # logger.info("save weights") - # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) diff --git a/networks/dylora.py b/networks/dylora.py deleted file mode 100644 index d71279c55..000000000 --- a/networks/dylora.py +++ /dev/null @@ -1,478 +0,0 @@ -# some codes are copied from: -# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/ - -# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. -# Changes made to the original code: -# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer -# ------------------------------------------------------------------------------------------ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. -# ------------------------------------------------------------------------------------------ - -import math -import os -import random -from typing import Dict, List, Optional, Tuple, Type, Union -from diffusers import AutoencoderKL -from transformers import CLIPTextModel -import torch -from torch import nn -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -class DyLoRAModule(torch.nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ - - # NOTE: support dropout in future - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1): - super().__init__() - self.lora_name = lora_name - self.lora_dim = lora_dim - self.unit = unit - assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit" - - if org_module.__class__.__name__ == "Conv2d": - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える - - self.is_conv2d = org_module.__class__.__name__ == "Conv2d" - self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3) - - if self.is_conv2d and self.is_conv2d_3x3: - kernel_size = org_module.kernel_size - self.stride = org_module.stride - self.padding = org_module.padding - self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)]) - self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)]) - else: - self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)]) - self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)]) - - # same as microsoft's - for lora in self.lora_A: - torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5)) - for lora in self.lora_B: - torch.nn.init.zeros_(lora) - - self.multiplier = multiplier - self.org_module = org_module # remove in applying - - def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module - - def forward(self, x): - result = self.org_forward(x) - - # specify the dynamic rank - trainable_rank = random.randint(0, self.lora_dim - 1) - trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit - - # 一部のパラメータを固定して、残りのパラメータを学習する - for i in range(0, trainable_rank): - self.lora_A[i].requires_grad = False - self.lora_B[i].requires_grad = False - for i in range(trainable_rank, trainable_rank + self.unit): - self.lora_A[i].requires_grad = True - self.lora_B[i].requires_grad = True - for i in range(trainable_rank + self.unit, self.lora_dim): - self.lora_A[i].requires_grad = False - self.lora_B[i].requires_grad = False - - lora_A = torch.cat(tuple(self.lora_A), dim=0) - lora_B = torch.cat(tuple(self.lora_B), dim=1) - - # calculate with lora_A and lora_B - if self.is_conv2d_3x3: - ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding) - ab = torch.nn.functional.conv2d(ab, lora_B) - else: - ab = x - if self.is_conv2d: - ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) - - ab = torch.nn.functional.linear(ab, lora_A) - ab = torch.nn.functional.linear(ab, lora_B) - - if self.is_conv2d: - ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W) - - # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな) - result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) - - # NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも - return result - - def state_dict(self, destination=None, prefix="", keep_vars=False): - # state dictを通常のLoRAと同じにする: - # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える - sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - - lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) - if self.is_conv2d and not self.is_conv2d_3x3: - lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) - - lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) - if self.is_conv2d and not self.is_conv2d_3x3: - lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) - - sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() - sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() - - i = 0 - while True: - key_a = f"{self.lora_name}.lora_A.{i}" - key_b = f"{self.lora_name}.lora_B.{i}" - if key_a in sd: - sd.pop(key_a) - sd.pop(key_b) - else: - break - i += 1 - return sd - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた - lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) - lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) - - if lora_A_weight is None or lora_B_weight is None: - if strict: - raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") - else: - return - - if self.is_conv2d and not self.is_conv2d_3x3: - lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) - lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) - - state_dict.update( - {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))} - ) - state_dict.update( - {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))} - ) - - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - -def create_network( - multiplier: float, - network_dim: Optional[int], - network_alpha: Optional[float], - vae: AutoencoderKL, - text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], - unet, - **kwargs, -): - if network_dim is None: - network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 - - # extract dim/alpha for conv2d, and block dim - conv_dim = kwargs.get("conv_dim", None) - conv_alpha = kwargs.get("conv_alpha", None) - unit = kwargs.get("unit", None) - if conv_dim is not None: - conv_dim = int(conv_dim) - assert conv_dim == network_dim, "conv_dim must be same as network_dim" - if conv_alpha is None: - conv_alpha = 1.0 - else: - conv_alpha = float(conv_alpha) - - if unit is not None: - unit = int(unit) - else: - unit = 1 - - network = DyLoRANetwork( - text_encoder, - unet, - multiplier=multiplier, - lora_dim=network_dim, - alpha=network_alpha, - apply_to_conv=conv_dim is not None, - unit=unit, - varbose=True, - ) - return network - - -# Create network from weights for inference, weights are not loaded here (because can be merged) -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): - if weights_sd is None: - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file, safe_open - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - # get dim/alpha mapping - modules_dim = {} - modules_alpha = {} - for key, value in weights_sd.items(): - if "." not in key: - continue - - lora_name = key.split(".")[0] - if "alpha" in key: - modules_alpha[lora_name] = value - elif "lora_down" in key: - dim = value.size()[0] - modules_dim[lora_name] = dim - # logger.info(f"{lora_name} {value.size()} {dim}") - - # support old LoRA without alpha - for key in modules_dim.keys(): - if key not in modules_alpha: - modules_alpha = modules_dim[key] - - module_class = DyLoRAModule - - network = DyLoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class - ) - return network, weights_sd - - -class DyLoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_UNET = "lora_unet" - LORA_PREFIX_TEXT_ENCODER = "lora_te" - - def __init__( - self, - text_encoder, - unet, - multiplier=1.0, - lora_dim=4, - alpha=1, - apply_to_conv=False, - modules_dim=None, - modules_alpha=None, - unit=1, - module_class=DyLoRAModule, - varbose=False, - ) -> None: - super().__init__() - self.multiplier = multiplier - - self.lora_dim = lora_dim - self.alpha = alpha - self.apply_to_conv = apply_to_conv - - if modules_dim is not None: - logger.info("create LoRA network from weights") - else: - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") - if self.apply_to_conv: - logger.info("apply LoRA to Conv2d with kernel size (3,3).") - - # create module instances - def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: - prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER - loras = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - - if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - - dim = None - alpha = None - if modules_dim is not None: - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - else: - if is_linear or is_conv2d_1x1 or apply_to_conv: - dim = self.lora_dim - alpha = self.alpha - - if dim is None or dim == 0: - continue - - # dropout and fan_in_fan_out is default - lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) - loras.append(lora) - return loras - - text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - - self.text_encoder_loras = [] - for i, text_encoder in enumerate(text_encoders): - if len(text_encoders) > 1: - index = i + 1 - print(f"create LoRA for Text Encoder {index}") - else: - index = None - print(f"create LoRA for Text Encoder") - - text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - self.text_encoder_loras.extend(text_encoder_loras) - - # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE - if modules_dim is not None or self.apply_to_conv: - target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - self.unet_loras = create_modules(True, unet, target_modules) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - - def set_multiplier(self, multiplier): - self.multiplier = multiplier - for lora in self.text_encoder_loras + self.unet_loras: - lora.multiplier = self.multiplier - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - info = self.load_state_dict(weights_sd, False) - return info - - def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - logger.info("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: - lora.apply_to() - self.add_module(lora.lora_name, lora) - - """ - def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - apply_text_encoder = apply_unet = False - for key in weights_sd.keys(): - if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER): - apply_text_encoder = True - elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET): - apply_unet = True - - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - logger.info("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: - sd_for_lora = {} - for key in weights_sd.keys(): - if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] - lora.merge_to(sd_for_lora, dtype, device) - - logger.info(f"weights are merged") - """ - - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - self.requires_grad_(True) - all_params = [] - - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params - - if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) - - if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - - return all_params - - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_grad_etc(self, text_encoder, unet): - self.requires_grad_(True) - - def on_epoch_start(self, text_encoder, unet): - self.train() - - def get_trainable_params(self): - return self.parameters() - - def save_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None - - state_dict = self.state_dict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - from library import train_util - - # Precalculate model hashes to save time on indexing - if metadata is None: - metadata = {} - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - # mask is a tensor with values from 0 to 1 - def set_region(self, sub_prompt_index, is_last_network, mask): - pass - - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): - pass diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py deleted file mode 100644 index 1184cd8a5..000000000 --- a/networks/extract_lora_from_dylora.py +++ /dev/null @@ -1,128 +0,0 @@ -# Convert LoRA to different rank approximation (should only be used to go to lower rank) -# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo - -import argparse -import math -import os -import torch -from safetensors.torch import load_file, save_file, safe_open -from tqdm import tqdm -from library import train_util, model_util -import numpy as np -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def load_state_dict(file_name): - if model_util.is_safetensors(file_name): - sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() - else: - sd = torch.load(file_name, map_location="cpu") - metadata = None - - return sd, metadata - - -def save_to_file(file_name, model, metadata): - if model_util.is_safetensors(file_name): - save_file(model, file_name, metadata) - else: - torch.save(model, file_name) - - -def split_lora_model(lora_sd, unit): - max_rank = 0 - - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if "lora_down" in key: - rank = value.size()[0] - if rank > max_rank: - max_rank = rank - logger.info(f"Max rank: {max_rank}") - - rank = unit - split_models = [] - new_alpha = None - while rank < max_rank: - logger.info(f"Splitting rank {rank}") - new_sd = {} - for key, value in lora_sd.items(): - if "lora_down" in key: - new_sd[key] = value[:rank].contiguous() - elif "lora_up" in key: - new_sd[key] = value[:, :rank].contiguous() - else: - # なぜかscaleするとおかしくなる…… - # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] - # scale = math.sqrt(this_rank / rank) # rank is > unit - # logger.info(key, value.size(), this_rank, rank, value, scale) - # new_alpha = value * scale # always same - # new_sd[key] = new_alpha - new_sd[key] = value - - split_models.append((new_sd, rank, new_alpha)) - rank += unit - - return max_rank, split_models - - -def split(args): - logger.info("loading Model...") - lora_sd, metadata = load_state_dict(args.model) - - logger.info("Splitting Model...") - original_rank, split_models = split_lora_model(lora_sd, args.unit) - - comment = metadata.get("ss_training_comment", "") - for state_dict, new_rank, new_alpha in split_models: - # update metadata - if metadata is None: - new_metadata = {} - else: - new_metadata = metadata.copy() - - new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" - new_metadata["ss_network_dim"] = str(new_rank) - # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) - - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - filename, ext = os.path.splitext(args.save_to) - model_file_name = filename + f"-{new_rank:04d}{ext}" - - logger.info(f"saving model to: {model_file_name}") - save_to_file(model_file_name, state_dict, new_metadata) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") - parser.add_argument( - "--save_to", - type=str, - default=None, - help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", - ) - parser.add_argument( - "--model", - type=str, - default=None, - help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - split(args) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py deleted file mode 100644 index 43c1d0058..000000000 --- a/networks/extract_lora_from_models.py +++ /dev/null @@ -1,360 +0,0 @@ -# extract approximating LoRA by svd from two SD models -# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo! - -import argparse -import json -import os -import time -import torch -from safetensors.torch import load_file, save_file -from tqdm import tqdm -from library import sai_model_spec, model_util, sdxl_model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# CLAMP_QUANTILE = 0.99 -# MIN_DIFF = 1e-1 - - -def save_to_file(file_name, model, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if os.path.splitext(file_name)[1] == ".safetensors": - save_file(model, file_name) - else: - torch.save(model, file_name) - - -def svd( - model_org=None, - model_tuned=None, - save_to=None, - dim=4, - v2=None, - sdxl=None, - conv_dim=None, - v_parameterization=None, - device=None, - save_precision=None, - clamp_quantile=0.99, - min_diff=0.01, - no_metadata=False, - load_precision=None, - load_original_model_to=None, - load_tuned_model_to=None, -): - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - - assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if v_parameterization is None: - v_parameterization = v2 - - load_dtype = str_to_dtype(load_precision) if load_precision else None - save_dtype = str_to_dtype(save_precision) - work_device = "cpu" - - # load models - if not sdxl: - logger.info(f"loading original SD model : {model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) - text_encoders_o = [text_encoder_o] - if load_dtype is not None: - text_encoder_o = text_encoder_o.to(load_dtype) - unet_o = unet_o.to(load_dtype) - - logger.info(f"loading tuned SD model : {model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) - text_encoders_t = [text_encoder_t] - if load_dtype is not None: - text_encoder_t = text_encoder_t.to(load_dtype) - unet_t = unet_t.to(load_dtype) - - model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) - else: - device_org = load_original_model_to if load_original_model_to else "cpu" - device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" - - logger.info(f"loading original SDXL model : {model_org}") - text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org - ) - text_encoders_o = [text_encoder_o1, text_encoder_o2] - if load_dtype is not None: - text_encoder_o1 = text_encoder_o1.to(load_dtype) - text_encoder_o2 = text_encoder_o2.to(load_dtype) - unet_o = unet_o.to(load_dtype) - - logger.info(f"loading original SDXL model : {model_tuned}") - text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned - ) - text_encoders_t = [text_encoder_t1, text_encoder_t2] - if load_dtype is not None: - text_encoder_t1 = text_encoder_t1.to(load_dtype) - text_encoder_t2 = text_encoder_t2.to(load_dtype) - unet_t = unet_t.to(load_dtype) - - model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 - - # create LoRA network to extract weights: Use dim (rank) as alpha - if conv_dim is None: - kwargs = {} - else: - kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - - lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) - assert len(lora_network_o.text_encoder_loras) == len( - lora_network_t.text_encoder_loras - ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " - - # get diffs - diffs = {} - text_encoder_different = False - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): - lora_name = lora_o.lora_name - module_o = lora_o.org_module - module_t = lora_t.org_module - diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - - # clear weight to save memory - module_o.weight = None - module_t.weight = None - - # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: - text_encoder_different = True - logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") - - diffs[lora_name] = diff - - # clear target Text Encoder to save memory - for text_encoder in text_encoders_t: - del text_encoder - - if not text_encoder_different: - logger.warning("Text encoder is same. Extract U-Net only.") - lora_network_o.text_encoder_loras = [] - diffs = {} # clear diffs - - for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): - lora_name = lora_o.lora_name - module_o = lora_o.org_module - module_t = lora_t.org_module - diff = module_t.weight.to(work_device) - module_o.weight.to(work_device) - - # clear weight to save memory - module_o.weight = None - module_t.weight = None - - diffs[lora_name] = diff - - # clear LoRA network, target U-Net to save memory - del lora_network_o - del lora_network_t - del unet_t - - # make LoRA with svd - logger.info("calculating by svd") - lora_weights = {} - with torch.no_grad(): - for lora_name, mat in tqdm(list(diffs.items())): - if args.device: - mat = mat.to(args.device) - mat = mat.to(torch.float) # calc by float - - # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 - conv2d = len(mat.size()) == 4 - kernel_size = None if not conv2d else mat.size()[2:4] - conv2d_3x3 = conv2d and kernel_size != (1, 1) - - rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim - out_dim, in_dim = mat.size()[0:2] - - if device: - mat = mat.to(device) - - # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) - rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim - - if conv2d: - if conv2d_3x3: - mat = mat.flatten(start_dim=1) - else: - mat = mat.squeeze() - - U, S, Vh = torch.linalg.svd(mat) - - U = U[:, :rank] - S = S[:rank] - U = U @ torch.diag(S) - - Vh = Vh[:rank, :] - - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, clamp_quantile) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.reshape(out_dim, rank, 1, 1) - Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) - - U = U.to(work_device, dtype=save_dtype).contiguous() - Vh = Vh.to(work_device, dtype=save_dtype).contiguous() - - lora_weights[lora_name] = (U, Vh) - - # make state dict for LoRA - lora_sd = {} - for lora_name, (up_weight, down_weight) in lora_weights.items(): - lora_sd[lora_name + ".lora_up.weight"] = up_weight - lora_sd[lora_name + ".lora_down.weight"] = down_weight - lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) - - # load state dict to LoRA and save it - lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) - lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict - - info = lora_network_save.load_state_dict(lora_sd) - logger.info(f"Loading extracted LoRA weights: {info}") - - dir_name = os.path.dirname(save_to) - if dir_name and not os.path.exists(dir_name): - os.makedirs(dir_name, exist_ok=True) - - # minimum metadata - net_kwargs = {} - if conv_dim is not None: - net_kwargs["conv_dim"] = str(conv_dim) - net_kwargs["conv_alpha"] = str(float(conv_dim)) - - metadata = { - "ss_v2": str(v2), - "ss_base_model_version": model_version, - "ss_network_module": "networks.lora", - "ss_network_dim": str(dim), - "ss_network_alpha": str(float(dim)), - "ss_network_args": json.dumps(net_kwargs), - } - - if not no_metadata: - title = os.path.splitext(os.path.basename(save_to))[0] - sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) - metadata.update(sai_metadata) - - lora_network_save.save_weights(save_to, save_dtype, metadata) - logger.info(f"LoRA weights are saved to: {save_to}") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") - parser.add_argument( - "--v_parameterization", - action="store_true", - default=None, - help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", - ) - parser.add_argument( - "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" - ) - parser.add_argument( - "--load_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる" - ) - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", - ) - parser.add_argument( - "--model_org", - type=str, - default=None, - required=True, - help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", - ) - parser.add_argument( - "--model_tuned", - type=str, - default=None, - required=True, - help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", - ) - parser.add_argument( - "--save_to", - type=str, - default=None, - required=True, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", - ) - parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") - parser.add_argument( - "--conv_dim", - type=int, - default=None, - help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", - ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument( - "--clamp_quantile", - type=float, - default=0.99, - help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", - ) - parser.add_argument( - "--min_diff", - type=float, - default=0.01, - help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" - + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", - ) - parser.add_argument( - "--no_metadata", - action="store_true", - help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " - + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", - ) - parser.add_argument( - "--load_original_model_to", - type=str, - default=None, - help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", - ) - parser.add_argument( - "--load_tuned_model_to", - type=str, - default=None, - help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - svd(**vars(args)) diff --git a/networks/lora.py b/networks/lora.py deleted file mode 100644 index cc09747bd..000000000 --- a/networks/lora.py +++ /dev/null @@ -1,1253 +0,0 @@ -# LoRA network module -# reference: -# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py -# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py - -import math -import os -from typing import Dict, List, Optional, Tuple, Type, Union -from diffusers import AutoencoderKL -from transformers import CLIPTextModel -import numpy as np -import torch -import re -from library.utils import setup_logging - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") - -RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") - - -class LoRAModule(torch.nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ - - def __init__( - self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, - ): - """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() - self.lora_name = lora_name - - if org_module.__class__.__name__ == "Conv2d": - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features - - # if limit_rank: - # self.lora_dim = min(lora_dim, in_dim, out_dim) - # if self.lora_dim != lora_dim: - # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") - # else: - self.lora_dim = lora_dim - - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える - - # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - - self.multiplier = multiplier - self.org_module = org_module # remove in applying - self.dropout = dropout - self.rank_dropout = rank_dropout - self.module_dropout = module_dropout - - def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module - - def forward(self, x): - org_forwarded = self.org_forward(x) - - # module dropout - if self.module_dropout is not None and self.training: - if torch.rand(1) < self.module_dropout: - return org_forwarded - - lx = self.lora_down(x) - - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask - - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability - else: - scale = self.scale - - lx = self.lora_up(lx) - - return org_forwarded + lx * self.multiplier * scale - - -class LoRAInfModule(LoRAModule): - def __init__( - self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - **kwargs, - ): - # no dropout for inference - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) - - self.org_module_ref = [org_module] # 後から参照できるように - self.enabled = True - - # check regional or not by lora_name - self.text_encoder = False - if lora_name.startswith("lora_te_"): - self.regional = False - self.use_sub_prompt = True - self.text_encoder = True - elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: - self.regional = False - self.use_sub_prompt = True - elif "time_emb" in lora_name: - self.regional = False - self.use_sub_prompt = False - else: - self.regional = True - self.use_sub_prompt = False - - self.network: LoRANetwork = None - - def set_network(self, network): - self.network = network - - # freezeしてマージする - def merge_to(self, sd, dtype, device): - # get up/down weight - up_weight = sd["lora_up.weight"].to(torch.float).to(device) - down_weight = sd["lora_down.weight"].to(torch.float).to(device) - - # extract weight from org_module - org_sd = self.org_module.state_dict() - weight = org_sd["weight"].to(torch.float) - - # merge weight - if len(weight.size()) == 2: - # linear - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + self.multiplier * conved * self.scale - - # set weight to org_module - org_sd["weight"] = weight.to(dtype) - self.org_module.load_state_dict(org_sd) - - # 復元できるマージのため、このモジュールのweightを返す - def get_weight(self, multiplier=None): - if multiplier is None: - multiplier = self.multiplier - - # get up/down weight from module - up_weight = self.lora_up.weight.to(torch.float) - down_weight = self.lora_down.weight.to(torch.float) - - # pre-calculated weight - if len(down_weight.size()) == 2: - # linear - weight = self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - weight = self.multiplier * conved * self.scale - - return weight - - def set_region(self, region): - self.region = region - self.region_mask = None - - def default_forward(self, x): - # logger.info(f"default_forward {self.lora_name} {x.size()}") - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - - def forward(self, x): - if not self.enabled: - return self.org_forward(x) - - if self.network is None or self.network.sub_prompt_index is None: - return self.default_forward(x) - if not self.regional and not self.use_sub_prompt: - return self.default_forward(x) - - if self.regional: - return self.regional_forward(x) - else: - return self.sub_prompt_forward(x) - - def get_mask_for_x(self, x): - # calculate size from shape of x - if len(x.size()) == 4: - h, w = x.size()[2:4] - area = h * w - else: - area = x.size()[1] - - mask = self.network.mask_dic.get(area, None) - if mask is None: - # raise ValueError(f"mask is None for resolution {area}") - # emb_layers in SDXL doesn't have mask - # if "emb" not in self.lora_name: - # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") - mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) - return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts - if len(x.size()) != 4: - mask = torch.reshape(mask, (1, -1, 1)) - return mask - - def regional_forward(self, x): - if "attn2_to_out" in self.lora_name: - return self.to_out_forward(x) - - if self.network.mask_dic is None: # sub_prompt_index >= 3 - return self.default_forward(x) - - # apply mask for LoRA result - lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - mask = self.get_mask_for_x(lx) - # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) - # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) - # mask = mask.squeeze(-1) - lx = lx * mask - - x = self.org_forward(x) - x = x + lx - - if "attn2_to_q" in self.lora_name and self.network.is_last_network: - x = self.postp_to_q(x) - - return x - - def postp_to_q(self, x): - # repeat x to num_sub_prompts - has_real_uncond = x.size()[0] // self.network.batch_size == 3 - qc = self.network.batch_size # uncond - qc += self.network.batch_size * self.network.num_sub_prompts # cond - if has_real_uncond: - qc += self.network.batch_size # real_uncond - - query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) - query[: self.network.batch_size] = x[: self.network.batch_size] - - for i in range(self.network.batch_size): - qi = self.network.batch_size + i * self.network.num_sub_prompts - query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] - - if has_real_uncond: - query[-self.network.batch_size :] = x[-self.network.batch_size :] - - # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") - return query - - def sub_prompt_forward(self, x): - if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA - return self.org_forward(x) - - emb_idx = self.network.sub_prompt_index - if not self.text_encoder: - emb_idx += self.network.batch_size - - # apply sub prompt of X - lx = x[emb_idx :: self.network.num_sub_prompts] - lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - - # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") - - x = self.org_forward(x) - x[emb_idx :: self.network.num_sub_prompts] += lx - - return x - - def to_out_forward(self, x): - # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") - - if self.network.is_last_network: - masks = [None] * self.network.num_sub_prompts - self.network.shared[self.lora_name] = (None, masks) - else: - lx, masks = self.network.shared[self.lora_name] - - # call own LoRA - x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] - lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale - - if self.network.is_last_network: - lx = torch.zeros( - (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype - ) - self.network.shared[self.lora_name] = (lx, masks) - - # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") - lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 - masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) - - # if not last network, return x and masks - x = self.org_forward(x) - if not self.network.is_last_network: - return x - - lx, masks = self.network.shared.pop(self.lora_name) - - # if last network, combine separated x with mask weighted sum - has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 - - out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) - out[: self.network.batch_size] = x[: self.network.batch_size] # uncond - if has_real_uncond: - out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - - # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") - # if num_sub_prompts > num of LoRAs, fill with zero - for i in range(len(masks)): - if masks[i] is None: - masks[i] = torch.zeros_like(masks[0]) - - mask = torch.cat(masks) - mask_sum = torch.sum(mask, dim=0) + 1e-4 - for i in range(self.network.batch_size): - # 1枚の画像ごとに処理する - lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] - lx1 = lx1 * mask - lx1 = torch.sum(lx1, dim=0) - - xi = self.network.batch_size + i * self.network.num_sub_prompts - x1 = x[xi : xi + self.network.num_sub_prompts] - x1 = x1 * mask - x1 = torch.sum(x1, dim=0) - x1 = x1 / mask_sum - - x1 = x1 + lx1 - out[self.network.batch_size + i] = x1 - - # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") - return out - - -def parse_block_lr_kwargs(nw_kwargs): - down_lr_weight = nw_kwargs.get("down_lr_weight", None) - mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) - up_lr_weight = nw_kwargs.get("up_lr_weight", None) - - # 以上のいずれにも設定がない場合は無効としてNoneを返す - if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: - return None, None, None - - # extract learning rate weight for each block - if down_lr_weight is not None: - # if some parameters are not set, use zero - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) - ) - - return down_lr_weight, mid_lr_weight, up_lr_weight - - -def create_network( - multiplier: float, - network_dim: Optional[int], - network_alpha: Optional[float], - vae: AutoencoderKL, - text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], - unet, - neuron_dropout: Optional[float] = None, - **kwargs, -): - if network_dim is None: - network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 - - # extract dim/alpha for conv2d, and block dim - conv_dim = kwargs.get("conv_dim", None) - conv_alpha = kwargs.get("conv_alpha", None) - if conv_dim is not None: - conv_dim = int(conv_dim) - if conv_alpha is None: - conv_alpha = 1.0 - else: - conv_alpha = float(conv_alpha) - - # block dim/alpha/lr - block_dims = kwargs.get("block_dims", None) - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - - # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする - if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: - block_alphas = kwargs.get("block_alphas", None) - conv_block_dims = kwargs.get("conv_block_dims", None) - conv_block_alphas = kwargs.get("conv_block_alphas", None) - - block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha - ) - - # remove block dim/alpha without learning rate - block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight - ) - - else: - block_alphas = None - conv_block_dims = None - conv_block_alphas = None - - # rank/module dropout - rank_dropout = kwargs.get("rank_dropout", None) - if rank_dropout is not None: - rank_dropout = float(rank_dropout) - module_dropout = kwargs.get("module_dropout", None) - if module_dropout is not None: - module_dropout = float(module_dropout) - - # すごく引数が多いな ( ^ω^)・・・ - network = LoRANetwork( - text_encoder, - unet, - multiplier=multiplier, - lora_dim=network_dim, - alpha=network_alpha, - dropout=neuron_dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - conv_lora_dim=conv_dim, - conv_alpha=conv_alpha, - block_dims=block_dims, - block_alphas=block_alphas, - conv_block_dims=conv_block_dims, - conv_block_alphas=conv_block_alphas, - varbose=True, - ) - - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - - return network - - -# このメソッドは外部から呼び出される可能性を考慮しておく -# network_dim, network_alpha にはデフォルト値が入っている。 -# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている -# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている -def get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha -): - num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 - - def parse_ints(s): - return [int(i) for i in s.split(",")] - - def parse_floats(s): - return [float(i) for i in s.split(",")] - - # block_dimsとblock_alphasをパースする。必ず値が入る - if block_dims is not None: - block_dims = parse_ints(block_dims) - assert ( - len(block_dims) == num_total_blocks - ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" - else: - logger.warning( - f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" - ) - block_dims = [network_dim] * num_total_blocks - - if block_alphas is not None: - block_alphas = parse_floats(block_alphas) - assert ( - len(block_alphas) == num_total_blocks - ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" - else: - logger.warning( - f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" - ) - block_alphas = [network_alpha] * num_total_blocks - - # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う - if conv_block_dims is not None: - conv_block_dims = parse_ints(conv_block_dims) - assert ( - len(conv_block_dims) == num_total_blocks - ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" - - if conv_block_alphas is not None: - conv_block_alphas = parse_floats(conv_block_alphas) - assert ( - len(conv_block_alphas) == num_total_blocks - ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" - else: - if conv_alpha is None: - conv_alpha = 1.0 - logger.warning( - f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" - ) - conv_block_alphas = [conv_alpha] * num_total_blocks - else: - if conv_dim is not None: - logger.warning( - f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" - ) - conv_block_dims = [conv_dim] * num_total_blocks - conv_block_alphas = [conv_alpha] * num_total_blocks - else: - conv_block_dims = None - conv_block_alphas = None - - return block_dims, block_alphas, conv_block_dims, conv_block_alphas - - -# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく -def get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold -) -> Tuple[List[float], List[float], List[float]]: - # パラメータ未指定時は何もせず、今までと同じ動作とする - if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return None, None, None - - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 - - def get_list(name_with_suffix) -> List[float]: - import math - - tokens = name_with_suffix.split("+") - name = tokens[0] - base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 - - if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] - elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] - elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] - elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] - elif name == "zeros": - return [0.0 + base_lr] * max_len - else: - logger.error( - "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" - % (name) - ) - return None - - if type(down_lr_weight) == str: - down_lr_weight = get_list(down_lr_weight) - if type(up_lr_weight) == str: - up_lr_weight = get_list(up_lr_weight) - - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] - - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) - - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) - - if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - logger.info("apply block learning rate / 階層別学習率を適用します。") - if down_lr_weight != None: - down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") - else: - logger.info("down_lr_weight: all 1.0, すべて1.0") - - if mid_lr_weight != None: - mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - logger.info(f"mid_lr_weight: {mid_lr_weight}") - else: - logger.info("mid_lr_weight: 1.0") - - if up_lr_weight != None: - up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") - else: - logger.info("up_lr_weight: all 1.0, すべて1.0") - - return down_lr_weight, mid_lr_weight, up_lr_weight - - -# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく -def remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight -): - # set 0 to block dim without learning rate to remove the block - if down_lr_weight != None: - for i, lr in enumerate(down_lr_weight): - if lr == 0: - block_dims[i] = 0 - if conv_block_dims is not None: - conv_block_dims[i] = 0 - if mid_lr_weight != None: - if mid_lr_weight == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if up_lr_weight != None: - for i, lr in enumerate(up_lr_weight): - if lr == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - - return block_dims, block_alphas, conv_block_dims, conv_block_alphas - - -# 外部から呼び出す可能性を考慮しておく -def get_block_index(lora_name: str) -> int: - block_idx = -1 # invalid lora name - - m = RE_UPDOWN.search(lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 - - return block_idx - - -# Create network from weights for inference, weights are not loaded here (because can be merged) -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): - if weights_sd is None: - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file, safe_open - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - # get dim/alpha mapping - modules_dim = {} - modules_alpha = {} - for key, value in weights_sd.items(): - if "." not in key: - continue - - lora_name = key.split(".")[0] - if "alpha" in key: - modules_alpha[lora_name] = value - elif "lora_down" in key: - dim = value.size()[0] - modules_dim[lora_name] = dim - # logger.info(lora_name, value.size(), dim) - - # support old LoRA without alpha - for key in modules_dim.keys(): - if key not in modules_alpha: - modules_alpha[key] = modules_dim[key] - - module_class = LoRAInfModule if for_inference else LoRAModule - - network = LoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class - ) - - # block lr - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - - return network, weights_sd - - -class LoRANetwork(torch.nn.Module): - NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 - - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_UNET = "lora_unet" - LORA_PREFIX_TEXT_ENCODER = "lora_te" - - # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER - LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" - LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" - - def __init__( - self, - text_encoder: Union[List[CLIPTextModel], CLIPTextModel], - unet, - multiplier: float = 1.0, - lora_dim: int = 4, - alpha: float = 1, - dropout: Optional[float] = None, - rank_dropout: Optional[float] = None, - module_dropout: Optional[float] = None, - conv_lora_dim: Optional[int] = None, - conv_alpha: Optional[float] = None, - block_dims: Optional[List[int]] = None, - block_alphas: Optional[List[float]] = None, - conv_block_dims: Optional[List[int]] = None, - conv_block_alphas: Optional[List[float]] = None, - modules_dim: Optional[Dict[str, int]] = None, - modules_alpha: Optional[Dict[str, int]] = None, - module_class: Type[object] = LoRAModule, - varbose: Optional[bool] = False, - ) -> None: - """ - LoRA network: すごく引数が多いが、パターンは以下の通り - 1. lora_dimとalphaを指定 - 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 - 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない - 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する - 5. modules_dimとmodules_alphaを指定 (推論用) - """ - super().__init__() - self.multiplier = multiplier - - self.lora_dim = lora_dim - self.alpha = alpha - self.conv_lora_dim = conv_lora_dim - self.conv_alpha = conv_alpha - self.dropout = dropout - self.rank_dropout = rank_dropout - self.module_dropout = module_dropout - - if modules_dim is not None: - logger.info(f"create LoRA network from weights") - elif block_dims is not None: - logger.info(f"create LoRA network from block_dims") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) - logger.info(f"block_dims: {block_dims}") - logger.info(f"block_alphas: {block_alphas}") - if conv_block_dims is not None: - logger.info(f"conv_block_dims: {conv_block_dims}") - logger.info(f"conv_block_alphas: {conv_block_alphas}") - else: - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) - if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) - - # create module instances - def create_modules( - is_unet: bool, - text_encoder_idx: Optional[int], # None, 1, 2 - root_module: torch.nn.Module, - target_replace_modules: List[torch.nn.Module], - ) -> List[LoRAModule]: - prefix = ( - self.LORA_PREFIX_UNET - if is_unet - else ( - self.LORA_PREFIX_TEXT_ENCODER - if text_encoder_idx is None - else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) - ) - ) - loras = [] - skipped = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - - if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - - dim = None - alpha = None - - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - elif is_unet and block_dims is not None: - # U-Netでblock_dims指定あり - block_idx = get_block_index(lora_name) - if is_linear or is_conv2d_1x1: - dim = block_dims[block_idx] - alpha = block_alphas[block_idx] - elif conv_block_dims is not None: - dim = conv_block_dims[block_idx] - alpha = conv_block_alphas[block_idx] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha - - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): - skipped.append(lora_name) - continue - - lora = module_class( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - ) - loras.append(lora) - return loras, skipped - - text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - - # create LoRA for text encoder - # 毎回すべてのモジュールを作るのは無駄なので要検討 - self.text_encoder_loras = [] - skipped_te = [] - for i, text_encoder in enumerate(text_encoders): - if len(text_encoders) > 1: - index = i + 1 - logger.info(f"create LoRA for Text Encoder {index}:") - else: - index = None - logger.info(f"create LoRA for Text Encoder:") - - text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - self.text_encoder_loras.extend(text_encoder_loras) - skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE - if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: - target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - - skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: - logger.warning( - f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" - ) - for name in skipped: - logger.info(f"\t{name}") - - self.up_lr_weight: List[float] = None - self.down_lr_weight: List[float] = None - self.mid_lr_weight: float = None - self.block_lr = False - - # assertion - names = set() - for lora in self.text_encoder_loras + self.unet_loras: - assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" - names.add(lora.lora_name) - - def set_multiplier(self, multiplier): - self.multiplier = multiplier - for lora in self.text_encoder_loras + self.unet_loras: - lora.multiplier = self.multiplier - - def set_enabled(self, is_enabled): - for lora in self.text_encoder_loras + self.unet_loras: - lora.enabled = is_enabled - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - info = self.load_state_dict(weights_sd, False) - return info - - def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - logger.info("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: - lora.apply_to() - self.add_module(lora.lora_name, lora) - - # マージできるかどうかを返す - def is_mergeable(self): - return True - - # TODO refactor to common function with apply_to - def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - apply_text_encoder = apply_unet = False - for key in weights_sd.keys(): - if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): - apply_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): - apply_unet = True - - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - logger.info("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: - sd_for_lora = {} - for key in weights_sd.keys(): - if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] - lora.merge_to(sd_for_lora, dtype, device) - - logger.info(f"weights are merged") - - # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない - def set_block_lr_weight( - self, - up_lr_weight: List[float] = None, - mid_lr_weight: float = None, - down_lr_weight: List[float] = None, - ): - self.block_lr = True - self.down_lr_weight = down_lr_weight - self.mid_lr_weight = mid_lr_weight - self.up_lr_weight = up_lr_weight - - def get_lr_weight(self, lora: LoRAModule) -> float: - lr_weight = 1.0 - block_idx = get_block_index(lora.lora_name) - if block_idx < 0: - return lr_weight - - if block_idx < LoRANetwork.NUM_OF_BLOCKS: - if self.down_lr_weight != None: - lr_weight = self.down_lr_weight[block_idx] - elif block_idx == LoRANetwork.NUM_OF_BLOCKS: - if self.mid_lr_weight != None: - lr_weight = self.mid_lr_weight - elif block_idx > LoRANetwork.NUM_OF_BLOCKS: - if self.up_lr_weight != None: - lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] - - return lr_weight - - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - self.requires_grad_(True) - all_params = [] - - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params - - if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) - - if self.unet_loras: - if self.block_lr: - # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - block_idx_to_lora = {} - for lora in self.unet_loras: - idx = get_block_index(lora.lora_name) - if idx not in block_idx_to_lora: - block_idx_to_lora[idx] = [] - block_idx_to_lora[idx].append(lora) - - # blockごとにパラメータを設定する - for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) - elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) - - else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - - return all_params - - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_grad_etc(self, text_encoder, unet): - self.requires_grad_(True) - - def on_epoch_start(self, text_encoder, unet): - self.train() - - def get_trainable_params(self): - return self.parameters() - - def save_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None - - state_dict = self.state_dict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - from library import train_util - - # Precalculate model hashes to save time on indexing - if metadata is None: - metadata = {} - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - # mask is a tensor with values from 0 to 1 - def set_region(self, sub_prompt_index, is_last_network, mask): - if mask.max() == 0: - mask = torch.ones_like(mask) - - self.mask = mask - self.sub_prompt_index = sub_prompt_index - self.is_last_network = is_last_network - - for lora in self.text_encoder_loras + self.unet_loras: - lora.set_network(self) - - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): - self.batch_size = batch_size - self.num_sub_prompts = num_sub_prompts - self.current_size = (height, width) - self.shared = shared - - # create masks - mask = self.mask - mask_dic = {} - mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w - ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight - dtype = ref_weight.dtype - device = ref_weight.device - - def resize_add(mh, mw): - # logger.info(mh, mw, mh * mw) - m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 - m = m.to(device, dtype=dtype) - mask_dic[mh * mw] = m - - h = height // 8 - w = width // 8 - for _ in range(4): - resize_add(h, w) - if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 - resize_add(h + h % 2, w + w % 2) - - # deep shrink - if ds_ratio is not None: - hd = int(h * ds_ratio) - wd = int(w * ds_ratio) - resize_add(hd, wd) - - h = (h + 1) // 2 - w = (w + 1) // 2 - - self.mask_dic = mask_dic - - def backup_weights(self): - # 重みのバックアップを行う - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras - for lora in loras: - org_module = lora.org_module_ref[0] - if not hasattr(org_module, "_lora_org_weight"): - sd = org_module.state_dict() - org_module._lora_org_weight = sd["weight"].detach().clone() - org_module._lora_restored = True - - def restore_weights(self): - # 重みのリストアを行う - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras - for lora in loras: - org_module = lora.org_module_ref[0] - if not org_module._lora_restored: - sd = org_module.state_dict() - sd["weight"] = org_module._lora_org_weight - org_module.load_state_dict(sd) - org_module._lora_restored = True - - def pre_calculation(self): - # 事前計算を行う - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras - for lora in loras: - org_module = lora.org_module_ref[0] - sd = org_module.state_dict() - - org_weight = sd["weight"] - lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) - sd["weight"] = org_weight + lora_weight - assert sd["weight"].shape == org_weight.shape - org_module.load_state_dict(sd) - - org_module._lora_restored = False - lora.enabled = False - - def apply_max_norm_regularization(self, max_norm_value, device): - downkeys = [] - upkeys = [] - alphakeys = [] - norms = [] - keys_scaled = 0 - - state_dict = self.state_dict() - for key in state_dict.keys(): - if "lora_down" in key and "weight" in key: - downkeys.append(key) - upkeys.append(key.replace("lora_down", "lora_up")) - alphakeys.append(key.replace("lora_down.weight", "alpha")) - - for i in range(len(downkeys)): - down = state_dict[downkeys[i]].to(device) - up = state_dict[upkeys[i]].to(device) - alpha = state_dict[alphakeys[i]].to(device) - dim = down.shape[0] - scale = alpha / dim - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - - updown *= scale - - norm = updown.norm().clamp(min=max_norm_value / 2) - desired = torch.clamp(norm, max=max_norm_value) - ratio = desired.cpu() / norm.cpu() - sqrt_ratio = ratio**0.5 - if ratio != 1: - keys_scaled += 1 - state_dict[upkeys[i]] *= sqrt_ratio - state_dict[downkeys[i]] *= sqrt_ratio - scalednorm = updown.norm() * ratio - norms.append(scalednorm.item()) - - return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py deleted file mode 100644 index b99b02442..000000000 --- a/networks/lora_diffusers.py +++ /dev/null @@ -1,616 +0,0 @@ -# Diffusersで動くLoRA。このファイル単独で完結する。 -# LoRA module for Diffusers. This file works independently. - -import bisect -import math -import random -from typing import Any, Dict, List, Mapping, Optional, Union -from diffusers import UNet2DConditionModel -import numpy as np -from tqdm import tqdm -from transformers import CLIPTextModel - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def make_unet_conversion_map() -> Dict[str, str]: - unet_conversion_map_layer = [] - - for i in range(3): # num_blocks is 3 in sdxl - # loop over downblocks/upblocks - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - # if i > 0: commentout for sdxl - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0.", "norm1."), - ("in_layers.2.", "conv1."), - ("out_layers.0.", "norm2."), - ("out_layers.3.", "conv2."), - ("emb_layers.1.", "time_emb_proj."), - ("skip_connection.", "conv_shortcut."), - ] - - unet_conversion_map = [] - for sd, hf in unet_conversion_map_layer: - if "resnets" in hf: - for sd_res, hf_res in unet_conversion_map_resnet: - unet_conversion_map.append((sd + sd_res, hf + hf_res)) - else: - unet_conversion_map.append((sd, hf)) - - for j in range(2): - hf_time_embed_prefix = f"time_embedding.linear_{j+1}." - sd_time_embed_prefix = f"time_embed.{j*2}." - unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) - - for j in range(2): - hf_label_embed_prefix = f"add_embedding.linear_{j+1}." - sd_label_embed_prefix = f"label_emb.0.{j*2}." - unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) - - unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) - unet_conversion_map.append(("out.0.", "conv_norm_out.")) - unet_conversion_map.append(("out.2.", "conv_out.")) - - sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} - return sd_hf_conversion_map - - -UNET_CONVERSION_MAP = make_unet_conversion_map() - - -class LoRAModule(torch.nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ - - def __init__( - self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - ): - """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() - self.lora_name = lora_name - - if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv": - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features - - self.lora_dim = lora_dim - - if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation - - # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - - self.multiplier = multiplier - self.org_module = [org_module] - self.enabled = True - self.network: LoRANetwork = None - self.org_forward = None - - # override org_module's forward method - def apply_to(self, multiplier=None): - if multiplier is not None: - self.multiplier = multiplier - if self.org_forward is None: - self.org_forward = self.org_module[0].forward - self.org_module[0].forward = self.forward - - # restore org_module's forward method - def unapply_to(self): - if self.org_forward is not None: - self.org_module[0].forward = self.org_forward - - # forward with lora - # scale is used LoRACompatibleConv, but we ignore it because we have multiplier - def forward(self, x, scale=1.0): - if not self.enabled: - return self.org_forward(x) - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - - def set_network(self, network): - self.network = network - - # merge lora weight to org weight - def merge_to(self, multiplier=1.0): - # get lora weight - lora_weight = self.get_weight(multiplier) - - # get org weight - org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype) - - # set weight to org_module - org_sd["weight"] = weight - self.org_module[0].load_state_dict(org_sd) - - # restore org weight from lora weight - def restore_from(self, multiplier=1.0): - # get lora weight - lora_weight = self.get_weight(multiplier) - - # get org weight - org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype) - - # set weight to org_module - org_sd["weight"] = weight - self.org_module[0].load_state_dict(org_sd) - - # return lora weight - def get_weight(self, multiplier=None): - if multiplier is None: - multiplier = self.multiplier - - # get up/down weight from module - up_weight = self.lora_up.weight.to(torch.float) - down_weight = self.lora_down.weight.to(torch.float) - - # pre-calculated weight - if len(down_weight.size()) == 2: - # linear - weight = self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - weight = self.multiplier * conved * self.scale - - return weight - - -# Create network from weights for inference, weights are not loaded here -def create_network_from_weights( - text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0 -): - # get dim/alpha mapping - modules_dim = {} - modules_alpha = {} - for key, value in weights_sd.items(): - if "." not in key: - continue - - lora_name = key.split(".")[0] - if "alpha" in key: - modules_alpha[lora_name] = value - elif "lora_down" in key: - dim = value.size()[0] - modules_dim[lora_name] = dim - # logger.info(f"{lora_name} {value.size()} {dim}") - - # support old LoRA without alpha - for key in modules_dim.keys(): - if key not in modules_alpha: - modules_alpha[key] = modules_dim[key] - - return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) - - -def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0): - text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder] - unet = pipe.unet - - lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier) - lora_network.load_state_dict(weights_sd) - lora_network.merge_to(multiplier=multiplier) - - -# block weightや学習に対応しない簡易版 / simple version without block weight and training -class LoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_UNET = "lora_unet" - LORA_PREFIX_TEXT_ENCODER = "lora_te" - - # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER - LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" - LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" - - def __init__( - self, - text_encoder: Union[List[CLIPTextModel], CLIPTextModel], - unet: UNet2DConditionModel, - multiplier: float = 1.0, - modules_dim: Optional[Dict[str, int]] = None, - modules_alpha: Optional[Dict[str, int]] = None, - varbose: Optional[bool] = False, - ) -> None: - super().__init__() - self.multiplier = multiplier - - logger.info("create LoRA network from weights") - - # convert SDXL Stability AI's U-Net modules to Diffusers - converted = self.convert_unet_modules(modules_dim, modules_alpha) - if converted: - logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") - - # create module instances - def create_modules( - is_unet: bool, - text_encoder_idx: Optional[int], # None, 1, 2 - root_module: torch.nn.Module, - target_replace_modules: List[torch.nn.Module], - ) -> List[LoRAModule]: - prefix = ( - self.LORA_PREFIX_UNET - if is_unet - else ( - self.LORA_PREFIX_TEXT_ENCODER - if text_encoder_idx is None - else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) - ) - ) - loras = [] - skipped = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = ( - child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" - ) - is_conv2d = ( - child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" - ) - - if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - - if lora_name not in modules_dim: - # logger.info(f"skipped {lora_name} (not found in modules_dim)") - skipped.append(lora_name) - continue - - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - lora = LoRAModule( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - ) - loras.append(lora) - return loras, skipped - - text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - - # create LoRA for text encoder - # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider - self.text_encoder_loras: List[LoRAModule] = [] - skipped_te = [] - for i, text_encoder in enumerate(text_encoders): - if len(text_encoders) > 1: - index = i + 1 - else: - index = None - - text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - self.text_encoder_loras.extend(text_encoder_loras) - skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - if len(skipped_te) > 0: - logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") - - # extend U-Net target modules to include Conv2d 3x3 - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - self.unet_loras: List[LoRAModule] - self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - if len(skipped_un) > 0: - logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") - - # assertion - names = set() - for lora in self.text_encoder_loras + self.unet_loras: - names.add(lora.lora_name) - for lora_name in modules_dim.keys(): - assert lora_name in names, f"{lora_name} is not found in created LoRA modules." - - # make to work load_state_dict - for lora in self.text_encoder_loras + self.unet_loras: - self.add_module(lora.lora_name, lora) - - # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers - def convert_unet_modules(self, modules_dim, modules_alpha): - converted_count = 0 - not_converted_count = 0 - - map_keys = list(UNET_CONVERSION_MAP.keys()) - map_keys.sort() - - for key in list(modules_dim.keys()): - if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): - search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "") - position = bisect.bisect_right(map_keys, search_key) - map_key = map_keys[position - 1] - if search_key.startswith(map_key): - new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key]) - modules_dim[new_key] = modules_dim[key] - modules_alpha[new_key] = modules_alpha[key] - del modules_dim[key] - del modules_alpha[key] - converted_count += 1 - else: - not_converted_count += 1 - assert ( - converted_count == 0 or not_converted_count == 0 - ), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted" - return converted_count - - def set_multiplier(self, multiplier): - self.multiplier = multiplier - for lora in self.text_encoder_loras + self.unet_loras: - lora.multiplier = self.multiplier - - def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - for lora in self.text_encoder_loras: - lora.apply_to(multiplier) - if apply_unet: - logger.info("enable LoRA for U-Net") - for lora in self.unet_loras: - lora.apply_to(multiplier) - - def unapply_to(self): - for lora in self.text_encoder_loras + self.unet_loras: - lora.unapply_to() - - def merge_to(self, multiplier=1.0): - logger.info("merge LoRA weights to original weights") - for lora in tqdm(self.text_encoder_loras + self.unet_loras): - lora.merge_to(multiplier) - logger.info(f"weights are merged") - - def restore_from(self, multiplier=1.0): - logger.info("restore LoRA weights from original weights") - for lora in tqdm(self.text_encoder_loras + self.unet_loras): - lora.restore_from(multiplier) - logger.info(f"weights are restored") - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - # convert SDXL Stability AI's state dict to Diffusers' based state dict - map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules - map_keys.sort() - for key in list(state_dict.keys()): - if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"): - search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "") - position = bisect.bisect_right(map_keys, search_key) - map_key = map_keys[position - 1] - if search_key.startswith(map_key): - new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key]) - state_dict[new_key] = state_dict[key] - del state_dict[key] - - # in case of V2, some weights have different shape, so we need to convert them - # because V2 LoRA is based on U-Net created by use_linear_projection=False - my_state_dict = self.state_dict() - for key in state_dict.keys(): - if state_dict[key].size() != my_state_dict[key].size(): - # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") - state_dict[key] = state_dict[key].view(my_state_dict[key].size()) - - return super().load_state_dict(state_dict, strict) - - -if __name__ == "__main__": - # sample code to use LoRANetwork - import os - import argparse - from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline - import torch - - device = get_preferred_device() - - parser = argparse.ArgumentParser() - parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") - parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights") - parser.add_argument("--sdxl", action="store_true", help="use SDXL model") - parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text") - parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text") - parser.add_argument("--seed", type=int, default=0, help="random seed") - args = parser.parse_args() - - image_prefix = args.model_id.replace("/", "_") + "_" - - # load Diffusers model - logger.info(f"load model from {args.model_id}") - pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] - if args.sdxl: - # use_safetensors=True does not work with 0.18.2 - pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16) - else: - pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16) - pipe.to(device) - pipe.set_use_memory_efficient_attention_xformers(True) - - text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] - - # load LoRA weights - logger.info(f"load LoRA weights from {args.lora_weights}") - if os.path.splitext(args.lora_weights)[1] == ".safetensors": - from safetensors.torch import load_file - - lora_sd = load_file(args.lora_weights) - else: - lora_sd = torch.load(args.lora_weights) - - # create by LoRA weights and load weights - logger.info(f"create LoRA network") - lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) - - logger.info(f"load LoRA network weights") - lora_network.load_state_dict(lora_sd) - - lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this - - # 必要があれば、元のモデルの重みをバックアップしておく - # back-up unet/text encoder weights if necessary - def detach_and_move_to_cpu(state_dict): - for k, v in state_dict.items(): - state_dict[k] = v.detach().cpu() - return state_dict - - org_unet_sd = pipe.unet.state_dict() - detach_and_move_to_cpu(org_unet_sd) - - org_text_encoder_sd = pipe.text_encoder.state_dict() - detach_and_move_to_cpu(org_text_encoder_sd) - - if args.sdxl: - org_text_encoder_2_sd = pipe.text_encoder_2.state_dict() - detach_and_move_to_cpu(org_text_encoder_2_sd) - - def seed_everything(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - - # create image with original weights - logger.info(f"create image with original weights") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "original.png") - - # apply LoRA network to the model: slower than merge_to, but can be reverted easily - logger.info(f"apply LoRA network to the model") - lora_network.apply_to(multiplier=1.0) - - logger.info(f"create image with applied LoRA") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "applied_lora.png") - - # unapply LoRA network to the model - logger.info(f"unapply LoRA network to the model") - lora_network.unapply_to() - - logger.info(f"create image with unapplied LoRA") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "unapplied_lora.png") - - # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) - logger.info(f"merge LoRA network to the model") - lora_network.merge_to(multiplier=1.0) - - logger.info(f"create image with LoRA") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "merged_lora.png") - - # restore (unmerge) LoRA weights: numerically unstable - # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない - # 保存したstate_dictから元の重みを復元するのが確実 - logger.info(f"restore (unmerge) LoRA weights") - lora_network.restore_from(multiplier=1.0) - - logger.info(f"create image without LoRA") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "unmerged_lora.png") - - # restore original weights - logger.info(f"restore original weights") - pipe.unet.load_state_dict(org_unet_sd) - pipe.text_encoder.load_state_dict(org_text_encoder_sd) - if args.sdxl: - pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) - - logger.info(f"create image with restored original weights") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "restore_original.png") - - # use convenience function to merge LoRA weights - logger.info(f"merge LoRA weights with convenience function") - merge_lora_weights(pipe, lora_sd, multiplier=1.0) - - logger.info(f"create image with merged LoRA weights") - seed_everything(args.seed) - image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] - image.save(image_prefix + "convenience_merged_lora.png") diff --git a/networks/lora_fa.py b/networks/lora_fa.py deleted file mode 100644 index 919222ce8..000000000 --- a/networks/lora_fa.py +++ /dev/null @@ -1,1244 +0,0 @@ -# LoRA network module -# reference: -# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py -# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py - -# temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303 -# need to be refactored and merged to lora.py - -import math -import os -from typing import Dict, List, Optional, Tuple, Type, Union -from diffusers import AutoencoderKL -from transformers import CLIPTextModel -import numpy as np -import torch -import re -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") - - -class LoRAModule(torch.nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ - - def __init__( - self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, - ): - """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() - self.lora_name = lora_name - - if org_module.__class__.__name__ == "Conv2d": - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features - - # if limit_rank: - # self.lora_dim = min(lora_dim, in_dim, out_dim) - # if self.lora_dim != lora_dim: - # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") - # else: - self.lora_dim = lora_dim - - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える - - # # same as microsoft's - # torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - - # according to the paper, initialize LoRA-A (down) as normal distribution - torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim))) - - torch.nn.init.zeros_(self.lora_up.weight) - - self.multiplier = multiplier - self.org_module = org_module # remove in applying - self.dropout = dropout - self.rank_dropout = rank_dropout - self.module_dropout = module_dropout - - def get_trainable_params(self): - params = self.named_parameters() - trainable_params = [] - for param in params: - if param[0] == "lora_up.weight": # up only - trainable_params.append(param[1]) - return trainable_params - - def requires_grad_(self, requires_grad: bool = True): - self.lora_up.requires_grad_(requires_grad) - self.lora_down.requires_grad_(False) - return self - - def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module - - def forward(self, x): - org_forwarded = self.org_forward(x) - - # module dropout - if self.module_dropout is not None and self.training: - if torch.rand(1) < self.module_dropout: - return org_forwarded - - lx = self.lora_down(x) - - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask - - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability - else: - scale = self.scale - - lx = self.lora_up(lx) - - return org_forwarded + lx * self.multiplier * scale - - -class LoRAInfModule(LoRAModule): - def __init__( - self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - **kwargs, - ): - # no dropout for inference - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) - - self.org_module_ref = [org_module] # 後から参照できるように - self.enabled = True - - # check regional or not by lora_name - self.text_encoder = False - if lora_name.startswith("lora_te_"): - self.regional = False - self.use_sub_prompt = True - self.text_encoder = True - elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: - self.regional = False - self.use_sub_prompt = True - elif "time_emb" in lora_name: - self.regional = False - self.use_sub_prompt = False - else: - self.regional = True - self.use_sub_prompt = False - - self.network: LoRANetwork = None - - def set_network(self, network): - self.network = network - - # freezeしてマージする - def merge_to(self, sd, dtype, device): - # get up/down weight - up_weight = sd["lora_up.weight"].to(torch.float).to(device) - down_weight = sd["lora_down.weight"].to(torch.float).to(device) - - # extract weight from org_module - org_sd = self.org_module.state_dict() - weight = org_sd["weight"].to(torch.float) - - # merge weight - if len(weight.size()) == 2: - # linear - weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + self.multiplier * conved * self.scale - - # set weight to org_module - org_sd["weight"] = weight.to(dtype) - self.org_module.load_state_dict(org_sd) - - # 復元できるマージのため、このモジュールのweightを返す - def get_weight(self, multiplier=None): - if multiplier is None: - multiplier = self.multiplier - - # get up/down weight from module - up_weight = self.lora_up.weight.to(torch.float) - down_weight = self.lora_down.weight.to(torch.float) - - # pre-calculated weight - if len(down_weight.size()) == 2: - # linear - weight = self.multiplier * (up_weight @ down_weight) * self.scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - self.multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * self.scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - weight = self.multiplier * conved * self.scale - - return weight - - def set_region(self, region): - self.region = region - self.region_mask = None - - def default_forward(self, x): - # logger.info("default_forward", self.lora_name, x.size()) - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - - def forward(self, x): - if not self.enabled: - return self.org_forward(x) - - if self.network is None or self.network.sub_prompt_index is None: - return self.default_forward(x) - if not self.regional and not self.use_sub_prompt: - return self.default_forward(x) - - if self.regional: - return self.regional_forward(x) - else: - return self.sub_prompt_forward(x) - - def get_mask_for_x(self, x): - # calculate size from shape of x - if len(x.size()) == 4: - h, w = x.size()[2:4] - area = h * w - else: - area = x.size()[1] - - mask = self.network.mask_dic[area] - if mask is None: - raise ValueError(f"mask is None for resolution {area}") - if len(x.size()) != 4: - mask = torch.reshape(mask, (1, -1, 1)) - return mask - - def regional_forward(self, x): - if "attn2_to_out" in self.lora_name: - return self.to_out_forward(x) - - if self.network.mask_dic is None: # sub_prompt_index >= 3 - return self.default_forward(x) - - # apply mask for LoRA result - lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - mask = self.get_mask_for_x(lx) - # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) - lx = lx * mask - - x = self.org_forward(x) - x = x + lx - - if "attn2_to_q" in self.lora_name and self.network.is_last_network: - x = self.postp_to_q(x) - - return x - - def postp_to_q(self, x): - # repeat x to num_sub_prompts - has_real_uncond = x.size()[0] // self.network.batch_size == 3 - qc = self.network.batch_size # uncond - qc += self.network.batch_size * self.network.num_sub_prompts # cond - if has_real_uncond: - qc += self.network.batch_size # real_uncond - - query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) - query[: self.network.batch_size] = x[: self.network.batch_size] - - for i in range(self.network.batch_size): - qi = self.network.batch_size + i * self.network.num_sub_prompts - query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] - - if has_real_uncond: - query[-self.network.batch_size :] = x[-self.network.batch_size :] - - # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) - return query - - def sub_prompt_forward(self, x): - if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA - return self.org_forward(x) - - emb_idx = self.network.sub_prompt_index - if not self.text_encoder: - emb_idx += self.network.batch_size - - # apply sub prompt of X - lx = x[emb_idx :: self.network.num_sub_prompts] - lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale - - # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) - - x = self.org_forward(x) - x[emb_idx :: self.network.num_sub_prompts] += lx - - return x - - def to_out_forward(self, x): - # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) - - if self.network.is_last_network: - masks = [None] * self.network.num_sub_prompts - self.network.shared[self.lora_name] = (None, masks) - else: - lx, masks = self.network.shared[self.lora_name] - - # call own LoRA - x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] - lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale - - if self.network.is_last_network: - lx = torch.zeros( - (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype - ) - self.network.shared[self.lora_name] = (lx, masks) - - # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) - lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 - masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) - - # if not last network, return x and masks - x = self.org_forward(x) - if not self.network.is_last_network: - return x - - lx, masks = self.network.shared.pop(self.lora_name) - - # if last network, combine separated x with mask weighted sum - has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 - - out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) - out[: self.network.batch_size] = x[: self.network.batch_size] # uncond - if has_real_uncond: - out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond - - # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) - # for i in range(len(masks)): - # if masks[i] is None: - # masks[i] = torch.zeros_like(masks[-1]) - - mask = torch.cat(masks) - mask_sum = torch.sum(mask, dim=0) + 1e-4 - for i in range(self.network.batch_size): - # 1枚の画像ごとに処理する - lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] - lx1 = lx1 * mask - lx1 = torch.sum(lx1, dim=0) - - xi = self.network.batch_size + i * self.network.num_sub_prompts - x1 = x[xi : xi + self.network.num_sub_prompts] - x1 = x1 * mask - x1 = torch.sum(x1, dim=0) - x1 = x1 / mask_sum - - x1 = x1 + lx1 - out[self.network.batch_size + i] = x1 - - # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond) - return out - - -def parse_block_lr_kwargs(nw_kwargs): - down_lr_weight = nw_kwargs.get("down_lr_weight", None) - mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) - up_lr_weight = nw_kwargs.get("up_lr_weight", None) - - # 以上のいずれにも設定がない場合は無効としてNoneを返す - if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: - return None, None, None - - # extract learning rate weight for each block - if down_lr_weight is not None: - # if some parameters are not set, use zero - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) - ) - - return down_lr_weight, mid_lr_weight, up_lr_weight - - -def create_network( - multiplier: float, - network_dim: Optional[int], - network_alpha: Optional[float], - vae: AutoencoderKL, - text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], - unet, - neuron_dropout: Optional[float] = None, - **kwargs, -): - if network_dim is None: - network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 - - # extract dim/alpha for conv2d, and block dim - conv_dim = kwargs.get("conv_dim", None) - conv_alpha = kwargs.get("conv_alpha", None) - if conv_dim is not None: - conv_dim = int(conv_dim) - if conv_alpha is None: - conv_alpha = 1.0 - else: - conv_alpha = float(conv_alpha) - - # block dim/alpha/lr - block_dims = kwargs.get("block_dims", None) - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - - # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする - if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: - block_alphas = kwargs.get("block_alphas", None) - conv_block_dims = kwargs.get("conv_block_dims", None) - conv_block_alphas = kwargs.get("conv_block_alphas", None) - - block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha - ) - - # remove block dim/alpha without learning rate - block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight - ) - - else: - block_alphas = None - conv_block_dims = None - conv_block_alphas = None - - # rank/module dropout - rank_dropout = kwargs.get("rank_dropout", None) - if rank_dropout is not None: - rank_dropout = float(rank_dropout) - module_dropout = kwargs.get("module_dropout", None) - if module_dropout is not None: - module_dropout = float(module_dropout) - - # すごく引数が多いな ( ^ω^)・・・ - network = LoRANetwork( - text_encoder, - unet, - multiplier=multiplier, - lora_dim=network_dim, - alpha=network_alpha, - dropout=neuron_dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - conv_lora_dim=conv_dim, - conv_alpha=conv_alpha, - block_dims=block_dims, - block_alphas=block_alphas, - conv_block_dims=conv_block_dims, - conv_block_alphas=conv_block_alphas, - varbose=True, - ) - - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - - return network - - -# このメソッドは外部から呼び出される可能性を考慮しておく -# network_dim, network_alpha にはデフォルト値が入っている。 -# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている -# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている -def get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha -): - num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 - - def parse_ints(s): - return [int(i) for i in s.split(",")] - - def parse_floats(s): - return [float(i) for i in s.split(",")] - - # block_dimsとblock_alphasをパースする。必ず値が入る - if block_dims is not None: - block_dims = parse_ints(block_dims) - assert ( - len(block_dims) == num_total_blocks - ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" - else: - logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") - block_dims = [network_dim] * num_total_blocks - - if block_alphas is not None: - block_alphas = parse_floats(block_alphas) - assert ( - len(block_alphas) == num_total_blocks - ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" - else: - logger.warning( - f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" - ) - block_alphas = [network_alpha] * num_total_blocks - - # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う - if conv_block_dims is not None: - conv_block_dims = parse_ints(conv_block_dims) - assert ( - len(conv_block_dims) == num_total_blocks - ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" - - if conv_block_alphas is not None: - conv_block_alphas = parse_floats(conv_block_alphas) - assert ( - len(conv_block_alphas) == num_total_blocks - ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" - else: - if conv_alpha is None: - conv_alpha = 1.0 - logger.warning( - f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" - ) - conv_block_alphas = [conv_alpha] * num_total_blocks - else: - if conv_dim is not None: - logger.warning( - f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" - ) - conv_block_dims = [conv_dim] * num_total_blocks - conv_block_alphas = [conv_alpha] * num_total_blocks - else: - conv_block_dims = None - conv_block_alphas = None - - return block_dims, block_alphas, conv_block_dims, conv_block_alphas - - -# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく -def get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold -) -> Tuple[List[float], List[float], List[float]]: - # パラメータ未指定時は何もせず、今までと同じ動作とする - if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return None, None, None - - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 - - def get_list(name_with_suffix) -> List[float]: - import math - - tokens = name_with_suffix.split("+") - name = tokens[0] - base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 - - if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] - elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] - elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] - elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] - elif name == "zeros": - return [0.0 + base_lr] * max_len - else: - logger.error( - "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" - % (name) - ) - return None - - if type(down_lr_weight) == str: - down_lr_weight = get_list(down_lr_weight) - if type(up_lr_weight) == str: - up_lr_weight = get_list(up_lr_weight) - - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] - - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) - - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) - - if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - logger.info("apply block learning rate / 階層別学習率を適用します。") - if down_lr_weight != None: - down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") - else: - logger.info("down_lr_weight: all 1.0, すべて1.0") - - if mid_lr_weight != None: - mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - logger.info(f"mid_lr_weight: {mid_lr_weight}") - else: - logger.info("mid_lr_weight: 1.0") - - if up_lr_weight != None: - up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") - else: - logger.info("up_lr_weight: all 1.0, すべて1.0") - - return down_lr_weight, mid_lr_weight, up_lr_weight - - -# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく -def remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight -): - # set 0 to block dim without learning rate to remove the block - if down_lr_weight != None: - for i, lr in enumerate(down_lr_weight): - if lr == 0: - block_dims[i] = 0 - if conv_block_dims is not None: - conv_block_dims[i] = 0 - if mid_lr_weight != None: - if mid_lr_weight == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if up_lr_weight != None: - for i, lr in enumerate(up_lr_weight): - if lr == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - - return block_dims, block_alphas, conv_block_dims, conv_block_alphas - - -# 外部から呼び出す可能性を考慮しておく -def get_block_index(lora_name: str) -> int: - block_idx = -1 # invalid lora name - - m = RE_UPDOWN.search(lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 - - return block_idx - - -# Create network from weights for inference, weights are not loaded here (because can be merged) -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): - if weights_sd is None: - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file, safe_open - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - # get dim/alpha mapping - modules_dim = {} - modules_alpha = {} - for key, value in weights_sd.items(): - if "." not in key: - continue - - lora_name = key.split(".")[0] - if "alpha" in key: - modules_alpha[lora_name] = value - elif "lora_down" in key: - dim = value.size()[0] - modules_dim[lora_name] = dim - # logger.info(lora_name, value.size(), dim) - - # support old LoRA without alpha - for key in modules_dim.keys(): - if key not in modules_alpha: - modules_alpha[key] = modules_dim[key] - - module_class = LoRAInfModule if for_inference else LoRAModule - - network = LoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class - ) - - # block lr - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - - return network, weights_sd - - -class LoRANetwork(torch.nn.Module): - NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 - - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] - LORA_PREFIX_UNET = "lora_unet" - LORA_PREFIX_TEXT_ENCODER = "lora_te" - - # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER - LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" - LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" - - def __init__( - self, - text_encoder: Union[List[CLIPTextModel], CLIPTextModel], - unet, - multiplier: float = 1.0, - lora_dim: int = 4, - alpha: float = 1, - dropout: Optional[float] = None, - rank_dropout: Optional[float] = None, - module_dropout: Optional[float] = None, - conv_lora_dim: Optional[int] = None, - conv_alpha: Optional[float] = None, - block_dims: Optional[List[int]] = None, - block_alphas: Optional[List[float]] = None, - conv_block_dims: Optional[List[int]] = None, - conv_block_alphas: Optional[List[float]] = None, - modules_dim: Optional[Dict[str, int]] = None, - modules_alpha: Optional[Dict[str, int]] = None, - module_class: Type[object] = LoRAModule, - varbose: Optional[bool] = False, - ) -> None: - """ - LoRA network: すごく引数が多いが、パターンは以下の通り - 1. lora_dimとalphaを指定 - 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 - 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない - 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する - 5. modules_dimとmodules_alphaを指定 (推論用) - """ - super().__init__() - self.multiplier = multiplier - - self.lora_dim = lora_dim - self.alpha = alpha - self.conv_lora_dim = conv_lora_dim - self.conv_alpha = conv_alpha - self.dropout = dropout - self.rank_dropout = rank_dropout - self.module_dropout = module_dropout - - if modules_dim is not None: - logger.info(f"create LoRA network from weights") - elif block_dims is not None: - logger.info(f"create LoRA network from block_dims") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - logger.info(f"block_dims: {block_dims}") - logger.info(f"block_alphas: {block_alphas}") - if conv_block_dims is not None: - logger.info(f"conv_block_dims: {conv_block_dims}") - logger.info(f"conv_block_alphas: {conv_block_alphas}") - else: - logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") - - # create module instances - def create_modules( - is_unet: bool, - text_encoder_idx: Optional[int], # None, 1, 2 - root_module: torch.nn.Module, - target_replace_modules: List[torch.nn.Module], - ) -> List[LoRAModule]: - prefix = ( - self.LORA_PREFIX_UNET - if is_unet - else ( - self.LORA_PREFIX_TEXT_ENCODER - if text_encoder_idx is None - else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) - ) - ) - loras = [] - skipped = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - - if is_linear or is_conv2d: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - - dim = None - alpha = None - - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - elif is_unet and block_dims is not None: - # U-Netでblock_dims指定あり - block_idx = get_block_index(lora_name) - if is_linear or is_conv2d_1x1: - dim = block_dims[block_idx] - alpha = block_alphas[block_idx] - elif conv_block_dims is not None: - dim = conv_block_dims[block_idx] - alpha = conv_block_alphas[block_idx] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha - - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): - skipped.append(lora_name) - continue - - lora = module_class( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - ) - loras.append(lora) - return loras, skipped - - text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - - # create LoRA for text encoder - # 毎回すべてのモジュールを作るのは無駄なので要検討 - self.text_encoder_loras = [] - skipped_te = [] - for i, text_encoder in enumerate(text_encoders): - if len(text_encoders) > 1: - index = i + 1 - logger.info(f"create LoRA for Text Encoder {index}:") - else: - index = None - logger.info(f"create LoRA for Text Encoder:") - - text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - self.text_encoder_loras.extend(text_encoder_loras) - skipped_te += skipped - logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE - if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: - target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - - skipped = skipped_te + skipped_un - if varbose and len(skipped) > 0: - logger.warning( - f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" - ) - for name in skipped: - logger.info(f"\t{name}") - - self.up_lr_weight: List[float] = None - self.down_lr_weight: List[float] = None - self.mid_lr_weight: float = None - self.block_lr = False - - # assertion - names = set() - for lora in self.text_encoder_loras + self.unet_loras: - assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" - names.add(lora.lora_name) - - def set_multiplier(self, multiplier): - self.multiplier = multiplier - for lora in self.text_encoder_loras + self.unet_loras: - lora.multiplier = self.multiplier - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - info = self.load_state_dict(weights_sd, False) - return info - - def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - logger.info("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: - lora.apply_to() - self.add_module(lora.lora_name, lora) - - # マージできるかどうかを返す - def is_mergeable(self): - return True - - # TODO refactor to common function with apply_to - def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - apply_text_encoder = apply_unet = False - for key in weights_sd.keys(): - if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): - apply_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): - apply_unet = True - - if apply_text_encoder: - logger.info("enable LoRA for text encoder") - else: - self.text_encoder_loras = [] - - if apply_unet: - logger.info("enable LoRA for U-Net") - else: - self.unet_loras = [] - - for lora in self.text_encoder_loras + self.unet_loras: - sd_for_lora = {} - for key in weights_sd.keys(): - if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] - lora.merge_to(sd_for_lora, dtype, device) - - logger.info(f"weights are merged") - - # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない - def set_block_lr_weight( - self, - up_lr_weight: List[float] = None, - mid_lr_weight: float = None, - down_lr_weight: List[float] = None, - ): - self.block_lr = True - self.down_lr_weight = down_lr_weight - self.mid_lr_weight = mid_lr_weight - self.up_lr_weight = up_lr_weight - - def get_lr_weight(self, lora: LoRAModule) -> float: - lr_weight = 1.0 - block_idx = get_block_index(lora.lora_name) - if block_idx < 0: - return lr_weight - - if block_idx < LoRANetwork.NUM_OF_BLOCKS: - if self.down_lr_weight != None: - lr_weight = self.down_lr_weight[block_idx] - elif block_idx == LoRANetwork.NUM_OF_BLOCKS: - if self.mid_lr_weight != None: - lr_weight = self.mid_lr_weight - elif block_idx > LoRANetwork.NUM_OF_BLOCKS: - if self.up_lr_weight != None: - lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] - - return lr_weight - - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - self.requires_grad_(True) - all_params = [] - - def enumerate_params(loras: List[LoRAModule]): - params = [] - for lora in loras: - # params.extend(lora.parameters()) - params.extend(lora.get_trainable_params()) - return params - - if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) - - if self.unet_loras: - if self.block_lr: - # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - block_idx_to_lora = {} - for lora in self.unet_loras: - idx = get_block_index(lora.lora_name) - if idx not in block_idx_to_lora: - block_idx_to_lora[idx] = [] - block_idx_to_lora[idx].append(lora) - - # blockごとにパラメータを設定する - for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) - elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) - - else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - - return all_params - - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_grad_etc(self, text_encoder, unet): - self.requires_grad_(True) - - def on_epoch_start(self, text_encoder, unet): - self.train() - - def get_trainable_params(self): - return self.parameters() - - def save_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None - - state_dict = self.state_dict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - from library import train_util - - # Precalculate model hashes to save time on indexing - if metadata is None: - metadata = {} - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - # mask is a tensor with values from 0 to 1 - def set_region(self, sub_prompt_index, is_last_network, mask): - if mask.max() == 0: - mask = torch.ones_like(mask) - - self.mask = mask - self.sub_prompt_index = sub_prompt_index - self.is_last_network = is_last_network - - for lora in self.text_encoder_loras + self.unet_loras: - lora.set_network(self) - - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): - self.batch_size = batch_size - self.num_sub_prompts = num_sub_prompts - self.current_size = (height, width) - self.shared = shared - - # create masks - mask = self.mask - mask_dic = {} - mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w - ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight - dtype = ref_weight.dtype - device = ref_weight.device - - def resize_add(mh, mw): - # logger.info(mh, mw, mh * mw) - m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 - m = m.to(device, dtype=dtype) - mask_dic[mh * mw] = m - - h = height // 8 - w = width // 8 - for _ in range(4): - resize_add(h, w) - if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 - resize_add(h + h % 2, w + w % 2) - h = (h + 1) // 2 - w = (w + 1) // 2 - - self.mask_dic = mask_dic - - def backup_weights(self): - # 重みのバックアップを行う - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras - for lora in loras: - org_module = lora.org_module_ref[0] - if not hasattr(org_module, "_lora_org_weight"): - sd = org_module.state_dict() - org_module._lora_org_weight = sd["weight"].detach().clone() - org_module._lora_restored = True - - def restore_weights(self): - # 重みのリストアを行う - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras - for lora in loras: - org_module = lora.org_module_ref[0] - if not org_module._lora_restored: - sd = org_module.state_dict() - sd["weight"] = org_module._lora_org_weight - org_module.load_state_dict(sd) - org_module._lora_restored = True - - def pre_calculation(self): - # 事前計算を行う - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras - for lora in loras: - org_module = lora.org_module_ref[0] - sd = org_module.state_dict() - - org_weight = sd["weight"] - lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) - sd["weight"] = org_weight + lora_weight - assert sd["weight"].shape == org_weight.shape - org_module.load_state_dict(sd) - - org_module._lora_restored = False - lora.enabled = False - - def apply_max_norm_regularization(self, max_norm_value, device): - downkeys = [] - upkeys = [] - alphakeys = [] - norms = [] - keys_scaled = 0 - - state_dict = self.state_dict() - for key in state_dict.keys(): - if "lora_down" in key and "weight" in key: - downkeys.append(key) - upkeys.append(key.replace("lora_down", "lora_up")) - alphakeys.append(key.replace("lora_down.weight", "alpha")) - - for i in range(len(downkeys)): - down = state_dict[downkeys[i]].to(device) - up = state_dict[upkeys[i]].to(device) - alpha = state_dict[alphakeys[i]].to(device) - dim = down.shape[0] - scale = alpha / dim - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - - updown *= scale - - norm = updown.norm().clamp(min=max_norm_value / 2) - desired = torch.clamp(norm, max=max_norm_value) - ratio = desired.cpu() / norm.cpu() - sqrt_ratio = ratio**0.5 - if ratio != 1: - keys_scaled += 1 - state_dict[upkeys[i]] *= sqrt_ratio - state_dict[downkeys[i]] *= sqrt_ratio - scalednorm = updown.norm() * ratio - norms.append(scalednorm.item()) - - return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py deleted file mode 100644 index 6aaa58107..000000000 --- a/networks/lora_interrogator.py +++ /dev/null @@ -1,146 +0,0 @@ - - -from tqdm import tqdm -from library import model_util -import library.train_util as train_util -import argparse -from transformers import CLIPTokenizer - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -import library.model_util as model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う - -DEVICE = get_preferred_device() - - -def interrogate(args): - weights_dtype = torch.float16 - - # いろいろ準備する - logger.info(f"loading SD model: {args.sd_model}") - args.pretrained_model_name_or_path = args.sd_model - args.vae = None - text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) - - logger.info(f"loading LoRA: {args.model}") - network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) - - # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい - has_te_weight = False - for key in weights_sd.keys(): - if 'lora_te' in key: - has_te_weight = True - break - if not has_te_weight: - logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") - return - del vae - - logger.info("loading tokenizer") - if args.v2: - tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) - - text_encoder.to(DEVICE, dtype=weights_dtype) - text_encoder.eval() - unet.to(DEVICE, dtype=weights_dtype) - unet.eval() # U-Netは呼び出さないので不要だけど - - # トークンをひとつひとつ当たっていく - token_id_start = 0 - token_id_end = max(tokenizer.all_special_ids) - logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") - - def get_all_embeddings(text_encoder): - embs = [] - with torch.no_grad(): - for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): - batch = [] - for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): - tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] - # tokens = [tid] # こちらは結果がいまひとつ - batch.append(tokens) - - # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] - # clip skip対応 - batch = torch.tensor(batch).to(DEVICE) - if args.clip_skip is None: - encoder_hidden_states = text_encoder(batch)[0] - else: - enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.to("cpu") - - embs.extend(encoder_hidden_states) - return torch.stack(embs) - - logger.info("get original text encoder embeddings.") - orig_embs = get_all_embeddings(text_encoder) - - network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) - info = network.load_state_dict(weights_sd, strict=False) - logger.info(f"Loading LoRA weights: {info}") - - network.to(DEVICE, dtype=weights_dtype) - network.eval() - - del unet - - logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") - logger.info("get text encoder embeddings with lora.") - lora_embs = get_all_embeddings(text_encoder) - - # 比べる:とりあえず単純に差分の絶対値で - logger.info("comparing...") - diffs = {} - for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): - diff = torch.mean(torch.abs(orig_emb - lora_emb)) - # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない - diff = float(diff.detach().to('cpu').numpy()) - diffs[token_id_start + i] = diff - - diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) - - # 結果を表示する - print("top 100:") - for i, (token, diff) in enumerate(diffs_sorted[:100]): - # if diff < 1e-6: - # break - string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) - print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') - parser.add_argument("--sd_model", type=str, default=None, - help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") - parser.add_argument("--model", type=str, default=None, - help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--batch_size", type=int, default=16, - help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - interrogate(args) diff --git a/networks/merge_lora.py b/networks/merge_lora.py deleted file mode 100644 index fea8a3f32..000000000 --- a/networks/merge_lora.py +++ /dev/null @@ -1,360 +0,0 @@ -import math -import argparse -import os -import time -import torch -from safetensors.torch import load_file, save_file -from library import sai_model_spec, train_util -import library.model_util as model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == ".safetensors": - sd = load_file(file_name) - metadata = train_util.load_metadata_from_safetensors(file_name) - else: - sd = torch.load(file_name, map_location="cpu") - metadata = {} - - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - - return sd, metadata - - -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if os.path.splitext(file_name)[1] == ".safetensors": - save_file(model, file_name, metadata=metadata) - else: - torch.save(model, file_name) - - -def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): - text_encoder.to(merge_dtype) - unet.to(merge_dtype) - - # create module map - name_to_module = {} - for i, root_module in enumerate([text_encoder, unet]): - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - ) - - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd, _ = load_state_dict(model, merge_dtype) - - logger.info(f"merging...") - for key in lora_sd.keys(): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" - if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue - module = name_to_module[module_name] - # logger.info(f"apply {key} to {module}") - - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - if len(weight.size()) == 2: - # linear - if len(up_weight.size()) == 4: # use linear projection mismatch - up_weight = up_weight.squeeze(3).squeeze(2) - down_weight = down_weight.squeeze(3).squeeze(2) - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - - module.weight = torch.nn.Parameter(weight) - - -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): - base_alphas = {} # alpha for merged model - base_dims = {} - - merged_sd = {} - v2 = None - base_model = None - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd, lora_metadata = load_state_dict(model, merge_dtype) - - if lora_metadata is not None: - if v2 is None: - v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string - if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) - - # get alpha and dim - alphas = {} # alpha for current model - dims = {} # dims for current model - for key in lora_sd.keys(): - if "alpha" in key: - lora_module_name = key[: key.rfind(".alpha")] - alpha = float(lora_sd[key].detach().numpy()) - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - elif "lora_down" in key: - lora_module_name = key[: key.rfind(".lora_down")] - dim = lora_sd[key].size()[0] - dims[lora_module_name] = dim - if lora_module_name not in base_dims: - base_dims[lora_module_name] = dim - - for lora_module_name in dims.keys(): - if lora_module_name not in alphas: - alpha = dims[lora_module_name] - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - - logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") - - # merge - logger.info(f"merging...") - for key in lora_sd.keys(): - if "alpha" in key: - continue - if "lora_up" in key and concat: - concat_dim = 1 - elif "lora_down" in key and concat: - concat_dim = 0 - else: - concat_dim = None - - lora_module_name = key[: key.rfind(".lora_")] - - base_alpha = base_alphas[lora_module_name] - alpha = alphas[lora_module_name] - - scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - - if key in merged_sd: - assert ( - merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - if concat_dim is not None: - merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) - else: - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale - else: - merged_sd[key] = lora_sd[key] * scale - - # set alpha to sd - for lora_module_name, alpha in base_alphas.items(): - key = lora_module_name + ".alpha" - merged_sd[key] = torch.tensor(alpha) - if shuffle: - key_down = lora_module_name + ".lora_down.weight" - key_up = lora_module_name + ".lora_up.weight" - dim = merged_sd[key_down].shape[0] - perm = torch.randperm(dim) - merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] - - logger.info("merged model") - logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") - - # check all dims are same - dims_list = list(set(base_dims.values())) - alphas_list = list(set(base_alphas.values())) - all_same_dims = True - all_same_alphas = True - for dims in dims_list: - if dims != dims_list[0]: - all_same_dims = False - break - for alphas in alphas_list: - if alphas != alphas_list[0]: - all_same_alphas = False - break - - # build minimum metadata - dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" - alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) - - return merged_sd, metadata, v2 == "True" - - -def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - if args.sd_model is not None: - logger.info(f"loading SD model: {args.sd_model}") - - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) - - merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - - if args.no_metadata: - sai_metadata = None - else: - merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, - args.v2, - args.v2, - False, - False, - False, - time.time(), - title=title, - merged_from=merged_from, - is_stable_diffusion_ckpt=True, - ) - if args.v2: - # TODO read sai modelspec - logger.warning( - "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" - ) - - logger.info(f"saving SD model to: {args.save_to}") - model_util.save_stable_diffusion_checkpoint( - args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae - ) - else: - state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - - logger.info(f"calculating hashes and creating metadata...") - - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - if not args.no_metadata: - merged_from = sai_model_spec.build_merged_from(args.models) - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from - ) - if v2: - # TODO read sai modelspec - logger.warning( - "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" - ) - metadata.update(sai_metadata) - - logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", - ) - parser.add_argument( - "--precision", - type=str, - default="float", - choices=["float", "fp16", "bf16"], - help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", - ) - parser.add_argument( - "--sd_model", - type=str, - default=None, - help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", - ) - parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" - ) - parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" - ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") - parser.add_argument( - "--no_metadata", - action="store_true", - help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " - + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", - ) - parser.add_argument( - "--concat", - action="store_true", - help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " - + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", - ) - parser.add_argument( - "--shuffle", - action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - merge(args) diff --git a/networks/merge_lora_old.py b/networks/merge_lora_old.py deleted file mode 100644 index 334d127b7..000000000 --- a/networks/merge_lora_old.py +++ /dev/null @@ -1,190 +0,0 @@ - - -import argparse -import os -import torch -from safetensors.torch import load_file, save_file -import library.model_util as model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == '.safetensors': - sd = load_file(file_name) - else: - sd = torch.load(file_name, map_location='cpu') - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - return sd - - -def save_to_file(file_name, model, state_dict, dtype): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if os.path.splitext(file_name)[1] == '.safetensors': - save_file(model, file_name) - else: - torch.save(model, file_name) - - -def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): - text_encoder.to(merge_dtype) - unet.to(merge_dtype) - - # create module map - name_to_module = {} - for i, root_module in enumerate([text_encoder, unet]): - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE - - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): - lora_name = prefix + '.' + name + '.' + child_name - lora_name = lora_name.replace('.', '_') - name_to_module[lora_name] = child_module - - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) - - logger.info(f"merging...") - for key in lora_sd.keys(): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[:key.index("lora_down")] + 'alpha' - - # find original module for this lora - module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" - if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue - module = name_to_module[module_name] - # logger.info(f"apply {key} to {module}") - - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - else: - # conv2d - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale - - module.weight = torch.nn.Parameter(weight) - - -def merge_lora_models(models, ratios, merge_dtype): - merged_sd = {} - - alpha = None - dim = None - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd = load_state_dict(model, merge_dtype) - - logger.info(f"merging...") - for key in lora_sd.keys(): - if 'alpha' in key: - if key in merged_sd: - assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" - else: - alpha = lora_sd[key].detach().numpy() - merged_sd[key] = lora_sd[key] - else: - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio - else: - if "lora_down" in key: - dim = lora_sd[key].size()[0] - merged_sd[key] = lora_sd[key] * ratio - - logger.info(f"dim (rank): {dim}, alpha: {alpha}") - if alpha is None: - alpha = dim - - return merged_sd, dim, alpha - - -def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - - def str_to_dtype(p): - if p == 'float': - return torch.float - if p == 'fp16': - return torch.float16 - if p == 'bf16': - return torch.bfloat16 - return None - - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - if args.sd_model is not None: - logger.info(f"loading SD model: {args.sd_model}") - - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) - - merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - - logger.info("") - logger.info(f"saving SD model to: {args.save_to}") - model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, - args.sd_model, 0, 0, save_dtype, vae) - else: - state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) - - logger.info(f"") - logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") - parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") - parser.add_argument("--sd_model", type=str, default=None, - help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") - parser.add_argument("--save_to", type=str, default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--models", type=str, nargs='*', - help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") - parser.add_argument("--ratios", type=float, nargs='*', - help="ratios for each model / それぞれのLoRAモデルの比率") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - merge(args) diff --git a/networks/oft.py b/networks/oft.py deleted file mode 100644 index 461a98698..000000000 --- a/networks/oft.py +++ /dev/null @@ -1,433 +0,0 @@ -# OFT network module - -import math -import os -from typing import Dict, List, Optional, Tuple, Type, Union -from diffusers import AutoencoderKL -from transformers import CLIPTextModel -import numpy as np -import torch -import re -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") - - -class OFTModule(torch.nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ - - def __init__( - self, - oft_name, - org_module: torch.nn.Module, - multiplier=1.0, - dim=4, - alpha=1, - ): - """ - dim -> num blocks - alpha -> constraint - """ - super().__init__() - self.oft_name = oft_name - - self.num_blocks = dim - - if "Linear" in org_module.__class__.__name__: - out_dim = org_module.out_features - elif "Conv" in org_module.__class__.__name__: - out_dim = org_module.out_channels - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().numpy() - self.constraint = alpha * out_dim - self.register_buffer("alpha", torch.tensor(alpha)) - - self.block_size = out_dim // self.num_blocks - self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) - - self.out_dim = out_dim - self.shape = org_module.weight.shape - - self.multiplier = multiplier - self.org_module = [org_module] # moduleにならないようにlistに入れる - - def apply_to(self): - self.org_forward = self.org_module[0].forward - self.org_module[0].forward = self.forward - - def get_weight(self, multiplier=None): - if multiplier is None: - multiplier = self.multiplier - - block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I - R = torch.block_diag(*block_R_weighted) - - return R - - def forward(self, x, scale=None): - x = self.org_forward(x) - if self.multiplier == 0.0: - return x - - R = self.get_weight().to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x - - -class OFTInfModule(OFTModule): - def __init__( - self, - oft_name, - org_module: torch.nn.Module, - multiplier=1.0, - dim=4, - alpha=1, - **kwargs, - ): - # no dropout for inference - super().__init__(oft_name, org_module, multiplier, dim, alpha) - self.enabled = True - self.network: OFTNetwork = None - - def set_network(self, network): - self.network = network - - def forward(self, x, scale=None): - if not self.enabled: - return self.org_forward(x) - return super().forward(x, scale) - - def merge_to(self, multiplier=None, sign=1): - R = self.get_weight(multiplier) * sign - - # get org weight - org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - R = R.to(org_weight.device, dtype=org_weight.dtype) - - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R) - - # set weight to org_module - org_sd["weight"] = weight - self.org_module[0].load_state_dict(org_sd) - - -def create_network( - multiplier: float, - network_dim: Optional[int], - network_alpha: Optional[float], - vae: AutoencoderKL, - text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], - unet, - neuron_dropout: Optional[float] = None, - **kwargs, -): - if network_dim is None: - network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 - - enable_all_linear = kwargs.get("enable_all_linear", None) - enable_conv = kwargs.get("enable_conv", None) - if enable_all_linear is not None: - enable_all_linear = bool(enable_all_linear) - if enable_conv is not None: - enable_conv = bool(enable_conv) - - network = OFTNetwork( - text_encoder, - unet, - multiplier=multiplier, - dim=network_dim, - alpha=network_alpha, - enable_all_linear=enable_all_linear, - enable_conv=enable_conv, - varbose=True, - ) - return network - - -# Create network from weights for inference, weights are not loaded here (because can be merged) -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): - if weights_sd is None: - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file, safe_open - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - # check dim, alpha and if weights have for conv2d - dim = None - alpha = None - has_conv2d = None - all_linear = None - for name, param in weights_sd.items(): - if name.endswith(".alpha"): - if alpha is None: - alpha = param.item() - else: - if dim is None: - dim = param.size()[0] - if has_conv2d is None and param.dim() == 4: - has_conv2d = True - if all_linear is None: - if param.dim() == 3 and "attn" not in name: - all_linear = True - if dim is not None and alpha is not None and has_conv2d is not None: - break - if has_conv2d is None: - has_conv2d = False - if all_linear is None: - all_linear = False - - module_class = OFTInfModule if for_inference else OFTModule - network = OFTNetwork( - text_encoder, - unet, - multiplier=multiplier, - dim=dim, - alpha=alpha, - enable_all_linear=all_linear, - enable_conv=has_conv2d, - module_class=module_class, - ) - return network, weights_sd - - -class OFTNetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] - UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] - UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな - - def __init__( - self, - text_encoder: Union[List[CLIPTextModel], CLIPTextModel], - unet, - multiplier: float = 1.0, - dim: int = 4, - alpha: float = 1, - enable_all_linear: Optional[bool] = False, - enable_conv: Optional[bool] = False, - module_class: Type[object] = OFTModule, - varbose: Optional[bool] = False, - ) -> None: - super().__init__() - self.multiplier = multiplier - - self.dim = dim - self.alpha = alpha - - logger.info( - f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" - ) - - # create module instances - def create_modules( - root_module: torch.nn.Module, - target_replace_modules: List[torch.nn.Module], - ) -> List[OFTModule]: - prefix = self.OFT_PREFIX_UNET - ofts = [] - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = "Linear" in child_module.__class__.__name__ - is_conv2d = "Conv2d" in child_module.__class__.__name__ - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - - if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): - oft_name = prefix + "." + name + "." + child_name - oft_name = oft_name.replace(".", "_") - # logger.info(oft_name) - - oft = module_class( - oft_name, - child_module, - self.multiplier, - dim, - alpha, - ) - ofts.append(oft) - return ofts - - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - if enable_all_linear: - target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR - else: - target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY - if enable_conv: - target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - - self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) - logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") - - # assertion - names = set() - for oft in self.unet_ofts: - assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" - names.add(oft.oft_name) - - def set_multiplier(self, multiplier): - self.multiplier = multiplier - for oft in self.unet_ofts: - oft.multiplier = self.multiplier - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - info = self.load_state_dict(weights_sd, False) - return info - - def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): - assert apply_unet, "apply_unet must be True" - - for oft in self.unet_ofts: - oft.apply_to() - self.add_module(oft.oft_name, oft) - - # マージできるかどうかを返す - def is_mergeable(self): - return True - - # TODO refactor to common function with apply_to - def merge_to(self, text_encoder, unet, weights_sd, dtype, device): - logger.info("enable OFT for U-Net") - - for oft in self.unet_ofts: - sd_for_lora = {} - for key in weights_sd.keys(): - if key.startswith(oft.oft_name): - sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] - oft.load_state_dict(sd_for_lora, False) - oft.merge_to() - - logger.info(f"weights are merged") - - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - self.requires_grad_(True) - all_params = [] - - def enumerate_params(ofts): - params = [] - for oft in ofts: - params.extend(oft.parameters()) - - # logger.info num of params - num_params = 0 - for p in params: - num_params += p.numel() - logger.info(f"OFT params: {num_params}") - return params - - param_data = {"params": enumerate_params(self.unet_ofts)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - - return all_params - - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_grad_etc(self, text_encoder, unet): - self.requires_grad_(True) - - def on_epoch_start(self, text_encoder, unet): - self.train() - - def get_trainable_params(self): - return self.parameters() - - def save_weights(self, file, dtype, metadata): - if metadata is not None and len(metadata) == 0: - metadata = None - - state_dict = self.state_dict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - from library import train_util - - # Precalculate model hashes to save time on indexing - if metadata is None: - metadata = {} - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - def backup_weights(self): - # 重みのバックアップを行う - ofts: List[OFTInfModule] = self.unet_ofts - for oft in ofts: - org_module = oft.org_module[0] - if not hasattr(org_module, "_lora_org_weight"): - sd = org_module.state_dict() - org_module._lora_org_weight = sd["weight"].detach().clone() - org_module._lora_restored = True - - def restore_weights(self): - # 重みのリストアを行う - ofts: List[OFTInfModule] = self.unet_ofts - for oft in ofts: - org_module = oft.org_module[0] - if not org_module._lora_restored: - sd = org_module.state_dict() - sd["weight"] = org_module._lora_org_weight - org_module.load_state_dict(sd) - org_module._lora_restored = True - - def pre_calculation(self): - # 事前計算を行う - ofts: List[OFTInfModule] = self.unet_ofts - for oft in ofts: - org_module = oft.org_module[0] - oft.merge_to() - # sd = org_module.state_dict() - # org_weight = sd["weight"] - # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) - # sd["weight"] = org_weight + lora_weight - # assert sd["weight"].shape == org_weight.shape - # org_module.load_state_dict(sd) - - org_module._lora_restored = False - oft.enabled = False diff --git a/networks/resize_lora.py b/networks/resize_lora.py deleted file mode 100644 index d697baa4c..000000000 --- a/networks/resize_lora.py +++ /dev/null @@ -1,411 +0,0 @@ -# Convert LoRA to different rank approximation (should only be used to go to lower rank) -# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py -# Thanks to cloneofsimo - -import os -import argparse -import torch -from safetensors.torch import load_file, save_file, safe_open -from tqdm import tqdm -import numpy as np - -from library import train_util -from library import model_util -from library.utils import setup_logging - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -MIN_SV = 1e-6 - -# Model save and load functions - - -def load_state_dict(file_name, dtype): - if model_util.is_safetensors(file_name): - sd = load_file(file_name) - with safe_open(file_name, framework="pt") as f: - metadata = f.metadata() - else: - sd = torch.load(file_name, map_location="cpu") - metadata = None - - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - - return sd, metadata - - -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if model_util.is_safetensors(file_name): - save_file(state_dict, file_name, metadata) - else: - torch.save(state_dict, file_name) - - -# Indexing functions - - -def index_sv_cumulative(S, target): - original_sum = float(torch.sum(S)) - cumulative_sums = torch.cumsum(S, dim=0) / original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S) - 1)) - - return index - - -def index_sv_fro(S, target): - S_squared = S.pow(2) - S_fro_sq = float(torch.sum(S_squared)) - sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S) - 1)) - - return index - - -def index_sv_ratio(S, target): - max_sv = S[0] - min_sv = max_sv / target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S) - 1)) - - return index - - -# Modified from Kohaku-blueleaf's extract/merge functions -def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): - out_size, in_size, kernel_size, _ = weight.size() - U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) - - param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) - lora_rank = param_dict["new_rank"] - - U = U[:, :lora_rank] - S = S[:lora_rank] - U = U @ torch.diag(S) - Vh = Vh[:lora_rank, :] - - param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() - param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() - del U, S, Vh, weight - return param_dict - - -def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): - out_size, in_size = weight.size() - - U, S, Vh = torch.linalg.svd(weight.to(device)) - - param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) - lora_rank = param_dict["new_rank"] - - U = U[:, :lora_rank] - S = S[:lora_rank] - U = U @ torch.diag(S) - Vh = Vh[:lora_rank, :] - - param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() - param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() - del U, S, Vh, weight - return param_dict - - -def merge_conv(lora_down, lora_up, device): - in_rank, in_size, kernel_size, k_ = lora_down.shape - out_size, out_rank, _, _ = lora_up.shape - assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" - - lora_down = lora_down.to(device) - lora_up = lora_up.to(device) - - merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) - weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) - del lora_up, lora_down - return weight - - -def merge_linear(lora_down, lora_up, device): - in_rank, in_size = lora_down.shape - out_size, out_rank = lora_up.shape - assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" - - lora_down = lora_down.to(device) - lora_up = lora_up.to(device) - - weight = lora_up @ lora_down - del lora_up, lora_down - return weight - - -# Calculate new rank - - -def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): - param_dict = {} - - if dynamic_method == "sv_ratio": - # Calculate new dim and alpha based off ratio - new_rank = index_sv_ratio(S, dynamic_param) + 1 - new_alpha = float(scale * new_rank) - - elif dynamic_method == "sv_cumulative": - # Calculate new dim and alpha based off cumulative sum - new_rank = index_sv_cumulative(S, dynamic_param) + 1 - new_alpha = float(scale * new_rank) - - elif dynamic_method == "sv_fro": - # Calculate new dim and alpha based off sqrt sum of squares - new_rank = index_sv_fro(S, dynamic_param) + 1 - new_alpha = float(scale * new_rank) - else: - new_rank = rank - new_alpha = float(scale * new_rank) - - if S[0] <= MIN_SV: # Zero matrix, set dim to 1 - new_rank = 1 - new_alpha = float(scale * new_rank) - elif new_rank > rank: # cap max rank at rank - new_rank = rank - new_alpha = float(scale * new_rank) - - # Calculate resize info - s_sum = torch.sum(torch.abs(S)) - s_rank = torch.sum(torch.abs(S[:new_rank])) - - S_squared = S.pow(2) - s_fro = torch.sqrt(torch.sum(S_squared)) - s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) - fro_percent = float(s_red_fro / s_fro) - - param_dict["new_rank"] = new_rank - param_dict["new_alpha"] = new_alpha - param_dict["sum_retained"] = (s_rank) / s_sum - param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0] / S[new_rank - 1] - - return param_dict - - -def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): - network_alpha = None - network_dim = None - verbose_str = "\n" - fro_list = [] - - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and "alpha" in key: - network_alpha = value - if network_dim is None and "lora_down" in key and len(value.size()) == 2: - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim - - scale = network_alpha / network_dim - - if dynamic_method: - logger.info( - f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" - ) - - lora_down_weight = None - lora_up_weight = None - - o_lora_sd = lora_sd.copy() - block_down_name = None - block_up_name = None - - with torch.no_grad(): - for key, value in tqdm(lora_sd.items()): - weight_name = None - if "lora_down" in key: - block_down_name = key.rsplit(".lora_down", 1)[0] - weight_name = key.rsplit(".", 1)[-1] - lora_down_weight = value - else: - continue - - # find corresponding lora_up and alpha - block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) - lora_alpha = lora_sd.get(block_down_name + ".alpha", None) - - weights_loaded = lora_down_weight is not None and lora_up_weight is not None - - if weights_loaded: - - conv2d = len(lora_down_weight.size()) == 4 - if lora_alpha is None: - scale = 1.0 - else: - scale = lora_alpha / lora_down_weight.size()[0] - - if conv2d: - full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) - else: - full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) - - if verbose: - max_ratio = param_dict["max_ratio"] - sum_retained = param_dict["sum_retained"] - fro_retained = param_dict["fro_retained"] - if not np.isnan(fro_retained): - fro_list.append(float(fro_retained)) - - verbose_str += f"{block_down_name:75} | " - verbose_str += ( - f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" - ) - - if verbose and dynamic_method: - verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" - else: - verbose_str += "\n" - - new_alpha = param_dict["new_alpha"] - o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() - o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) - - block_down_name = None - block_up_name = None - lora_down_weight = None - lora_up_weight = None - weights_loaded = False - del param_dict - - if verbose: - print(verbose_str) - print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") - logger.info("resizing complete") - return o_lora_sd, network_dim, new_alpha - - -def resize(args): - if args.save_to is None or not ( - args.save_to.endswith(".ckpt") - or args.save_to.endswith(".pt") - or args.save_to.endswith(".pth") - or args.save_to.endswith(".safetensors") - ): - raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") - - args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank - - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - - if args.dynamic_method and not args.dynamic_param: - raise Exception("If using dynamic_method, then dynamic_param is required") - - merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - logger.info("loading Model...") - lora_sd, metadata = load_state_dict(args.model, merge_dtype) - - logger.info("Resizing Lora...") - state_dict, old_dim, new_alpha = resize_lora_model( - lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose - ) - - # update metadata - if metadata is None: - metadata = {} - - comment = metadata.get("ss_training_comment", "") - - if not args.dynamic_method: - conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})" - metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}" - metadata["ss_network_dim"] = str(args.new_rank) - metadata["ss_network_alpha"] = str(new_alpha) - else: - metadata["ss_training_comment"] = ( - f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" - ) - metadata["ss_network_dim"] = "Dynamic" - metadata["ss_network_alpha"] = "Dynamic" - - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", - ) - parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument( - "--new_conv_rank", - type=int, - default=None, - help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", - ) - parser.add_argument( - "--save_to", - type=str, - default=None, - help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", - ) - parser.add_argument( - "--model", - type=str, - default=None, - help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", - ) - parser.add_argument( - "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" - ) - parser.add_argument( - "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" - ) - parser.add_argument( - "--dynamic_method", - type=str, - default=None, - choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], - help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", - ) - parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - resize(args) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py deleted file mode 100644 index 3383a80de..000000000 --- a/networks/sdxl_merge_lora.py +++ /dev/null @@ -1,351 +0,0 @@ -import math -import argparse -import os -import time -import torch -from safetensors.torch import load_file, save_file -from tqdm import tqdm -from library import sai_model_spec, sdxl_model_util, train_util -import library.model_util as model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == ".safetensors": - sd = load_file(file_name) - metadata = train_util.load_metadata_from_safetensors(file_name) - else: - sd = torch.load(file_name, map_location="cpu") - metadata = {} - - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - - return sd, metadata - - -def save_to_file(file_name, model, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if os.path.splitext(file_name)[1] == ".safetensors": - save_file(model, file_name, metadata=metadata) - else: - torch.save(model, file_name) - - -def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): - text_encoder1.to(merge_dtype) - text_encoder1.to(merge_dtype) - unet.to(merge_dtype) - - # create module map - name_to_module = {} - for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 - else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( - lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - ) - - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd, _ = load_state_dict(model, merge_dtype) - - logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" - if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue - module = name_to_module[module_name] - # logger.info(f"apply {key} to {module}") - - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - - module.weight = torch.nn.Parameter(weight) - - -def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): - base_alphas = {} # alpha for merged model - base_dims = {} - - merged_sd = {} - v2 = None - base_model = None - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd, lora_metadata = load_state_dict(model, merge_dtype) - - if lora_metadata is not None: - if v2 is None: - v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず - if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) - - # get alpha and dim - alphas = {} # alpha for current model - dims = {} # dims for current model - for key in lora_sd.keys(): - if "alpha" in key: - lora_module_name = key[: key.rfind(".alpha")] - alpha = float(lora_sd[key].detach().numpy()) - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - elif "lora_down" in key: - lora_module_name = key[: key.rfind(".lora_down")] - dim = lora_sd[key].size()[0] - dims[lora_module_name] = dim - if lora_module_name not in base_dims: - base_dims[lora_module_name] = dim - - for lora_module_name in dims.keys(): - if lora_module_name not in alphas: - alpha = dims[lora_module_name] - alphas[lora_module_name] = alpha - if lora_module_name not in base_alphas: - base_alphas[lora_module_name] = alpha - - logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") - - # merge - logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "alpha" in key: - continue - - if "lora_up" in key and concat: - concat_dim = 1 - elif "lora_down" in key and concat: - concat_dim = 0 - else: - concat_dim = None - - lora_module_name = key[: key.rfind(".lora_")] - - base_alpha = base_alphas[lora_module_name] - alpha = alphas[lora_module_name] - - scale = math.sqrt(alpha / base_alpha) * ratio - scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 - - if key in merged_sd: - assert ( - merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - if concat_dim is not None: - merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) - else: - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale - else: - merged_sd[key] = lora_sd[key] * scale - - # set alpha to sd - for lora_module_name, alpha in base_alphas.items(): - key = lora_module_name + ".alpha" - merged_sd[key] = torch.tensor(alpha) - if shuffle: - key_down = lora_module_name + ".lora_down.weight" - key_up = lora_module_name + ".lora_up.weight" - dim = merged_sd[key_down].shape[0] - perm = torch.randperm(dim) - merged_sd[key_down] = merged_sd[key_down][perm] - merged_sd[key_up] = merged_sd[key_up][:,perm] - - logger.info("merged model") - logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") - - # check all dims are same - dims_list = list(set(base_dims.values())) - alphas_list = list(set(base_alphas.values())) - all_same_dims = True - all_same_alphas = True - for dims in dims_list: - if dims != dims_list[0]: - all_same_dims = False - break - for alphas in alphas_list: - if alphas != alphas_list[0]: - all_same_alphas = False - break - - # build minimum metadata - dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" - alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" - metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) - - return merged_sd, metadata - - -def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - if args.sd_model is not None: - logger.info(f"loading SD model: {args.sd_model}") - - ( - text_model1, - text_model2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") - - merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) - - if args.no_metadata: - sai_metadata = None - else: - merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from - ) - - logger.info(f"saving SD model to: {args.save_to}") - sdxl_model_util.save_stable_diffusion_checkpoint( - args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype - ) - else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) - - logger.info(f"calculating hashes and creating metadata...") - - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - if not args.no_metadata: - merged_from = sai_model_spec.build_merged_from(args.models) - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from - ) - metadata.update(sai_metadata) - - logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", - ) - parser.add_argument( - "--precision", - type=str, - default="float", - choices=["float", "fp16", "bf16"], - help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", - ) - parser.add_argument( - "--sd_model", - type=str, - default=None, - help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", - ) - parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" - ) - parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" - ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") - parser.add_argument( - "--no_metadata", - action="store_true", - help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " - + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", - ) - parser.add_argument( - "--concat", - action="store_true", - help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " - + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", - ) - parser.add_argument( - "--shuffle", - action="store_true", - help="shuffle lora weight./ " - + "LoRAの重みをシャッフルする", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - merge(args) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py deleted file mode 100644 index cb00a6000..000000000 --- a/networks/svd_merge_lora.py +++ /dev/null @@ -1,262 +0,0 @@ -import argparse -import os -import time -import torch -from safetensors.torch import load_file, save_file -from tqdm import tqdm -from library import sai_model_spec, train_util -import library.model_util as model_util -import lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -CLAMP_QUANTILE = 0.99 - - -def load_state_dict(file_name, dtype): - if os.path.splitext(file_name)[1] == ".safetensors": - sd = load_file(file_name) - metadata = train_util.load_metadata_from_safetensors(file_name) - else: - sd = torch.load(file_name, map_location="cpu") - metadata = {} - - for key in list(sd.keys()): - if type(sd[key]) == torch.Tensor: - sd[key] = sd[key].to(dtype) - - return sd, metadata - - -def save_to_file(file_name, state_dict, dtype, metadata): - if dtype is not None: - for key in list(state_dict.keys()): - if type(state_dict[key]) == torch.Tensor: - state_dict[key] = state_dict[key].to(dtype) - - if os.path.splitext(file_name)[1] == ".safetensors": - save_file(state_dict, file_name, metadata=metadata) - else: - torch.save(state_dict, file_name) - - -def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): - logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") - merged_sd = {} - v2 = None - base_model = None - for model, ratio in zip(models, ratios): - logger.info(f"loading: {model}") - lora_sd, lora_metadata = load_state_dict(model, merge_dtype) - - if lora_metadata is not None: - if v2 is None: - v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string - if base_model is None: - base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) - - # merge - logger.info(f"merging...") - for key in tqdm(list(lora_sd.keys())): - if "lora_down" not in key: - continue - - lora_module_name = key[: key.rfind(".lora_down")] - - down_weight = lora_sd[key] - network_dim = down_weight.size()[0] - - up_weight = lora_sd[lora_module_name + ".lora_up.weight"] - alpha = lora_sd.get(lora_module_name + ".alpha", network_dim) - - in_dim = down_weight.size()[1] - out_dim = up_weight.size()[0] - conv2d = len(down_weight.size()) == 4 - kernel_size = None if not conv2d else down_weight.size()[2:4] - # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) - - # make original weight if not exist - if lora_module_name not in merged_sd: - weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) - if device: - weight = weight.to(device) - else: - weight = merged_sd[lora_module_name] - - # merge to weight - if device: - up_weight = up_weight.to(device) - down_weight = down_weight.to(device) - - # W <- W + U * D - scale = alpha / network_dim - - if device: # and isinstance(scale, torch.Tensor): - scale = scale.to(device) - - if not conv2d: # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif kernel_size == (1, 1): - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - weight = weight + ratio * conved * scale - - merged_sd[lora_module_name] = weight - - # extract from merged weights - logger.info("extract new lora...") - merged_lora_sd = {} - with torch.no_grad(): - for lora_module_name, mat in tqdm(list(merged_sd.items())): - conv2d = len(mat.size()) == 4 - kernel_size = None if not conv2d else mat.size()[2:4] - conv2d_3x3 = conv2d and kernel_size != (1, 1) - out_dim, in_dim = mat.size()[0:2] - - if conv2d: - if conv2d_3x3: - mat = mat.flatten(start_dim=1) - else: - mat = mat.squeeze() - - module_new_rank = new_conv_rank if conv2d_3x3 else new_rank - module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim - - U, S, Vh = torch.linalg.svd(mat) - - U = U[:, :module_new_rank] - S = S[:module_new_rank] - U = U @ torch.diag(S) - - Vh = Vh[:module_new_rank, :] - - dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) - low_val = -hi_val - - U = U.clamp(low_val, hi_val) - Vh = Vh.clamp(low_val, hi_val) - - if conv2d: - U = U.reshape(out_dim, module_new_rank, 1, 1) - Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) - - up_weight = U - down_weight = Vh - - merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() - merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) - - # build minimum metadata - dims = f"{new_rank}" - alphas = f"{new_rank}" - if new_conv_rank is not None: - network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank} - else: - network_args = None - metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args) - - return merged_lora_sd, metadata, v2 == "True", base_model - - -def merge(args): - assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" - - def str_to_dtype(p): - if p == "float": - return torch.float - if p == "fp16": - return torch.float16 - if p == "bf16": - return torch.bfloat16 - return None - - merge_dtype = str_to_dtype(args.precision) - save_dtype = str_to_dtype(args.save_precision) - if save_dtype is None: - save_dtype = merge_dtype - - new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank - state_dict, metadata, v2, base_model = merge_lora_models( - args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype - ) - - logger.info(f"calculating hashes and creating metadata...") - - model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) - metadata["sshs_model_hash"] = model_hash - metadata["sshs_legacy_hash"] = legacy_hash - - if not args.no_metadata: - is_sdxl = base_model is not None and base_model.lower().startswith("sdxl") - merged_from = sai_model_spec.build_merged_from(args.models) - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from - ) - if v2: - # TODO read sai modelspec - logger.warning( - "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" - ) - metadata.update(sai_metadata) - - logger.info(f"saving model to: {args.save_to}") - save_to_file(args.save_to, state_dict, save_dtype, metadata) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--save_precision", - type=str, - default=None, - choices=[None, "float", "fp16", "bf16"], - help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", - ) - parser.add_argument( - "--precision", - type=str, - default="float", - choices=["float", "fp16", "bf16"], - help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", - ) - parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" - ) - parser.add_argument( - "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" - ) - parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") - parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") - parser.add_argument( - "--new_conv_rank", - type=int, - default=None, - help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", - ) - parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") - parser.add_argument( - "--no_metadata", - action="store_true", - help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " - + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - merge(args) diff --git a/requirements.txt b/requirements.txt index 42895ca38..a54613c26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,4 +43,4 @@ voluptuous==0.13.1 wandb==0.15.11 scipy==1.11.4 # for kohya_ss library --e . # no_verify leave this to specify not checking this a verification stage +-e ./sd-scripts # no_verify leave this to specify not checking this a verification stage diff --git a/requirements_linux.txt b/requirements_linux.txt index 6f64060d9..288e0f3ab 100644 --- a/requirements_linux.txt +++ b/requirements_linux.txt @@ -1,4 +1,4 @@ -torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage -xformers==0.0.21 bitsandbytes==0.41.1 -tensorboard==2.14.1 tensorflow==2.14.0 +torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 +bitsandbytes==0.41.2 +tensorboard==2.15.2 tensorflow==2.15.0.post1 -r requirements.txt diff --git a/requirements_linux_docker.txt b/requirements_linux_docker.txt index ee070f2e1..2606d5fd0 100644 --- a/requirements_linux_docker.txt +++ b/requirements_linux_docker.txt @@ -1,5 +1,5 @@ xformers>=0.0.20 bitsandbytes==0.41.1 accelerate==0.25.0 -tensorboard==2.14.1 -tensorflow==2.14.0 \ No newline at end of file +tensorboard==2.15.2 +tensorflow==2.15.0.post1 \ No newline at end of file diff --git a/requirements_windows_torch2.txt b/requirements_windows_torch2.txt index 6b7c5edb8..9bcbf81ed 100644 --- a/requirements_windows_torch2.txt +++ b/requirements_windows_torch2.txt @@ -1,7 +1,7 @@ -# torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 # no_verify -torch==2.1.2+cu118 torchvision==0.16.2+cu118 torchaudio==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 -bitsandbytes==0.41.1 # no_verify +torch==2.1.2+cu118 torchvision==0.16.2+cu118 torchaudio==2.1.2+cu118 xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 +nvidia-cudnn-cu11==8.9.5.29 +https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl # no_verify # https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl # no_verify +# bitsandbytes==0.41.1 tensorboard==2.14.1 tensorflow==2.14.0 -r requirements.txt diff --git a/sd-scripts b/sd-scripts new file mode 160000 index 000000000..2d7389185 --- /dev/null +++ b/sd-scripts @@ -0,0 +1 @@ +Subproject commit 2d7389185c021bc527b414563c245c5489d6328a diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py deleted file mode 100755 index 641b3209f..000000000 --- a/sdxl_gen_img.py +++ /dev/null @@ -1,3210 +0,0 @@ -import itertools -import json -from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable -import glob -import importlib -import inspect -import time -import zipfile -from diffusers.utils import deprecate -from diffusers.configuration_utils import FrozenDict -import argparse -import math -import os -import random -import re - -import diffusers -import numpy as np - -import torch -from library.device_utils import init_ipex, clean_memory, get_preferred_device -init_ipex() - -import torchvision -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - # UNet2DConditionModel, - StableDiffusionPipeline, -) -from einops import rearrange -from tqdm import tqdm -from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor -import PIL -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import library.model_util as model_util -import library.train_util as train_util -import library.sdxl_model_util as sdxl_model_util -import library.sdxl_train_util as sdxl_train_util -from networks.lora import LoRANetwork -from library.sdxl_original_unet import InferSdxlUNet2DConditionModel -from library.original_unet import FlashAttentionFunction -from networks.control_net_lllite import ControlNetLLLite -from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - -# その他の設定 -LATENT_CHANNELS = 4 -DOWNSAMPLING_FACTOR = 8 - -CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - logger.info("Enable memory efficient attention for U-Net") - - # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - logger.info("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - logger.info("Enable SDPA for U-Net") - unet.set_use_memory_efficient_attention(False, False) - unet.set_use_sdpa(True) - - -# TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? - vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う - elif sdpa: - replace_vae_attn_to_sdpa() - - -def replace_vae_attn_to_memory_efficient(): - logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_flash_attn(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_flash_attn - - -def replace_vae_attn_to_xformers(): - logger.info("VAE: Attention.forward has been replaced to xformers") - import xformers.ops - - def forward_xformers(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_xformers_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_xformers(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_xformers - - -def replace_vae_attn_to_sdpa(): - logger.info("VAE: Attention.forward has been replaced to sdpa") - - def forward_sdpa(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = torch.nn.functional.scaled_dot_product_attention( - query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - out = rearrange(out, "b n h d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_sdpa_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_sdpa(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_sdpa - - -# endregion - -# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 -# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 - - -class PipelineLike: - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoders: List[CLIPTextModel], - tokenizers: List[CLIPTokenizer], - unet: InferSdxlUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.vae = vae - self.text_encoders = text_encoders - self.tokenizers = tokenizers - self.unet: InferSdxlUNet2DConditionModel = unet - self.scheduler = scheduler - self.safety_checker = None - - self.clip_vision_model: CLIPVisionModelWithProjection = None - self.clip_vision_processor: CLIPImageProcessor = None - self.clip_vision_strength = 0.0 - - # Textual Inversion - self.token_replacements_list = [] - for _ in range(len(self.text_encoders)): - self.token_replacements_list.append({}) - - # ControlNet # not supported yet - self.control_nets: List[ControlNetLLLite] = [] - self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - - self.gradual_latent: GradualLatent = None - - # Textual Inversion - def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): - self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids - - def set_enable_control_net(self, en: bool): - self.control_net_enabled = en - - def get_token_replacer(self, tokenizer): - tokenizer_index = self.tokenizers.index(tokenizer) - token_replacements = self.token_replacements_list[tokenizer_index] - - def replace_tokens(tokens): - # logger.info("replace_tokens", tokens, "=>", token_replacements) - if isinstance(tokens, torch.Tensor): - tokens = tokens.tolist() - - new_tokens = [] - for token in tokens: - if token in token_replacements: - replacement = token_replacements[token] - new_tokens.extend(replacement) - else: - new_tokens.append(token) - return new_tokens - - return replace_tokens - - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets - - def set_gradual_latent(self, gradual_latent): - if gradual_latent is None: - print("gradual_latent is disabled") - self.gradual_latent = None - else: - print(f"gradual_latent is enabled: {gradual_latent}") - self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 1024, - width: int = 1024, - original_height: int = None, - original_width: int = None, - original_height_negative: int = None, - original_width_negative: int = None, - crop_top: int = 0, - crop_left: int = 0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_guide_images=None, - **kwargs, - ): - # TODO support secondary prompt - num_images_per_prompt = 1 # fixed because already prompt is repeated - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - reginonal_network = " AND " in prompt[0] - - vae_batch_size = ( - batch_size - if vae_batch_size is None - else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - ) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - logger.info(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - tes_text_embs = [] - tes_uncond_embs = [] - tes_real_uncond_embs = [] - - for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): - token_replacer = self.get_token_replacer(tokenizer) - - # use last text_pool, because it is from text encoder 2 - text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( - tokenizer, - text_encoder, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_text_embs.append(text_embeddings) - tes_uncond_embs.append(uncond_embeddings) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - token_replacer, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""] * batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_real_uncond_embs.append(real_uncond_embeddings) - - # concat text encoder outputs - text_embeddings = tes_text_embs[0] - uncond_embeddings = tes_uncond_embs[0] - for i in range(1, len(tes_text_embs)): - text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - if do_classifier_free_guidance: - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - if self.control_nets: - # ControlNetのhintにguide imageを流用する - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - if isinstance(clip_guide_images[0], PIL.Image.Image): - clip_guide_images = [preprocess_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images) - if isinstance(clip_guide_images, list): - clip_guide_images = torch.stack(clip_guide_images) - - clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) - - # create size embs - if original_height is None: - original_height = height - if original_width is None: - original_width = width - if original_height_negative is None: - original_height_negative = original_height - if original_width_negative is None: - original_width_negative = original_width - if crop_top is None: - crop_top = 0 - if crop_left is None: - crop_left = 0 - emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - uc_emb1 = sdxl_train_util.get_timestep_embedding( - torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 - ) - emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - - if reginonal_network: - # use last pool for conditioning - num_sub_prompts = len(text_pool) // batch_size - text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt - - if init_image is not None and self.clip_vision_model is not None: - logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") - vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) - pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) - - clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) - clip_vision_embeddings = clip_vision_embeddings.image_embeds - - if len(clip_vision_embeddings) == 1 and batch_size > 1: - clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) - - clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength - assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" - text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) - - c_vector = torch.cat([text_pool, c_vector], dim=1) - if do_classifier_free_guidance: - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - vector_embeddings = torch.cat([uc_vector, c_vector]) - else: - vector_embeddings = c_vector - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) - - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[-2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - clean_memory() - init_latents = [] - for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): - init_latent_dist = self.vae.encode( - (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( - self.vae.dtype - ) - ).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) - - init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents - - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - if self.control_net_enabled: - for control_net, _ in self.control_nets: - with torch.no_grad(): - control_net.set_cond_image(clip_guide_images) - else: - for control_net, _ in self.control_nets: - control_net.set_cond_image(None) - - each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets) - - # # first, we downscale the latents to the half of the size - # # 最初に1/2に縮小する - # height, width = latents.shape[-2:] - # # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to( - # # latents.dtype - # # ) - # latents = latents[:, :, ::2, ::2] - # current_scale = 0.5 - - # # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?) - # # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので) - # scale_step = 0.125 - - # # timesteps at which to start increasing the scale: 1000 seems to be enough - # # 拡大を開始するtimesteps: 1000で十分そうである - # start_timesteps = 1000 - - # # how many steps to wait before increasing the scale again - # # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed) - # # large values leads to flat images - - # # 何ステップごとに拡大するか - # # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる) - # # 大きすぎると細部が書き込まれずのっぺりした感じになる - # every_n_steps = 5 - - # scale_step = input("scale step:") - # scale_step = float(scale_step) - # start_timesteps = input("start timesteps:") - # start_timesteps = int(start_timesteps) - # every_n_steps = input("every n steps:") - # every_n_steps = int(every_n_steps) - - # # for i, t in enumerate(tqdm(timesteps)): - # i = 0 - # last_step = 0 - # while i < len(timesteps): - # t = timesteps[i] - # print(f"[{i}] t={t}") - - # print(i, t, current_scale, latents.shape) - # if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0: - # if i == last_step: - # pass - # else: - # print("upscale") - # current_scale = min(current_scale + scale_step, 1.0) - - # h = int(height * current_scale) // 8 * 8 - # w = int(width * current_scale) // 8 * 8 - - # latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to( - # latents.dtype - # ) - # last_step = i - # i = max(0, i - every_n_steps + 1) - - # diff = timesteps[i] - timesteps[last_step] - # # resized_init_noise = torch.nn.functional.interpolate( - # # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False - # # ).to(latents.dtype) - # # latents = self.scheduler.add_noise(latents, resized_init_noise, diff) - # latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4) - # # latents += torch.randn_like(latents) / 100 * diff - # continue - - enable_gradual_latent = False - if self.gradual_latent: - if not hasattr(self.scheduler, "set_gradual_latent_params"): - print("gradual_latent is not supported for this scheduler. Ignoring.") - print(self.scheduler.__class__.__name__) - else: - enable_gradual_latent = True - step_elapsed = 1000 - current_ratio = self.gradual_latent.ratio - - # first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする - height, width = latents.shape[-2:] - org_dtype = latents.dtype - if org_dtype == torch.bfloat16: - latents = latents.float() - latents = torch.nn.functional.interpolate( - latents, scale_factor=current_ratio, mode="bicubic", align_corners=False - ).to(org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - if self.gradual_latent.gaussian_blur_ksize: - latents = self.gradual_latent.apply_unshark_mask(latents) - - for i, t in enumerate(tqdm(timesteps)): - resized_size = None - if enable_gradual_latent: - # gradually upscale the latents / latentsを徐々にアップスケールする - if ( - t < self.gradual_latent.start_timesteps - and current_ratio < 1.0 - and step_elapsed >= self.gradual_latent.every_n_steps - ): - current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0) - # make divisible by 8 because size of latents must be divisible at bottom of UNet - h = int(height * current_ratio) // 8 * 8 - w = int(width * current_ratio) // 8 * 8 - resized_size = (h, w) - self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent) - step_elapsed = 0 - else: - self.scheduler.set_gradual_latent_params(None, None) - step_elapsed += 1 - - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # disable control net if ratio is set - if self.control_nets and self.control_net_enabled: - for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)): - if not enabled or ratio >= 1.0: - continue - if ratio < i / len(timesteps): - logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") - control_net.set_cond_image(None) - each_control_net_enabled[j] = False - - # predict the noise residual - # TODO Diffusers' ControlNet - # if self.control_nets and self.control_net_enabled: - # if reginonal_network: - # num_sub_and_neg_prompts = len(text_embeddings) // batch_size - # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - # else: - # text_emb_last = text_embeddings - - # # not working yet - # noise_pred = original_control_net.call_unet_and_control_net( - # i, - # num_latent_input, - # self.unet, - # self.control_nets, - # guided_hints, - # i / len(timesteps), - # latent_model_input, - # t, - # text_emb_last, - # ).sample - # else: - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( - num_latent_input - ) # uncond is real uncond - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - - negative_scale * (noise_pred_negative - noise_pred_uncond) - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - i += 1 - - if return_latents: - return latents - - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents.to(self.vae.dtype)).sample - else: - clean_memory() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append( - self.vae.decode( - (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) - ).sample - ) - image = torch.cat(images) - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - clean_memory() - - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - return image - - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - # keep break as separate token - text = text.replace("BREAK", "\\BREAK\\") - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - if word.strip() == "BREAK": - # pad until next multiple of tokenizer's max token length - pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - logger.info(f"BREAK pad_len: {pad_len}") - for i in range(pad_len): - # v2のときEOSをつけるべきかどうかわからないぜ - # if i == 0: - # text_token.append(tokenizer.eos_token_id) - # else: - text_token.append(tokenizer.pad_token_id) - text_weight.append(1.0) - continue - - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - - token = token_replacer(token) # for Textual Inversion - - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - text_encoder: CLIPTextModel, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - pool = None - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - # -2 is same for Text Encoder 1 and 2 - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-2] - if pool is None: - pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-2] - pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) - return text_embeddings, pool - - -def get_weighted_text_embeddings( - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, - token_replacer=None, - device=None, - **kwargs, -): - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - # split the prompts with "AND". each prompt must have the same number of splits - new_prompts = [] - for p in prompt: - new_prompts.extend(p.split(" AND ")) - prompt = new_prompts - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings, text_pool = get_unweighted_text_embeddings( - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - if uncond_prompt is not None: - uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( - text_encoder, - uncond_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens - return text_embeddings, text_pool, None, None, prompt_tokens - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -# regular expression for dynamic prompt: -# starts and ends with "{" and "}" -# contains at least one variant divided by "|" -# optional framgments divided by "$$" at start -# if the first fragment is "E" or "e", enumerate all variants -# if the second fragment is a number or two numbers, repeat the variants in the range -# if the third fragment is a string, use it as a separator - -RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") - - -def handle_dynamic_prompt_variants(prompt, repeat_count): - founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) - if not founds: - return [prompt] - - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating - - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") - - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] - else: - logger.warning(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) - - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values - - return replacer - - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] - - return replacer - - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts - - -# endregion - -# def load_clip_l14_336(dtype): -# logger.info(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - -class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any - raw_prompt: str - - -class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - original_width: int - original_height: int - original_width_negative: int - original_height_negative: int - crop_left: int - crop_top: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] - num_sub_prompts: int - - -class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt - - -def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - # モデルを読み込む - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype - ) - unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - mem_eff = not (args.xformers or args.sdpa) - replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) - replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) - - # tokenizerを読み込む - logger.info("loading tokenizer") - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # schedulerを用意する - sched_init_args = {} - has_steps_offset = True - has_clip_sample = True - scheduler_num_noises_per_step = 1 - - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - has_clip_sample = False - elif args.sampler == "lms" or args.sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - has_clip_sample = False - elif args.sampler == "euler" or args.sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - has_clip_sample = False - elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteSchedulerGL - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - has_clip_sample = False - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - has_clip_sample = False - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - has_clip_sample = False - has_steps_offset = False - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - has_clip_sample = False - elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - has_clip_sample = False - elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - has_clip_sample = False - - # 警告を出さないようにする - if has_steps_offset: - sched_init_args["steps_offset"] = 1 - if has_clip_sample: - sched_init_args["clip_sample"] = False - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == "randn": - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # ↓以下は結局PipeでFalseに設定されるので意味がなかった - # # clip_sample=Trueにする - # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # logger.info("set clip_sample to True") - # scheduler.config.clip_sample = True - - # deviceを決定する - device = get_preferred_device() - - # custom pipelineをコピったやつを生成する - if args.vae_slices: - from library.slicing_vae import SlicingAutoencoderKL - - sli_vae = SlicingAutoencoderKL( - act_fn="silu", - block_out_channels=(128, 256, 512, 512), - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], - in_channels=3, - latent_channels=4, - layers_per_block=2, - norm_num_groups=32, - out_channels=3, - sample_size=512, - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], - num_slices=args.vae_slices, - ) - sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする - vae = sli_vae - del sli_vae - - vae_dtype = dtype - if args.no_half_vae: - logger.info("set vae_dtype to float32") - vae_dtype = torch.float32 - vae.to(vae_dtype).to(device) - vae.eval() - - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) - text_encoder1.eval() - text_encoder2.eval() - unet.eval() - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - network_pre_calc = args.network_pre_calc - - # merge関連の引数を統合する - if args.network_merge: - network_merge = len(args.network_module) # all networks are merged - elif args.network_merge_n_models: - network_merge = args.network_merge_n_models - else: - network_merge = 0 - logger.info(f"network_merge: {network_merge}") - - for i, network_module in enumerate(args.network_module): - logger.info(f"import network module: {network_module}") - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights is None or len(args.network_weights) <= i: - raise ValueError("No weight. Weight is required.") - - network_weight = args.network_weights[i] - logger.info(f"load network weights from: {network_weight}") - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - logger.info(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - if network is None: - return - - mergeable = network.is_mergeable() - if network_merge and not mergeable: - logger.warning("network is not mergiable. ignore merge option.") - - if not mergeable or i >= network_merge: - # not merging - network.apply_to([text_encoder1, text_encoder2], unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - logger.info(f"weights are loaded: {info}") - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - if network_pre_calc: - logger.info("backup original weights") - network.backup_weights() - - networks.append(network) - network_default_muls.append(network_mul) - else: - network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) - - else: - networks = [] - - # upscalerの指定があれば取得する - upscaler = None - if args.highres_fix_upscaler: - logger.info(f"import upscaler module: {args.highres_fix_upscaler}") - imported_module = importlib.import_module(args.highres_fix_upscaler) - - us_kwargs = {} - if args.highres_fix_upscaler_args: - for net_arg in args.highres_fix_upscaler_args.split(";"): - key, value = net_arg.split("=") - us_kwargs[key] = value - - logger.info("create upscaler") - upscaler = imported_module.create_upscaler(**us_kwargs) - upscaler.to(dtype).to(device) - - # ControlNetの処理 - control_nets: List[Tuple[ControlNetLLLite, float]] = [] - # if args.control_net_models: - # for i, model in enumerate(args.control_net_models): - # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) - # prep = original_control_net.load_preprocess(prep_type) - # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) - if args.control_net_lllite_models: - for i, model_file in enumerate(args.control_net_lllite_models): - logger.info(f"loading ControlNet-LLLite: {model_file}") - - from safetensors.torch import load_file - - state_dict = load_file(model_file) - mlp_dim = None - cond_emb_dim = None - for key, value in state_dict.items(): - if mlp_dim is None and "down.0.weight" in key: - mlp_dim = value.shape[0] - elif cond_emb_dim is None and "conditioning1.0" in key: - cond_emb_dim = value.shape[0] * 2 - if mlp_dim is not None and cond_emb_dim is not None: - break - assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" - - multiplier = ( - 1.0 - if not args.control_net_multipliers or len(args.control_net_multipliers) <= i - else args.control_net_multipliers[i] - ) - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim, multiplier=multiplier) - control_net.apply_to() - control_net.load_state_dict(state_dict) - control_net.to(dtype).to(device) - control_net.set_batch_cond_only(False, False) - control_nets.append((control_net, ratio)) - - if args.opt_channels_last: - logger.info(f"set optimizing: channels last") - text_encoder1.to(memory_format=torch.channels_last) - text_encoder2.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.to(memory_format=torch.channels_last) - # cn.unet.to(memory_format=torch.channels_last) - # cn.net.to(memory_format=torch.channels_last) - - pipe = PipelineLike( - device, - vae, - [text_encoder1, text_encoder2], - [tokenizer1, tokenizer2], - unet, - scheduler, - args.clip_skip, - ) - pipe.set_control_nets(control_nets) - logger.info("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Deep Shrink - if args.ds_depth_1 is not None: - unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) - - # Gradual Latent - if args.gradual_latent_timesteps is not None: - if args.gradual_latent_unsharp_params: - us_params = args.gradual_latent_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] - us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - - gradual_latent = GradualLatent( - args.gradual_latent_ratio, - args.gradual_latent_timesteps, - args.gradual_latent_every_n_steps, - args.gradual_latent_ratio_step, - args.gradual_latent_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds1 = [] - token_ids_embeds2 = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - if "string_to_param" in data: - data = data["string_to_param"] - - embeds1 = data["clip_l"] # text encoder 1 - embeds2 = data["clip_g"] # text encoder 2 - - num_vectors_per_token = embeds1.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens1 = tokenizer1.add_tokens(token_strings) - num_added_tokens2 = tokenizer2.add_tokens(token_strings) - assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( - f"tokenizer has same word to token string (filename): {embeds_file}" - + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" - ) - - token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) - token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") - assert ( - min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 - ), f"token ids1 is not ordered" - assert ( - min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 - ), f"token ids2 is not ordered" - assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" - assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... - pipe.add_token_replacement(1, token_ids2[0], token_ids2) - - token_ids_embeds1.append((token_ids1, embeds1)) - token_ids_embeds2.append((token_ids2, embeds2)) - - text_encoder1.resize_token_embeddings(len(tokenizer1)) - text_encoder2.resize_token_embeddings(len(tokenizer2)) - token_embeds1 = text_encoder1.get_input_embeddings().weight.data - token_embeds2 = text_encoder2.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds1: - for token_id, embed in zip(token_ids, embeds): - token_embeds1[token_id] = embed - for token_ids, embeds in token_ids_embeds2: - for token_id, embed in zip(token_ids, embeds): - token_embeds2[token_id] = embed - - # promptを取得する - if args.from_file is not None: - logger.info(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] - else: - paths = ( - glob.glob(os.path.join(path, "*.png")) - + glob.glob(os.path.join(path, "*.jpg")) - + glob.glob(os.path.join(path, "*.jpeg")) - + glob.glob(os.path.join(path, "*.webp")) - ) - paths.sort() - - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - logger.info(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) - - return images - - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, "filename"): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized - - if args.image_path is not None: - logger.info(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - logger.info(f"loaded {len(init_images)} images for img2img") - - # CLIP Vision - if args.clip_vision_strength is not None: - logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}") - vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) - vision_model.to(device, dtype) - processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) - - pipe.clip_vision_model = vision_model - pipe.clip_vision_processor = processor - pipe.clip_vision_strength = args.clip_vision_strength - logger.info(f"CLIP Vision model loaded.") - - else: - init_images = None - - if args.mask_path is not None: - logger.info(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - logger.info(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None - - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - logger.info("get prompts from images' metadata") - for img in init_images: - if "prompt" in img.text: - prompt = img.text["prompt"] - if "negative-prompt" in img.text: - prompt += " --n " + img.text["negative-prompt"] - prompt_list.append(prompt) - - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l - - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l - - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - # highres fix を考慮に入れる - w, h = args.W, args.H - if highres_fix: - w = int(w * args.highres_fix_scale + 0.5) - h = int(h * args.highres_fix_scale + 0.5) - - if init_images is not None: - logger.info(f"resize img2img source images to {w}*{h}") - init_images = resize_images(init_images, (w, h)) - if mask_images is not None: - logger.info(f"resize img2img mask images to {w}*{h}") - mask_images = resize_images(mask_images, (w, h)) - - regional_network = False - if networks and mask_images: - # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 - regional_network = True - logger.info("use mask as region") - - size = None - for i, network in enumerate(networks): - if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: - np_mask = np.array(mask_images[0]) - - if args.network_regional_mask_max_color_codes: - # カラーコードでマスクを指定する - ch0 = (i + 1) & 1 - ch1 = ((i + 1) >> 1) & 1 - ch2 = ((i + 1) >> 2) & 1 - np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) - np_mask = np_mask.astype(np.uint8) * 255 - else: - np_mask = np_mask[:, :, i] - size = np_mask.shape - else: - np_mask = np.full(size, 255, dtype=np.uint8) - mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) - network.set_region(i, i == len(networks) - 1, mask) - mask_images = None - - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - logger.info(f"load image for ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) - - logger.info(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - logger.warning( - f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}" - ) - guide_images = None - else: - guide_images = None - - # seed指定時はseedを決めておく - if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None - - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 1024 - if args.H is None: - args.H = 1024 - - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - - for gen_iter in range(args.n_iter): - logger.info(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) - - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) - - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - - logger.info("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - - def scale_and_round(x): - if x is None: - return None - return int(x * args.highres_fix_scale + 0.5) - - width_1st = scale_and_round(ext.width) - height_1st = scale_and_round(ext.height) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 - - original_width_1st = scale_and_round(ext.original_width) - original_height_1st = scale_and_round(ext.original_height) - original_width_negative_1st = scale_and_round(ext.original_width_negative) - original_height_negative_1st = scale_and_round(ext.original_height_negative) - crop_left_1st = scale_and_round(ext.crop_left) - crop_top_1st = scale_and_round(ext.crop_top) - - strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength - - ext_1st = BatchDataExt( - width_1st, - height_1st, - original_width_1st, - original_height_1st, - original_width_negative_1st, - original_height_negative_1st, - crop_left_1st, - crop_top_1st, - args.highres_fix_steps, - ext.scale, - ext.negative_scale, - strength_1st, - ext.network_muls, - ext.num_sub_prompts, - ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) - - pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) - - # 2nd stageのバッチを作成して以下処理する - logger.info("process 2nd stage") - width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height - - if upscaler: - # upscalerを使って画像を拡大する - lowreso_imgs = None if is_1st_latent else images_1st - lowreso_latents = None if not is_1st_latent else images_1st - - # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents - batch_size = len(images_1st) - vae_batch_size = ( - batch_size - if args.vae_batch_size is None - else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) - ) - vae_batch_size = int(vae_batch_size) - images_1st = upscaler.upscale( - vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size - ) - - elif args.highres_fix_latents_upscaling: - # latentを拡大する - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" - ) # , antialias=True) - images_1st = images_1st.to(org_dtype) - - else: - # 画像をLANCZOSで拡大する - images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] - - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd - - if args.highres_fix_disable_control_net: - pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする - - # このバッチの情報を取り出す - ( - return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), - ( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - network_muls, - num_sub_prompts, - ), - ) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) - - prompts = [] - negative_prompts = [] - raw_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [ - torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step) - ] - seeds = [] - clip_prompts = [] - - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] - - if mask_image is not None: - mask_images = [] - else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None - - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None - - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, ( - _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), - _, - ) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - raw_prompts.append(raw_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - # 追加ネットワークの処理 - shared = {} - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) - - if not regional_network and network_pre_calc: - for n in networks: - n.restore_weights() - for n in networks: - n.pre_calculation() - logger.info("pre-calculation... done") - - images = pipe( - prompts, - negative_prompts, - init_images, - mask_images, - height, - width, - original_height, - original_width, - original_height_negative, - original_width_negative, - crop_top, - crop_left, - steps, - scale, - negative_scale, - strength, - latents=start_code, - output_type="pil", - max_embeddings_multiples=max_embeddings_multiples, - img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, - return_latents=return_latents, - clip_prompts=clip_prompts, - clip_guide_images=guide_images, - ) - if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images - - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - logger.error( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) - - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - logger.info("") - logger.info("Type prompt:") - try: - raw_prompt = input() - except EOFError: - break - - valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - raw_prompt = prompt_list[prompt_index] - - # sd-dynamic-prompts like variants: - # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - - if pi == 0 or len(raw_prompts) > 1: - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - # Deep Shrink - ds_depth_1 = None # means no override - ds_timesteps_1 = args.ds_timesteps_1 - ds_depth_2 = args.ds_depth_2 - ds_timesteps_2 = args.ds_timesteps_2 - ds_ratio = args.ds_ratio - - # Gradual Latent - gl_timesteps = None # means no override - gl_ratio = args.gradual_latent_ratio - gl_every_n_steps = args.gradual_latent_every_n_steps - gl_ratio_step = args.gradual_latent_ratio_step - gl_s_noise = args.gradual_latent_s_noise - gl_unsharp_params = args.gradual_latent_unsharp_params - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - logger.info(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - logger.info(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - logger.info(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - logger.info(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - logger.info(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - logger.info(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - logger.info(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - logger.info(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - logger.info(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - logger.info(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - logger.info(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - logger.info(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - logger.info(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - logger.info(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - logger.info(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - logger.info(f"network mul: {network_muls}") - continue - - # Deep Shrink - m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 1 - ds_depth_1 = int(m.group(1)) - logger.info(f"deep shrink depth 1: {ds_depth_1}") - continue - - m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 1 - ds_timesteps_1 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}") - continue - - m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink depth 2 - ds_depth_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink depth 2: {ds_depth_2}") - continue - - m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink timesteps 2 - ds_timesteps_2 = int(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}") - continue - - m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) - if m: # deep shrink ratio - ds_ratio = float(m.group(1)) - ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override - logger.info(f"deep shrink ratio: {ds_ratio}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - # Gradual Latent - m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent timesteps - gl_timesteps = int(m.group(1)) - print(f"gradual latent timesteps: {gl_timesteps}") - continue - - m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio - gl_ratio = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio: {ds_ratio}") - continue - - m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent every n steps - gl_every_n_steps = int(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent every n steps: {gl_every_n_steps}") - continue - - m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent ratio step - gl_ratio_step = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent ratio step: {gl_ratio_step}") - continue - - m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE) - if m: # gradual latent s noise - gl_s_noise = float(m.group(1)) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent s noise: {gl_s_noise}") - continue - - m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # gradual latent unsharp params - gl_unsharp_params = m.group(1) - gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override - print(f"gradual latent unsharp params: {gl_unsharp_params}") - continue - - except ValueError as ex: - logger.error(f"Exception in parsing / 解析エラー: {parg}") - logger.error(f"{ex}") - - # override Deep Shrink - if ds_depth_1 is not None: - if ds_depth_1 < 0: - ds_depth_1 = args.ds_depth_1 or 3 - unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) - - # override Gradual Latent - if gl_timesteps is not None: - if gl_timesteps < 0: - gl_timesteps = args.gradual_latent_timesteps or 650 - if gl_unsharp_params is not None: - unsharp_params = gl_unsharp_params.split(",") - us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] - us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) - us_ksize = int(us_ksize) - else: - us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None - gradual_latent = GradualLatent( - gl_ratio, - gl_timesteps, - gl_every_n_steps, - gl_ratio_step, - gl_s_noise, - us_ksize, - us_sigma, - us_strength, - us_target_x, - ) - pipe.set_gradual_latent(gradual_latent) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - logger.error("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - logger.info(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - logger.warning( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt - ), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() - - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() - - global_step += 1 - - prompt_index += 1 - - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() - - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument( - "--from_file", - type=str, - default=None, - help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", - ) - parser.add_argument( - "--interactive", - action="store_true", - help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)", - ) - parser.add_argument( - "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" - ) - parser.add_argument( - "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" - ) - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument( - "--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする" - ) - parser.add_argument( - "--use_original_file_name", - action="store_true", - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", - ) - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument( - "--original_height", - type=int, - default=None, - help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width", - type=int, - default=None, - help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値", - ) - parser.add_argument( - "--original_height_negative", - type=int, - default=None, - help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width_negative", - type=int, - default=None, - help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", - ) - parser.add_argument( - "--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値" - ) - parser.add_argument( - "--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値" - ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument( - "--vae_batch_size", - type=float, - default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", - ) - parser.add_argument( - "--vae_slices", - type=int, - default=None, - help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", - ) - parser.add_argument( - "--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない" - ) - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument( - "--sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", - ) - parser.add_argument( - "--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ" - ) - parser.add_argument( - "--vae", - type=str, - default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument( - "--seed", - type=int, - default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", - ) - parser.add_argument( - "--iter_same_seed", - action="store_true", - help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", - ) - parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") - parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") - parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") - parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") - parser.add_argument( - "--diffusers_xformers", - action="store_true", - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - "--opt_channels_last", - action="store_true", - help="set channels last option to model / モデルにchannels lastを指定し最適化する", - ) - parser.add_argument( - "--network_module", - type=str, - default=None, - nargs="*", - help="additional network module to use / 追加ネットワークを使う時そのモジュール名", - ) - parser.add_argument( - "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" - ) - parser.add_argument( - "--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率" - ) - parser.add_argument( - "--network_args", - type=str, - default=None, - nargs="*", - help="additional arguments for network (key=value) / ネットワークへの追加の引数", - ) - parser.add_argument( - "--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する" - ) - parser.add_argument( - "--network_merge_n_models", - type=int, - default=None, - help="merge this number of networks / この数だけネットワークをマージする", - ) - parser.add_argument( - "--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする" - ) - parser.add_argument( - "--network_pre_calc", - action="store_true", - help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する", - ) - parser.add_argument( - "--network_regional_mask_max_color_codes", - type=int, - default=None, - help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", - ) - parser.add_argument( - "--textual_inversion_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", - ) - parser.add_argument( - "--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う" - ) - parser.add_argument( - "--max_embeddings_multiples", - type=int, - default=None, - help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", - ) - parser.add_argument( - "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" - ) - parser.add_argument( - "--highres_fix_scale", - type=float, - default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", - ) - parser.add_argument( - "--highres_fix_steps", - type=int, - default=28, - help="1st stage steps for highres fix / highres fixの最初のステージのステップ数", - ) - parser.add_argument( - "--highres_fix_strength", - type=float, - default=None, - help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", - ) - parser.add_argument( - "--highres_fix_save_1st", - action="store_true", - help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する", - ) - parser.add_argument( - "--highres_fix_latents_upscaling", - action="store_true", - help="use latents upscaling for highres fix / highres fixでlatentで拡大する", - ) - parser.add_argument( - "--highres_fix_upscaler", - type=str, - default=None, - help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名", - ) - parser.add_argument( - "--highres_fix_upscaler_args", - type=str, - default=None, - help="additional arguments for upscaler (key=value) / upscalerへの追加の引数", - ) - parser.add_argument( - "--highres_fix_disable_control_net", - action="store_true", - help="disable ControlNet for highres fix / highres fixでControlNetを使わない", - ) - - parser.add_argument( - "--negative_scale", - type=float, - default=None, - help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する", - ) - - parser.add_argument( - "--control_net_lllite_models", - type=str, - default=None, - nargs="*", - help="ControlNet models to use / 使用するControlNetのモデル名", - ) - # parser.add_argument( - # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - # ) - # parser.add_argument( - # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - # ) - parser.add_argument( - "--control_net_multipliers", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率" - ) - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - ) - parser.add_argument( - "--clip_vision_strength", - type=float, - default=None, - help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", - ) - - # Deep Shrink - parser.add_argument( - "--ds_depth_1", - type=int, - default=None, - help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", - ) - parser.add_argument( - "--ds_timesteps_1", - type=int, - default=650, - help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", - ) - parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") - parser.add_argument( - "--ds_timesteps_2", - type=int, - default=650, - help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", - ) - parser.add_argument( - "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" - ) - - # gradual latent - parser.add_argument( - "--gradual_latent_timesteps", - type=int, - default=None, - help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する", - ) - parser.add_argument( - "--gradual_latent_ratio", - type=float, - default=0.5, - help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する", - ) - parser.add_argument( - "--gradual_latent_ratio_step", - type=float, - default=0.125, - help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか", - ) - parser.add_argument( - "--gradual_latent_every_n_steps", - type=int, - default=3, - help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる", - ) - parser.add_argument( - "--gradual_latent_s_noise", - type=float, - default=1.0, - help="s_noise for Gradual Latent / Gradual Latentのs_noise", - ) - parser.add_argument( - "--gradual_latent_unsharp_params", - type=str, - default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", - ) - - # # parser.add_argument( - # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" - # ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - setup_logging(args, reset=True) - main(args) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py deleted file mode 100644 index 084735665..000000000 --- a/sdxl_minimal_inference.py +++ /dev/null @@ -1,329 +0,0 @@ -# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う -# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE - -import argparse -import datetime -import math -import os -import random -from einops import repeat -import numpy as np - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -from tqdm import tqdm -from transformers import CLIPTokenizer -from diffusers import EulerDiscreteScheduler -from PIL import Image -import open_clip -from safetensors.torch import load_file - -from library import model_util, sdxl_model_util -import networks.lora as lora -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -# scheduler: このあたりの設定はSD1/2と同じでいいらしい -# scheduler: The settings around here seem to be the same as SD1/2 -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - - -# Time EmbeddingはDiffusersからのコピー -# Time Embedding is copied from Diffusers - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def get_timestep_embedding(x, outdim): - assert len(x.shape) == 2 - b, dims = x.shape[0], x.shape[1] - # x = rearrange(x, "b d -> (b d)") - x = torch.flatten(x) - emb = timestep_embedding(x, outdim) - # emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim) - emb = torch.reshape(emb, (b, dims * outdim)) - return emb - - -if __name__ == "__main__": - # 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions - - # SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL - target_height = 1024 - target_width = 1024 - original_height = target_height - original_width = target_width - crop_top = 0 - crop_left = 0 - - steps = 50 - guidance_scale = 7 - seed = None # 1 - - DEVICE = get_preferred_device() - DTYPE = torch.float16 # bfloat16 may work - - parser = argparse.ArgumentParser() - parser.add_argument("--ckpt_path", type=str, required=True) - parser.add_argument("--prompt", type=str, default="A photo of a cat") - parser.add_argument("--prompt2", type=str, default=None) - parser.add_argument("--negative_prompt", type=str, default="") - parser.add_argument("--output_dir", type=str, default=".") - parser.add_argument( - "--lora_weights", - type=str, - nargs="*", - default=[], - help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", - ) - parser.add_argument("--interactive", action="store_true") - args = parser.parse_args() - - if args.prompt2 is None: - args.prompt2 = args.prompt - - # HuggingFaceのmodel id - text_encoder_1_name = "openai/clip-vit-large-patch14" - text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - - # checkpointを読み込む。モデル変換についてはそちらの関数を参照 - # Load checkpoint. For model conversion, see this function - - # 本体RAMが少ない場合はGPUにロードするといいかも - # If the main RAM is small, it may be better to load it on the GPU - text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu" - ) - - # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている - # In SDXL, Text Encoder 1 is also using HuggingFace's - - # Text Encoder 2はSDXL本体ではopen_clipを使っている - # それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う - # 重みの変換コードはSD2とほぼ同じ - # In SDXL, Text Encoder 2 is using open_clip - # It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's - # The weight conversion code is almost the same as SD2 - - # VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う - # fp16でNaNが出やすいようだ - # The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different. - # NaN seems to be more likely to occur in fp16 - - unet.to(DEVICE, dtype=DTYPE) - unet.eval() - - vae_dtype = DTYPE - if DTYPE == torch.float16: - logger.info("use float32 for vae") - vae_dtype = torch.float32 - vae.to(DEVICE, dtype=vae_dtype) - vae.eval() - - text_model1.to(DEVICE, dtype=DTYPE) - text_model1.eval() - text_model2.to(DEVICE, dtype=DTYPE) - text_model2.eval() - - unet.set_use_memory_efficient_attention(True, False) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(True) - - # Tokenizers - tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) - tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77) - - # LoRA - for weights_file in args.lora_weights: - if ";" in weights_file: - weights_file, multiplier = weights_file.split(";") - multiplier = float(multiplier) - else: - multiplier = 1.0 - - lora_model, weights_sd = lora.create_network_from_weights( - multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True - ) - lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE) - - # scheduler - scheduler = EulerDiscreteScheduler( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - ) - - def generate_image(prompt, prompt2, negative_prompt, seed=None): - # 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future - # prepare embedding - with torch.no_grad(): - # vector - emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) - # logger.info("emb1", emb1.shape) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) - uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right - - # crossattn - - # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders - def call_text_encoder(text, text2): - # text encoder 1 - batch_encoding = tokenizer1( - text, - truncation=True, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(DEVICE) - - with torch.no_grad(): - enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True) - text_embedding1 = enc_out["hidden_states"][11] - # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい - - # text encoder 2 - with torch.no_grad(): - tokens = tokenizer2(text2).to(DEVICE) - - enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) - text_embedding2_penu = enc_out["hidden_states"][-2] - # logger.info("hidden_states2", text_embedding2_penu.shape) - text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion - - # 連結して終了 concat and finish - text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2) - return text_embedding, text_embedding2_pool - - # cond - c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) - # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape) - c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) - - # uncond - uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt) - uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1) - - text_embeddings = torch.cat([uc_ctx, c_ctx]) - vector_embeddings = torch.cat([uc_vector, c_vector]) - - # メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する - - if seed is not None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - # # random generator for initial noise - # generator = torch.Generator(device="cuda").manual_seed(seed) - generator = None - else: - generator = None - - # get the initial random noise unless the user supplied it - # SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している - # SDXL creates latents in CPU, Diffusers creates latents in target device - latents_shape = (1, 4, target_height // 8, target_width // 8) - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=torch.float32, - ).to(DEVICE, dtype=DTYPE) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * scheduler.init_noise_sigma - - # set timesteps - scheduler.set_timesteps(steps, DEVICE) - - # このへんはDiffusersからのコピペ - # Copy from Diffusers - timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE) - num_latent_input = 2 - with torch.no_grad(): - for i, t in enumerate(tqdm(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = scheduler.scale_model_input(latent_model_input, t) - - noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings) - - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - latents = scheduler.step(noise_pred, t, latents).prev_sample - - # latents = 1 / 0.18215 * latents - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - latents = latents.to(vae_dtype) - image = vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - # 保存して終了 save and finish - timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - for i, img in enumerate(image): - img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png")) - - if not args.interactive: - generate_image(args.prompt, args.prompt2, args.negative_prompt, seed) - else: - # loop for interactive - while True: - prompt = input("prompt: ") - if prompt == "": - break - prompt2 = input("prompt2: ") - if prompt2 == "": - prompt2 = prompt - negative_prompt = input("negative prompt: ") - seed = input("seed: ") - if seed == "": - seed = None - else: - seed = int(seed) - generate_image(prompt, prompt2, negative_prompt, seed) - - logger.info("Done!") diff --git a/sdxl_train.py b/sdxl_train.py deleted file mode 100644 index e0df263d6..000000000 --- a/sdxl_train.py +++ /dev/null @@ -1,792 +0,0 @@ -# training with captions - -import argparse -import math -import os -from multiprocessing import Value -from typing import List -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from accelerate.utils import set_seed -from diffusers import DDPMScheduler -from library import sdxl_model_util - -import library.train_util as train_util - -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -import library.config_util as config_util -import library.sdxl_train_util as sdxl_train_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - prepare_scheduler_for_custom_training, - scale_v_prediction_loss_like_noise_prediction, - add_v_prediction_like_loss, - apply_debiased_estimation, -) -from library.sdxl_original_unet import SdxlUNet2DConditionModel - - -UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23 - - -def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]: - block_params = [[] for _ in range(len(block_lrs))] - - for i, (name, param) in enumerate(unet.named_parameters()): - if name.startswith("time_embed.") or name.startswith("label_emb."): - block_index = 0 # 0 - elif name.startswith("input_blocks."): # 1-9 - block_index = 1 + int(name.split(".")[1]) - elif name.startswith("middle_block."): # 10-12 - block_index = 10 + int(name.split(".")[1]) - elif name.startswith("output_blocks."): # 13-21 - block_index = 13 + int(name.split(".")[1]) - elif name.startswith("out."): # 22 - block_index = 22 - else: - raise ValueError(f"unexpected parameter name: {name}") - - block_params[block_index].append(param) - - params_to_optimize = [] - for i, params in enumerate(block_params): - if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0 - continue - params_to_optimize.append({"params": params, "lr": block_lrs[i]}) - - return params_to_optimize - - -def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type): - names = [] - block_index = 0 - while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2: - if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR: - if block_lrs[block_index] == 0: - block_index += 1 - continue - names.append(f"block{block_index}") - elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR: - names.append("text_encoder1") - elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1: - names.append("text_encoder2") - - block_index += 1 - - train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - sdxl_train_util.verify_sdxl_training_args(args) - setup_logging(args, reset=True) - - assert ( - not args.weighted_captions - ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" - assert ( - not args.train_text_encoder or not args.cache_text_encoder_outputs - ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" - - if args.block_lr: - block_lrs = [float(lr) for lr in args.block_lr.split(",")] - assert ( - len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR - ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" - else: - block_lrs = None - - cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None - - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する - - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - train_dataset_group.verify_bucket_reso_steps(32) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, True) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります" - - # Diffusers版のxformers使用フラグを設定する関数 - def set_diffusers_xformers_flag(model, valid): - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - fn_recursive_set_mem_eff(model) - - # モデルに xformers とか memory efficient attention を組み込む - if args.diffusers_xformers: - # もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず - accelerator.print("Use xformers by Diffusers") - # set_diffusers_xformers_flag(unet, True) - set_diffusers_xformers_flag(vae, True) - else: - # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある - accelerator.print("Disable Diffusers' xformers") - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # 学習を準備する:モデルを適切な状態にする - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - train_unet = args.learning_rate > 0 - train_text_encoder1 = False - train_text_encoder2 = False - - if args.train_text_encoder: - # TODO each option for two text encoders? - accelerator.print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder1.gradient_checkpointing_enable() - text_encoder2.gradient_checkpointing_enable() - lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_text_encoder1 = lr_te1 > 0 - train_text_encoder2 = lr_te2 > 0 - - # caching one text encoder output is not supported - if not train_text_encoder1: - text_encoder1.to(weight_dtype) - if not train_text_encoder2: - text_encoder2.to(weight_dtype) - text_encoder1.requires_grad_(train_text_encoder1) - text_encoder2.requires_grad_(train_text_encoder2) - text_encoder1.train(train_text_encoder1) - text_encoder2.train(train_text_encoder2) - else: - text_encoder1.to(weight_dtype) - text_encoder2.to(weight_dtype) - text_encoder1.requires_grad_(False) - text_encoder2.requires_grad_(False) - text_encoder1.eval() - text_encoder2.eval() - - # TextEncoderの出力をキャッシュする - if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad - with torch.no_grad(), accelerator.autocast(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) - - unet.requires_grad_(train_unet) - if not train_unet: - unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - - training_models = [] - params_to_optimize = [] - if train_unet: - training_models.append(unet) - if block_lrs is None: - params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate}) - else: - params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs)) - - if train_text_encoder1: - training_models.append(text_encoder1) - params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) - if train_text_encoder2: - training_models.append(text_encoder2) - params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) - - # calculate number of trainable parameters - n_params = 0 - for params in params_to_optimize: - for p in params["params"]: - n_params += p.numel() - - accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") - accelerator.print(f"number of models: {len(training_models)}") - accelerator.print(f"number of trainable parameters: {n_params}") - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder1.to(weight_dtype) - text_encoder2.to(weight_dtype) - elif args.full_bf16: - assert ( - args.mixed_precision == "bf16" - ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - accelerator.print("enable full bf16 training.") - unet.to(weight_dtype) - text_encoder1.to(weight_dtype) - text_encoder2.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_unet: - unet = accelerator.prepare(unet) - if train_text_encoder1: - # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer - text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) - text_encoder1.text_model.final_layer_norm.requires_grad_(False) - text_encoder1 = accelerator.prepare(text_encoder1) - if train_text_encoder2: - text_encoder2 = accelerator.prepare(text_encoder2) - - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - accelerator.print("running training / 学習開始") - accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # accelerator.print( - # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - # ) - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - - # For --sample_at_first - sdxl_train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet - ) - - loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for m in training_models: - m.train() - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(*training_models): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - with torch.no_grad(): - # latentに変換 - latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.set_grad_enabled(args.train_text_encoder): - # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - # unwrap_model is fine for models not wrapped by accelerator - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoder1.device), - # batch["input_ids2"].to(text_encoder1.device), - # tokenizer1, - # tokenizer2, - # text_encoder1, - # text_encoder2, - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") - - # get size embeddings - orig_size = batch["original_sizes_hw"] - crop_size = batch["crop_top_lefts"] - target_size = batch["target_sizes_hw"] - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # concat embeddings - vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - - target = noise - - if ( - args.min_snr_gamma - or args.scale_v_pred_loss_like_noise_pred - or args.v_pred_like_loss - or args.debiased_estimation_loss - ): - # do not mean over batch dimension for snr weight or scale v-pred loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # mean over batch dimension - else: - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - sdxl_train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - False, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(text_encoder1), - accelerator.unwrap_model(text_encoder2), - accelerator.unwrap_model(unet), - vae, - logit_scale, - ckpt_info, - ) - - current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: - logs = {"loss": current_loss} - if block_lrs is None: - train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) - else: - append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs - - accelerator.log(logs, step=global_step) - - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(text_encoder1), - accelerator.unwrap_model(text_encoder2), - accelerator.unwrap_model(unet), - vae, - logit_scale, - ckpt_info, - ) - - sdxl_train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - [tokenizer1, tokenizer2], - [text_encoder1, text_encoder2], - unet, - ) - - is_main_process = accelerator.is_main_process - # if is_main_process: - unet = accelerator.unwrap_model(unet) - text_encoder1 = accelerator.unwrap_model(text_encoder1) - text_encoder2 = accelerator.unwrap_model(text_encoder2) - - accelerator.end_training() - - if args.save_state: # and is_main_process: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - sdxl_train_util.save_sd_model_on_train_end( - args, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - global_step, - text_encoder1, - text_encoder2, - unet, - vae, - logit_scale, - ckpt_info, - ) - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) - - parser.add_argument( - "--learning_rate_te1", - type=float, - default=None, - help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", - ) - parser.add_argument( - "--learning_rate_te2", - type=float, - default=None, - help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", - ) - - parser.add_argument( - "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" - ) - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - parser.add_argument( - "--block_lr", - type=str, - default=None, - help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " - + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py deleted file mode 100644 index 1e5f92349..000000000 --- a/sdxl_train_control_net_lllite.py +++ /dev/null @@ -1,616 +0,0 @@ -# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード -# training code for ControlNet-LLLite with passing cond_image to U-Net's forward - -import argparse -import json -import math -import os -import random -import time -from multiprocessing import Value -from types import SimpleNamespace -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from accelerate.utils import set_seed -import accelerate -from diffusers import DDPMScheduler, ControlNetModel -from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util - -import library.model_util as model_util -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - add_v_prediction_like_loss, - apply_snr_weight, - prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, - scale_v_prediction_loss_like_noise_prediction, - apply_debiased_estimation, -) -import networks.control_net_lllite_for_train as control_net_lllite_for_train -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - sdxl_train_util.verify_sdxl_training_args(args) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - train_dataset_group.verify_bucket_reso_steps(32) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - else: - logger.warning( - "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" - ) - - if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # TextEncoderの出力をキャッシュする - if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() - - # prepare ControlNet-LLLite - control_net_lllite_for_train.replace_unet_linear_and_conv2d() - - if args.network_weights is not None: - accelerator.print(f"initialize U-Net with ControlNet-LLLite") - with accelerate.init_empty_weights(): - unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() - unet_lllite.to(accelerator.device, dtype=weight_dtype) - - unet_sd = unet.state_dict() - info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd) - accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}") - else: - # cosumes large memory, so send to GPU before creating the LLLite model - accelerator.print("sending U-Net to GPU") - unet.to(accelerator.device, dtype=weight_dtype) - unet_sd = unet.state_dict() - - # init LLLite weights - accelerator.print(f"initialize U-Net with ControlNet-LLLite") - - if args.lowram: - with accelerate.init_on_device(accelerator.device): - unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() - else: - unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite() - unet_lllite.to(weight_dtype) - - info = unet_lllite.load_lllite_weights(None, unet_sd) - accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}") - del unet_sd, unet - - unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite - del unet_lllite - - unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = list(unet.prepare_params()) - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - # if args.full_fp16: - # assert ( - # args.mixed_precision == "fp16" - # ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - # accelerator.print("enable full fp16 training.") - # unet.to(weight_dtype) - # elif args.full_bf16: - # assert ( - # args.mixed_precision == "bf16" - # ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - # accelerator.print("enable full bf16 training.") - # unet.to(weight_dtype) - - unet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - if args.gradient_checkpointing: - unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる - else: - unet.eval() - - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model( - ckpt_name, - unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite, - steps, - epoch_no, - force_sync_upload=False, - ): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" - - unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(unet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.no_grad(): - # Get the text embedding for conditioning - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # get size embeddings - orig_size = batch["original_sizes_hw"] - crop_size = batch["crop_top_lefts"] - target_size = batch["target_sizes_hw"] - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # concat embeddings - vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet - # 内部でcond_embに変換される / it will be converted to cond_emb inside - - # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # end of epoch - - if is_main_process: - unet = accelerator.unwrap_model(unet) - - accelerator.end_training() - - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" - ) - parser.add_argument( - "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" - ) - parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") - parser.add_argument( - "--network_dropout", - type=float, - default=None, - help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - return parser - - -if __name__ == "__main__": - # sdxl_original_unet.USE_REENTRANT = False - - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py deleted file mode 100644 index dac56eedd..000000000 --- a/sdxl_train_control_net_lllite_old.py +++ /dev/null @@ -1,584 +0,0 @@ -import argparse -import json -import math -import os -import random -import time -from multiprocessing import Value -from types import SimpleNamespace -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from accelerate.utils import set_seed -from diffusers import DDPMScheduler, ControlNetModel -from safetensors.torch import load_file -from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util - -import library.model_util as model_util -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - add_v_prediction_like_loss, - apply_snr_weight, - prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, - scale_v_prediction_loss_like_noise_prediction, - apply_debiased_estimation, -) -import networks.control_net_lllite as control_net_lllite -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - sdxl_train_util.verify_sdxl_training_args(args) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - train_dataset_group.verify_bucket_reso_steps(32) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - else: - logger.warning( - "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません" - ) - - if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # TextEncoderの出力をキャッシュする - if args.cache_text_encoder_outputs: - # Text Encodes are eval and no grad - with torch.no_grad(): - train_dataset_group.cache_text_encoder_outputs( - (tokenizer1, tokenizer2), - (text_encoder1, text_encoder2), - accelerator.device, - None, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - accelerator.wait_for_everyone() - - # prepare ControlNet - network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) - network.apply_to() - - if args.network_weights is not None: - info = network.load_weights(args.network_weights) - accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - network.enable_gradient_checkpointing() # may have no effect - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = list(network.prepare_optimizer_params()) - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - network.to(weight_dtype) - elif args.full_bf16: - assert ( - args.mixed_precision == "bf16" - ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - accelerator.print("enable full bf16 training.") - unet.to(weight_dtype) - network.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) - network: control_net_lllite.ControlNetLLLite - - if args.gradient_checkpointing: - unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる - else: - unet.eval() - - network.prepare_grad_etc() - - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" - - unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - network.on_epoch_start() # train() - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(network): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * sdxl_model_util.VAE_SCALE_FACTOR - - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.no_grad(): - # Get the text embedding for conditioning - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # get size embeddings - orig_size = batch["original_sizes_hw"] - crop_size = batch["crop_top_lefts"] - target_size = batch["target_sizes_hw"] - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # concat embeddings - vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet - # 内部でcond_embに変換される / it will be converted to cond_emb inside - network.set_cond_image(controlnet_image) - - # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # end of epoch - - if is_main_process: - network = accelerator.unwrap_model(network) - - accelerator.end_training() - - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数" - ) - parser.add_argument( - "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" - ) - parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") - parser.add_argument( - "--network_dropout", - type=float, - default=None, - help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - return parser - - -if __name__ == "__main__": - # sdxl_original_unet.USE_REENTRANT = False - - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/sdxl_train_network.py b/sdxl_train_network.py deleted file mode 100644 index d33239d92..000000000 --- a/sdxl_train_network.py +++ /dev/null @@ -1,184 +0,0 @@ -import argparse - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from library import sdxl_model_util, sdxl_train_util, train_util -import train_network -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -class SdxlNetworkTrainer(train_network.NetworkTrainer): - def __init__(self): - super().__init__() - self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR - self.is_sdxl = True - - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - sdxl_train_util.verify_sdxl_training_args(args) - - if args.cache_text_encoder_outputs: - assert ( - train_dataset_group.is_text_encoder_output_cacheable() - ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - - assert ( - args.network_train_unet_only or not args.cache_text_encoder_outputs - ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" - - train_dataset_group.verify_bucket_reso_steps(32) - - def load_target_model(self, args, weight_dtype, accelerator): - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) - - self.load_stable_diffusion_format = load_stable_diffusion_format - self.logit_scale = logit_scale - self.ckpt_info = ckpt_info - - return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer - - def is_text_encoder_outputs_cached(self, args): - return args.cache_text_encoder_outputs - - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype - ): - if args.cache_text_encoder_outputs: - if not args.lowram: - # メモリ消費を減らす - logger.info("move vae and unet to cpu to save memory") - org_vae_device = vae.device - org_unet_device = unet.device - vae.to("cpu") - unet.to("cpu") - clean_memory_on_device(accelerator.device) - - # When TE is not be trained, it will not be prepared so we need to use explicit autocast - with accelerator.autocast(): - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) - - text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU - text_encoders[1].to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - - if not args.lowram: - logger.info("move vae and unet back to original device") - vae.to(org_vae_device) - unet.to(org_unet_device) - else: - # Text Encoderから毎回出力を取得するので、GPUに乗せておく - text_encoders[0].to(accelerator.device, dtype=weight_dtype) - text_encoders[1].to(accelerator.device, dtype=weight_dtype) - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - # Get the text embedding for conditioning - # TODO support weighted captions - # if args.weighted_captions: - # encoder_hidden_states = get_weighted_text_embeddings( - # tokenizer, - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) - # else: - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - else: - encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) - encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) - pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) - - # # verify that the text encoder outputs are correct - # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( - # args.max_token_length, - # batch["input_ids"].to(text_encoders[0].device), - # batch["input_ids2"].to(text_encoders[0].device), - # tokenizers[0], - # tokenizers[1], - # text_encoders[0], - # text_encoders[1], - # None if not args.full_fp16 else weight_dtype, - # ) - # b_size = encoder_hidden_states1.shape[0] - # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 - # logger.info("text encoder outputs verified") - - return encoder_hidden_states1, encoder_hidden_states2, pool2 - - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - # get size embeddings - orig_size = batch["original_sizes_hw"] - crop_size = batch["crop_top_lefts"] - target_size = batch["target_sizes_hw"] - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # concat embeddings - encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - return noise_pred - - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - - -def setup_parser() -> argparse.ArgumentParser: - parser = train_network.setup_parser() - sdxl_train_util.add_sdxl_training_arguments(parser) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - trainer = SdxlNetworkTrainer() - trainer.train(args) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py deleted file mode 100644 index b9a948bb2..000000000 --- a/sdxl_train_textual_inversion.py +++ /dev/null @@ -1,138 +0,0 @@ -import argparse -import os - -import regex - -import torch -from library.device_utils import init_ipex -init_ipex() - -import open_clip -from library import sdxl_model_util, sdxl_train_util, train_util - -import train_textual_inversion - - -class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): - def __init__(self): - super().__init__() - self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR - self.is_sdxl = True - - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) - sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) - - train_dataset_group.verify_bucket_reso_steps(32) - - def load_target_model(self, args, weight_dtype, accelerator): - ( - load_stable_diffusion_format, - text_encoder1, - text_encoder2, - vae, - unet, - logit_scale, - ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) - - self.load_stable_diffusion_format = load_stable_diffusion_format - self.logit_scale = logit_scale - self.ckpt_info = ckpt_info - - return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet - - def load_tokenizer(self, args): - tokenizer = sdxl_train_util.load_tokenizers(args) - return tokenizer - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids1 = batch["input_ids"] - input_ids2 = batch["input_ids2"] - with torch.enable_grad(): - input_ids1 = input_ids1.to(accelerator.device) - input_ids2 = input_ids2.to(accelerator.device) - encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( - args.max_token_length, - input_ids1, - input_ids2, - tokenizers[0], - tokenizers[1], - text_encoders[0], - text_encoders[1], - None if not args.full_fp16 else weight_dtype, - accelerator=accelerator, - ) - return encoder_hidden_states1, encoder_hidden_states2, pool2 - - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - # get size embeddings - orig_size = batch["original_sizes_hw"] - crop_size = batch["crop_top_lefts"] - target_size = batch["target_sizes_hw"] - embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # concat embeddings - encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - return noise_pred - - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): - sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) - - def save_weights(self, file, updated_embs, save_dtype, metadata): - state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - data = load_file(file) - else: - data = torch.load(file, map_location="cpu") - - emb_l = data.get("clip_l", None) # ViT-L text encoder 1 - emb_g = data.get("clip_g", None) # BiG-G text encoder 2 - - assert ( - emb_l is not None or emb_g is not None - ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" - - return [emb_l, emb_g] - - -def setup_parser() -> argparse.ArgumentParser: - parser = train_textual_inversion.setup_parser() - # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching - # sdxl_train_util.add_sdxl_training_arguments(parser) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - trainer = SdxlTextualInversionTrainer() - trainer.train(args) diff --git a/setup/setup_common.py b/setup/setup_common.py index a5dcca27e..edddec4ce 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -13,6 +13,100 @@ errors = 0 # Define the 'errors' variable before using it log = logging.getLogger('sd') +def update_submodule(): + """ + Ensure the submodule is initialized and updated. + """ + try: + # Initialize and update the submodule + subprocess.run(["git", "submodule", "update", "--init", "--recursive", "--quiet"], check=True) + log.info("Submodule initialized and updated.") + + except subprocess.CalledProcessError as e: + log.error(f"Error during Git operation: {e}") + except FileNotFoundError as e: + log.error(e) + +def read_tag_version_from_file(file_path): + """ + Read the tag version from a given file. + + Parameters: + - file_path: The path to the file containing the tag version. + + Returns: + The tag version as a string. + """ + with open(file_path, 'r') as file: + # Read the first line and strip whitespace + tag_version = file.readline().strip() + return tag_version + +def clone_or_checkout(repo_url, branch_or_tag, directory_name): + """ + Clone a repo or checkout a specific branch or tag if the repo already exists. + For branches, it updates to the latest version before checking out. + Suppresses detached HEAD advice for tags or specific commits. + Restores the original working directory after operations. + + Parameters: + - repo_url: The URL of the Git repository. + - branch_or_tag: The name of the branch or tag to clone or checkout. + - directory_name: The name of the directory to clone into or where the repo already exists. + """ + original_dir = os.getcwd() # Store the original directory + try: + if not os.path.exists(directory_name): + # Directory does not exist, clone the repo quietly + + # Construct the command as a string for logging + # run_cmd = f"git clone --branch {branch_or_tag} --single-branch --quiet {repo_url} {directory_name}" + run_cmd = ["git", "clone", "--branch", branch_or_tag, "--single-branch", "--quiet", repo_url, directory_name] + + + # Log the command + log.debug(run_cmd) + + # Run the command + process = subprocess.Popen( + run_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + output, error = process.communicate() + + if error and not error.startswith("Note: switching to"): + log.warning(error) + else: + log.info(f"Successfully cloned sd-scripts {branch_or_tag}") + + else: + os.chdir(directory_name) + subprocess.run(["git", "fetch", "--all", "--quiet"], check=True) + subprocess.run(["git", "config", "advice.detachedHead", "false"], check=True) + + # Get the current branch or commit hash + current_branch_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode() + tag_branch_hash = subprocess.check_output(["git", "rev-parse", branch_or_tag]).strip().decode() + + if current_branch_hash != tag_branch_hash: + run_cmd = f"git checkout {branch_or_tag} --quiet" + # Log the command + log.debug(run_cmd) + + # Execute the checkout command + process = subprocess.Popen(run_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + output, error = process.communicate() + + if error: + log.warning(error.decode()) + else: + log.info(f"Checked out sd-scripts {branch_or_tag} successfully.") + else: + log.info(f"Current branch of sd-scripts is already at the required release {branch_or_tag}.") + except subprocess.CalledProcessError as e: + log.error(f"Error during Git operation: {e}") + finally: + os.chdir(original_dir) # Restore the original directory + # setup console and file logging def setup_logging(clean=False): # diff --git a/setup/setup_linux.py b/setup/setup_linux.py index a69e44222..76ea9dacf 100644 --- a/setup/setup_linux.py +++ b/setup/setup_linux.py @@ -27,6 +27,12 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce if __name__ == '__main__': setup_common.ensure_base_requirements() setup_common.setup_logging() + + setup_common.update_submodule() + + # setup_common.clone_or_checkout( + # "https://github.com/kohya-ss/sd-scripts.git", tag_version, "sd-scripts" + # ) parser = argparse.ArgumentParser() parser.add_argument('--platform-requirements-file', dest='platform_requirements_file', default='requirements_linux.txt', help='Path to the platform-specific requirements file') diff --git a/setup/setup_windows.py b/setup/setup_windows.py index 4dc1c0329..b250f2815 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -15,26 +15,28 @@ def cudnn_install(): - cudnn_src = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "..\cudnn_windows" - ) + # Original path with "..\\venv" + original_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..\\venv\\Lib\\site-packages\\nvidia\\cudnn\\bin") + # Normalize the path to resolve "..\\venv" + cudnn_src = os.path.abspath(original_path) cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib") - log.info(f"Checking for CUDNN files in {cudnn_dest}...") + log.info(f"Copying CUDNN files from {cudnn_src} to {cudnn_dest}...") if os.path.exists(cudnn_src): if os.path.exists(cudnn_dest): # check for different files filecmp.clear_cache() for file in os.listdir(cudnn_src): - src_file = os.path.join(cudnn_src, file) - dest_file = os.path.join(cudnn_dest, file) - # if dest file exists, check if it's different - if os.path.exists(dest_file): - if not filecmp.cmp(src_file, dest_file, shallow=False): + if file.lower().endswith('.dll'): # Check if the file is a .dll file + src_file = os.path.join(cudnn_src, file) + dest_file = os.path.join(cudnn_dest, file) + # if dest file exists, check if it's different + if os.path.exists(dest_file): + if not filecmp.cmp(src_file, dest_file, shallow=False): + shutil.copy2(src_file, cudnn_dest) + else: shutil.copy2(src_file, cudnn_dest) - else: - shutil.copy2(src_file, cudnn_dest) - log.info("Copied CUDNN 8.6 files to destination") + log.info("Copied CUDNN .dll files to destination") else: log.warning(f"Destination directory {cudnn_dest} does not exist") else: @@ -100,6 +102,8 @@ def install_kohya_ss_torch2(): setup_common.check_repo_version() setup_common.check_python() + setup_common.update_submodule() + # Upgrade pip if needed setup_common.install("--upgrade pip") @@ -107,7 +111,8 @@ def install_kohya_ss_torch2(): "requirements_windows_torch2.txt", check_no_verify_flag=False ) - sync_bits_and_bytes_files() + # sync_bits_and_bytes_files() + setup_common.configure_accelerate(run_accelerate=True) # run_cmd(f'accelerate config') @@ -138,13 +143,20 @@ def install_bitsandbytes_0_41_1(): reinstall=True, ) +def install_bitsandbytes_0_41_2(): + log.info("Installing bitsandbytes 0.41.1...") + setup_common.install( + "--upgrade https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl", + "bitsandbytes 0.41.2", + reinstall=True, + ) def main_menu(): setup_common.clear_screen() while True: print("\nKohya_ss GUI setup menu:\n") print("1. Install kohya_ss gui") - print("2. (Optional) Install cudnn files (avoid unless you really need it)") + print("2. (Optional) Install cudnn files (if you want to use latest supported cudnn version)") print("3. (Optional) Install specific bitsandbytes versions") print("4. (Optional) Manually configure accelerate") print("5. (Optional) Start Kohya_ss GUI in browser") @@ -167,9 +179,12 @@ def main_menu(): "3. (Optional) Force installation of bitsandbytes 0.41.1 for new optimizer options support" ) print( - "4. (Danger) Install bitsandbytes-windows (this package has been reported to cause issues for most... avoid...)" + "4. (Recommanded) Force installation of bitsandbytes 0.41.2 for new optimizer options support" + ) + print( + "5. (Danger) Install bitsandbytes-windows (this package has been reported to cause issues for most... avoid...)" ) - print("5. Cancel") + print("6. Exit") choice_torch = input("\nEnter your choice: ") print("") @@ -182,12 +197,15 @@ def main_menu(): elif choice_torch == "3": install_bitsandbytes_0_41_1() break - elif choice_torch == "4": + elif choice_torch == "3": + install_bitsandbytes_0_41_2() + break + elif choice_torch == "5": setup_common.install( "--upgrade bitsandbytes-windows", reinstall=True ) break - elif choice_torch == "5": + elif choice_torch == "6": break else: print("Invalid choice. Please enter a number between 1-3.") diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index e24b84d78..df88689f2 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -107,6 +107,8 @@ def main(): torch_ver = check_torch() + setup_common.update_submodule() + if args.requirements: setup_common.install_requirements(args.requirements, check_no_verify_flag=True) else: diff --git a/style.css b/style.css index 78d34a947..456433e9d 100644 --- a/style.css +++ b/style.css @@ -1,5 +1,4 @@ #open_folder_small{ - height: auto; min-width: auto; flex-grow: 0; padding-left: 0.25em; diff --git a/test/config/dreambooth-Prodigy-SDXL.json b/test/config/dreambooth-Prodigy-SDXL.json new file mode 100644 index 000000000..239122b15 --- /dev/null +++ b/test/config/dreambooth-Prodigy-SDXL.json @@ -0,0 +1,91 @@ +{ + "adaptive_noise_scale": 0, + "additional_parameters": "", + "bucket_no_upscale": true, + "bucket_reso_steps": 32, + "cache_latents": true, + "cache_latents_to_disk": false, + "caption_dropout_every_n_epochs": 0.0, + "caption_dropout_rate": 0, + "caption_extension": "", + "clip_skip": 2, + "color_aug": false, + "enable_bucket": true, + "epoch": 1, + "flip_aug": false, + "full_bf16": false, + "full_fp16": false, + "gpu_ids": "", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": false, + "keep_tokens": "0", + "learning_rate": 1.0, + "learning_rate_te": 1e-05, + "learning_rate_te1": 1e-05, + "learning_rate_te2": 0.0, + "logging_dir": "./test/logs", + "lr_scheduler": "cosine", + "lr_scheduler_args": "", + "lr_scheduler_num_cycles": "", + "lr_scheduler_power": "", + "lr_warmup": 0, + "max_bucket_reso": 2048, + "max_data_loader_n_workers": "0", + "max_resolution": "512,512", + "max_timestep": 1000, + "max_token_length": "75", + "max_train_epochs": "", + "max_train_steps": "", + "mem_eff_attn": false, + "min_bucket_reso": 256, + "min_snr_gamma": 0, + "min_timestep": 0, + "mixed_precision": "bf16", + "model_list": "stabilityai/stable-diffusion-xl-base-1.0", + "multi_gpu": false, + "multires_noise_discount": 0.2, + "multires_noise_iterations": 8, + "no_token_padding": false, + "noise_offset": "0.05", + "noise_offset_type": "Multires", + "num_cpu_threads_per_process": 2, + "num_machines": 1, + "num_processes": 1, + "optimizer": "Prodigy", + "optimizer_args": "decouple=True weight_decay=0.6 betas=0.9,0.99 use_bias_correction=True", + "output_dir": "./test/output", + "output_name": "db-Prodigy", + "persistent_data_loader_workers": false, + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0", + "prior_loss_weight": 1.0, + "random_crop": false, + "reg_data_dir": "", + "resume": "", + "sample_every_n_epochs": 0, + "sample_every_n_steps": 25, + "sample_prompts": "a painting of a gas mask , by darius kawasaki", + "sample_sampler": "euler_a", + "save_every_n_epochs": 1, + "save_every_n_steps": 0, + "save_last_n_steps": 0, + "save_last_n_steps_state": 0, + "save_model_as": "safetensors", + "save_precision": "fp16", + "save_state": false, + "scale_v_pred_loss_like_noise_pred": false, + "sdxl": true, + "seed": "1234", + "shuffle_caption": false, + "stop_text_encoder_training": 0, + "train_batch_size": 1, + "train_data_dir": "./test/img", + "use_wandb": false, + "v2": false, + "v_parameterization": false, + "v_pred_like_loss": 0, + "vae": "", + "vae_batch_size": 0, + "wandb_api_key": "", + "weighted_captions": false, + "xformers": "xformers" +} \ No newline at end of file diff --git a/tools/blip2-for-sd/README.md b/tools/blip2-for-sd/README.md deleted file mode 100644 index 286d28159..000000000 --- a/tools/blip2-for-sd/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# blip2-for-sd - -source: https://github.com/Talmendo/blip2-for-sd - -Simple script to make BLIP2 output image description in a format suitable for Stable Diffusion. - -Format followed is roughly -`[STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]` - -## Usage -- Install dependencies according to requirements.txt - -- run main.py -`python main.py` - -The default model will be loaded automatically from huggingface. -You will be presented with an input to specify the folder to process after the model is loaded. - -Screenshot 2023-08-04 102650 - - -- The image or source folder should have the following structure: - -![Screenshot 2023-08-04 102544](https://github.com/Talmendo/blip2-for-sd/assets/141401796/eea9c2b0-e96a-40e4-8a6d-32dd7aa3e802) - - -Each folder represents a base prompt to be used for every image inside. - -- You can adjust BLIP2 settings in `caption_processor.py` inbetween runs, without having to stop the script. Just update it before inputting the new source folder. - -## Models -Default model is `Salesforce/blip2-opt-2.7b`, works quite well and doesn't require much VRAM. -Also tested with `Salesforce/blip2-opt-6.7b-coco` which seems to gives better results at the cost of much more VRAM and a large download (~30GB). \ No newline at end of file diff --git a/tools/blip2-for-sd/caption_processor.py b/tools/blip2-for-sd/caption_processor.py deleted file mode 100644 index 8de18c33b..000000000 --- a/tools/blip2-for-sd/caption_processor.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import re - -class CaptionProcessor: - def __init__(self, model, processor, device): - self.model = model - self.processor = processor - self.device = device - - def gen(self, inputs, max_length=10, min_length=0, top_k=30, top_p=0.92, num_beams=4): - return self.model.generate( - **inputs, - # max_new_tokens=25, # Number of tokens to generate - max_length=max_length, # Maximum length of the sequence to be generated, mutually exclusive with max_new_tokens - num_beams=num_beams, # Number of beams to use for beam search - num_return_sequences=1, # Number of captions to generate - early_stopping=True, # Stop when no new tokens are generated - repetition_penalty=1.5, # Penalize repeated words - no_repeat_ngram_size=2, # Number of words that can be repeated - # do_sample=True, # Introduce randomness to captions - # temperature=0.9, # Measure of randomness 0-1, 0 means no randomness - top_k=top_k, # Number of highest probability tokens to keep, 0 means no filtering - top_p=top_p, # Probability threshold, 0 means no filtering - min_length=min_length, # Minimum length of the sequence to be generated - ) - - def process(self, prompt, image): - return self.processor(image, text=prompt, return_tensors="pt").to(self.device, torch.float16) - - def caption_from(self, generated): - caption_list = self.processor.batch_decode(generated, skip_special_tokens=True) - caption_list = [caption.strip() for caption in caption_list] - return caption_list if len(caption_list) > 1 else caption_list[0] - - def sanitise_caption(self, caption): - return caption.replace(" - ", "-") - - # TODO this needs some more work - def sanitise_prompt_shard(self, prompt): - # Remove everything after "Answer:" - prompt = prompt.split("Answer:")[0].strip() - - # Define a pattern for multiple replacements - replacements = [ - (r", a point and shoot(?: camera)?", ""), # Matches ", a point and shoot" with optional " camera" - (r"it is a ", ""), - (r"it is ", ""), - (r"hair hair", "hair"), - (r"wearing nothing", "nude"), - (r"She's ", ""), - (r"She is ", "") - ] - - # Apply the replacements using regex - for pattern, replacement in replacements: - prompt = re.sub(pattern, replacement, prompt) - - return prompt - - def ask(self, question, image): - return self.sanitise_prompt_shard(self.caption_from(self.gen(self.process(f"Question: {question} Answer:", image)))) - - def caption_me(self, initial_prompt, image): - prompt = "" - - try: - # [STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER] - # print("\n") - hair_color = self.ask("What is her hair color?", image) - hair_length = self.ask("What is her hair length?", image) - p_hair = f"{hair_color} {hair_length} hair" - # print(p_hair) - - p_style = self.ask("Between the choices selfie, mirror selfie, candid, professional portrait what is the style of the photo?", image) - # print(p_style) - - p_clothing = self.ask("What is she wearing if anything?", image) - # print(p_clothing) - - p_action = self.ask("What is she doing? Could be something like standing, stretching, walking, squatting, etc", image) - # print(p_action) - - p_framing = self.ask("Between the choices close up, upper body shot, full body shot what is the framing of the photo?", image) - # print(p_framing) - - p_setting = self.ask("Where is she? Be descriptive and detailed", image) - # print(p_setting) - - p_lighting = self.ask("What is the scene lighting like? For example: soft lighting, studio lighting, natural lighting", image) - # print(p_lighting) - - p_angle = self.ask("What angle is the picture taken from? Be succinct, like: from the side, from below, from front", image) - # print(p_angle) - - p_camera = self.ask("What kind of camera could this picture have been taken with? Be specific and guess a brand with specific camera type", image) - # print(p_camera) - - # prompt = self.sanitise_caption(f"{p_style}, {initial_prompt} with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}") - prompt = self.sanitise_caption(f"{p_style}, with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}") - - return prompt - except Exception as e: - print(e) - - return prompt \ No newline at end of file diff --git a/tools/blip2-for-sd/main.py b/tools/blip2-for-sd/main.py deleted file mode 100644 index fc83bceb8..000000000 --- a/tools/blip2-for-sd/main.py +++ /dev/null @@ -1,89 +0,0 @@ -import requests, torch, sys, os -import argparse - -from importlib import reload -from PIL import Image -from transformers import AutoProcessor, Blip2ForConditionalGeneration -from tqdm import tqdm - -import caption_processor - -model = None -processor = None -device = None - -def load_model(model_name="Salesforce/blip2-opt-2.7b"): - global model, processor, device - - print("Loading Model") - processor = AutoProcessor.from_pretrained(model_name) - model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16) - - if torch.cuda.is_available(): - print("CUDA available, using GPU") - device = "cuda" - else: - print("CUDA not available, using CPU") - device = "cpu" - - print("Moving model to device") - model.to(device) - -def main(path): - # reloading caption_processor to enable us to change its values in between executions - # without having to reload the model, which can take very long - # probably cleaner to do this with a config file and just reload that - # but this works for now - reload(caption_processor) - prompt_file_dict = {} - - # list all sub dirs in path - sub_dirs = [dir for dir in os.listdir(path) if os.path.isdir(os.path.join(path, dir))] - - print("Reading prompts from sub dirs and finding image files") - for prompt in sub_dirs: - prompt_file_dict[prompt] = [file for file in os.listdir(os.path.join(path, prompt)) if file.endswith((".jpg", ".png", ".jpeg", ".webp"))] - - for prompt, file_list in prompt_file_dict.items(): - print(f"Found {str(len(file_list))} files for prompt \"{prompt}\"") - - for prompt, file_list in prompt_file_dict.items(): - total = len(file_list) - - for file in tqdm(file_list): - # read image - image = Image.open(os.path.join(path, prompt, file)) - - caption = "" - # generate caption - try: - caption = caption_processor.CaptionProcessor(model, processor, device).caption_me(prompt, image) - except: - print("Error creating caption for file: " + file) - - # save caption to file - # file without extension - with open(os.path.join(path, prompt, os.path.splitext(file)[0] + ".txt"), "w", encoding="utf-8") as f: - f.write(caption) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Enter the path to the file") - parser.add_argument("path", type=str, nargs='?', default="", help="Path to the file") - parser.add_argument("--interactive", action="store_true", help="Interactive mode") - - args = parser.parse_args() - interactive = args.interactive - - load_model(model_name="Salesforce/blip2-opt-2.7b") - - if interactive: - while True: - path = input("Enter path: ") - main(path) - continue_prompt = input("Continue? (y/n): ") - if continue_prompt.lower() != 'y': - break - else: - path = args.path - search_subdirectories = False - main(path) diff --git a/tools/blip2-for-sd/requirements.txt b/tools/blip2-for-sd/requirements.txt deleted file mode 100644 index d0e05a320..000000000 --- a/tools/blip2-for-sd/requirements.txt +++ /dev/null @@ -1,28 +0,0 @@ ---extra-index-url https://download.pytorch.org/whl/cu118 -accelerate==0.21.0 -certifi==2023.7.22 -charset-normalizer==3.2.0 -colorama==0.4.6 -filelock==3.12.2 -fsspec==2023.6.0 -huggingface-hub==0.16.4 -idna==3.4 -Jinja2==3.1.2 -MarkupSafe==2.1.3 -mpmath==1.3.0 -networkx==3.1 -numpy==1.25.2 -packaging==23.1 -Pillow==10.0.0 -psutil==5.9.5 -PyYAML==6.0.1 -regex==2023.6.3 -requests==2.31.0 -safetensors==0.3.1 -sympy==1.12 -tokenizers==0.13.3 -torch==2.0.1+cu118 -tqdm==4.65.0 -transformers==4.31.0 -typing_extensions==4.7.1 -urllib3==2.0.4 \ No newline at end of file diff --git a/tools/cache_latents.py b/tools/cache_latents.py deleted file mode 100644 index 347db27f7..000000000 --- a/tools/cache_latents.py +++ /dev/null @@ -1,197 +0,0 @@ -# latentsのdiskへの事前キャッシュを行う / cache latents to disk - -import argparse -import math -from multiprocessing import Value -import os - -from accelerate.utils import set_seed -import torch -from tqdm import tqdm - -from library import config_util -from library import train_util -from library import sdxl_train_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def cache_to_disk(args: argparse.Namespace) -> None: - train_util.prepare_dataset_args(args, True) - - # check cache latents arg - assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" - - use_dreambooth_method = args.in_json is None - - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する - - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - # datasetのcache_latentsを呼ばなければ、生の画像が返る - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, _ = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - logger.info("load model") - if args.sdxl: - (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) - else: - _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - - # dataloaderを準備する - train_dataset_group.set_caching_mode("latents") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - b_size = len(batch["images"]) - vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size - flip_aug = batch["flip_aug"] - random_crop = batch["random_crop"] - bucket_reso = batch["bucket_reso"] - - # バッチを分割して処理する - for i in range(0, b_size, vae_batch_size): - images = batch["images"][i : i + vae_batch_size] - absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] - resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] - - image_infos = [] - for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.image = image - image_info.bucket_reso = bucket_reso - image_info.resized_size = resized_size - image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" - - if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): - logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") - continue - - image_infos.append(image_info) - - if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) - - accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - train_util.add_sd_models_arguments(parser) - train_util.add_training_arguments(parser, True) - train_util.add_dataset_arguments(parser, True, True, True) - config_util.add_config_arguments(parser) - parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - parser.add_argument( - "--skip_existing", - action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", - ) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - cache_to_disk(args) diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py deleted file mode 100644 index 5f1d6d201..000000000 --- a/tools/cache_text_encoder_outputs.py +++ /dev/null @@ -1,194 +0,0 @@ -# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance - -import argparse -import math -from multiprocessing import Value -import os - -from accelerate.utils import set_seed -import torch -from tqdm import tqdm - -from library import config_util -from library import train_util -from library import sdxl_train_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def cache_to_disk(args: argparse.Namespace) -> None: - train_util.prepare_dataset_args(args, True) - - # check cache arg - assert ( - args.cache_text_encoder_outputs_to_disk - ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" - - # できるだけ準備はしておくが今のところSDXLのみしか動かない - assert ( - args.sdxl - ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" - - use_dreambooth_method = args.in_json is None - - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する - - # tokenizerを準備する:datasetを動かすために必要 - if args.sdxl: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - tokenizers = [tokenizer1, tokenizer2] - else: - tokenizer = train_util.load_tokenizer(args) - tokenizers = [tokenizer] - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, _ = train_util.prepare_dtype(args) - - # モデルを読み込む - logger.info("load model") - if args.sdxl: - (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) - text_encoders = [text_encoder1, text_encoder2] - else: - text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder1] - - for text_encoder in text_encoders: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - - # dataloaderを準備する - train_dataset_group.set_caching_mode("text") - - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず - train_dataloader = accelerator.prepare(train_dataloader) - - # データ取得のためのループ - for batch in tqdm(train_dataloader): - absolute_paths = batch["absolute_paths"] - input_ids1_list = batch["input_ids1_list"] - input_ids2_list = batch["input_ids2_list"] - - image_infos = [] - for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): - image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) - image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX - image_info - - if args.skip_existing: - if os.path.exists(image_info.text_encoder_outputs_npz): - logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") - continue - - image_info.input_ids1 = input_ids1 - image_info.input_ids2 = input_ids2 - image_infos.append(image_info) - - if len(image_infos) > 0: - b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) - b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) - train_util.cache_batch_text_encoder_outputs( - image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype - ) - - accelerator.wait_for_everyone() - accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - train_util.add_sd_models_arguments(parser) - train_util.add_training_arguments(parser, True) - train_util.add_dataset_arguments(parser, True, True, True) - config_util.add_config_arguments(parser) - sdxl_train_util.add_sdxl_training_arguments(parser) - parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") - parser.add_argument( - "--skip_existing", - action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", - ) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - cache_to_disk(args) diff --git a/tools/canny.py b/tools/canny.py deleted file mode 100644 index f2190975c..000000000 --- a/tools/canny.py +++ /dev/null @@ -1,34 +0,0 @@ -import argparse -import cv2 - -import logging -from library.utils import setup_logging -setup_logging() -logger = logging.getLogger(__name__) - -def canny(args): - img = cv2.imread(args.input) - img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - - canny_img = cv2.Canny(img, args.thres1, args.thres2) - # canny_img = 255 - canny_img - - cv2.imwrite(args.output, canny_img) - logger.info("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--input", type=str, default=None, help="input path") - parser.add_argument("--output", type=str, default=None, help="output path") - parser.add_argument("--thres1", type=int, default=32, help="thres1") - parser.add_argument("--thres2", type=int, default=224, help="thres2") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - canny(args) diff --git a/tools/convert_diffusers20_original_sd.md b/tools/convert_diffusers20_original_sd.md deleted file mode 100644 index 4763e5fd5..000000000 --- a/tools/convert_diffusers20_original_sd.md +++ /dev/null @@ -1,46 +0,0 @@ -# How to use - -##Diffusers to Stable Diffusion .ckpt conversion - -Specify the folder of the source model and the destination .ckpt file as follows (actually written on one line). The v1/v2 version is automatically determined. - -``` -python convert_diffusers20_original_sd.py ..\models\diffusers_model - ..\models\sd.ckpt -``` - -Note that v2 Diffusers' Text Encoder has only 22 layers, and if you convert it to Stable Diffusion as it is, the weights will be insufficient, so the weights of the 22 layers will be copied as the 23rd layer. The weight of the 23rd layer is not used during image generation, so it has no effect. Similarly, text_projection logit_scale also adds dummy weights (it doesn't seem to be used for image generation). - -## Stable Diffusion .ckpt to Diffusers - -Enter the following: - -``` -python convert_diffusers20_original_sd.py ..\models\sd.ckpt - ..\models\diffusers_model - --v2 --reference_model stabilityai/stable-diffusion-2 -``` - -Specify the .ckpt file and the destination folder as arguments. -Model judgment is not possible, so please use the `--v1` option or the `--v2` option depending on the model. - -Also, since `.ckpt` does not contain scheduler and tokenizer information, you need to copy them from some existing Diffusers model. Please specify with `--reference_model`. You can specify the HuggingFace id or a local model directory. - -If you don't have a local model, you can specify "stabilityai/stable-diffusion-2" or "stabilityai/stable-diffusion-2-base" for v2. -For v1.4/1.5, "CompVis/stable-diffusion-v1-4" is fine (v1.4 and v1.5 seem to be the same). - -## What can you do? - -`--fp16 / --bf16 / --float` - -You can specify the data format when saving checkpoint. --fp16 only, also valid when loading Diffusers models. - -`--epoch / --global_step` - -When saving checkpoint, write epoch and global_step with the specified values. If not specified, both will be 0. - -## Conclusion - -Some people may be troubled by the Diffusers model due to the poor inference environment. I hope it helps a little. - -(Note that converting the data format from checkpoint to checkpoint is also possible, although it has not been tested.) ) \ No newline at end of file diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py deleted file mode 100644 index 572ee2f0c..000000000 --- a/tools/convert_diffusers20_original_sd.py +++ /dev/null @@ -1,163 +0,0 @@ -# convert Diffusers v1.x/v2.0 model to original Stable Diffusion - -import argparse -import os -import torch -from diffusers import StableDiffusionPipeline - -import library.model_util as model_util -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def convert(args): - # 引数を確認する - load_dtype = torch.float16 if args.fp16 else None - - save_dtype = None - if args.fp16 or args.save_precision_as == "fp16": - save_dtype = torch.float16 - elif args.bf16 or args.save_precision_as == "bf16": - save_dtype = torch.bfloat16 - elif args.float or args.save_precision_as == "float": - save_dtype = torch.float - - is_load_ckpt = os.path.isfile(args.model_to_load) - is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 - - assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" - # assert ( - # is_save_ckpt or args.reference_model is not None - # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" - - # モデルを読み込む - msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - logger.info(f"loading {msg}: {args.model_to_load}") - - if is_load_ckpt: - v2_model = args.v2 - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( - v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection - ) - else: - pipe = StableDiffusionPipeline.from_pretrained( - args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant - ) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - - if args.v1 == args.v2: - # 自動判定する - v2_model = unet.config.cross_attention_dim == 1024 - logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) - else: - v2_model = not args.v1 - - # 変換して保存する - msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - logger.info(f"converting and saving as {msg}: {args.model_to_save}") - - if is_save_ckpt: - original_model = args.model_to_load if is_load_ckpt else None - key_count = model_util.save_stable_diffusion_checkpoint( - v2_model, - args.model_to_save, - text_encoder, - unet, - original_model, - args.epoch, - args.global_step, - None if args.metadata is None else eval(args.metadata), - save_dtype=save_dtype, - vae=vae, - ) - logger.info(f"model saved. total converted state_dict keys: {key_count}") - else: - logger.info( - f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" - ) - model_util.save_diffusers_checkpoint( - v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors - ) - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument( - "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む" - ) - parser.add_argument( - "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" - ) - parser.add_argument( - "--unet_use_linear_projection", - action="store_true", - help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)", - ) - parser.add_argument( - "--fp16", - action="store_true", - help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)", - ) - parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)") - parser.add_argument( - "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)" - ) - parser.add_argument( - "--save_precision_as", - type=str, - default="no", - choices=["fp16", "bf16", "float"], - help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください", - ) - parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値") - parser.add_argument( - "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" - ) - parser.add_argument( - "--metadata", - type=str, - default=None, - help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'', - ) - parser.add_argument( - "--variant", - type=str, - default=None, - help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16", - ) - parser.add_argument( - "--reference_model", - type=str, - default=None, - help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`", - ) - parser.add_argument( - "--use_safetensors", - action="store_true", - help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)", - ) - - parser.add_argument( - "model_to_load", - type=str, - default=None, - help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ", - ) - parser.add_argument( - "model_to_save", - type=str, - default=None, - help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存", - ) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - convert(args) diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py deleted file mode 100644 index bbc643edc..000000000 --- a/tools/detect_face_rotate.py +++ /dev/null @@ -1,250 +0,0 @@ -# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - -# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す - -# v2: extract max face if multiple faces are found -# v3: add crop_ratio option -# v4: add multiple faces extraction and min/max size - -import argparse -import math -import cv2 -import glob -import os -from anime_face_detector import create_detector -from tqdm import tqdm -import numpy as np -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -KP_REYE = 11 -KP_LEYE = 19 - -SCORE_THRES = 0.90 - - -def detect_faces(detector, image, min_size): - preds = detector(image) # bgr - # logger.info(len(preds)) - - faces = [] - for pred in preds: - bb = pred['bbox'] - score = bb[-1] - if score < SCORE_THRES: - continue - - left, top, right, bottom = bb[:4] - cx = int((left + right) / 2) - cy = int((top + bottom) / 2) - fw = int(right - left) - fh = int(bottom - top) - - lex, ley = pred['keypoints'][KP_LEYE, 0:2] - rex, rey = pred['keypoints'][KP_REYE, 0:2] - angle = math.atan2(ley - rey, lex - rex) - angle = angle / math.pi * 180 - - faces.append((cx, cy, fw, fh, angle)) - - faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順 - return faces - - -def rotate_image(image, angle, cx, cy): - h, w = image.shape[0:2] - rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) - - # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 - # nh = max(h, int(w * math.sin(angle))) - # nw = max(w, int(h * math.sin(angle))) - # if nh > h or nw > w: - # pad_y = nh - h - # pad_t = pad_y // 2 - # pad_x = nw - w - # pad_l = pad_x // 2 - # m = np.array([[0, 0, pad_l], - # [0, 0, pad_t]]) - # rot_mat = rot_mat + m - # h, w = nh, nw - # cx += pad_l - # cy += pad_t - - result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) - return result, cx, cy - - -def process(args): - assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" - assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" - - # アニメ顔検出モデルを読み込む - logger.info("loading face detector.") - detector = create_detector('yolov3') - - # cropの引数を解析する - if args.crop_size is None: - crop_width = crop_height = None - else: - tokens = args.crop_size.split(',') - assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" - crop_width, crop_height = [int(t) for t in tokens] - - if args.crop_ratio is None: - crop_h_ratio = crop_v_ratio = None - else: - tokens = args.crop_ratio.split(',') - assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" - crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] - - # 画像を処理する - logger.info("processing.") - output_extension = ".png" - - os.makedirs(args.dst_dir, exist_ok=True) - paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ - glob.glob(os.path.join(args.src_dir, "*.webp")) - for path in tqdm(paths): - basename = os.path.splitext(os.path.basename(path))[0] - - # image = cv2.imread(path) # 日本語ファイル名でエラーになる - image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) - if len(image.shape) == 2: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - if image.shape[2] == 4: - logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") - image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい - - h, w = image.shape[:2] - - faces = detect_faces(detector, image, args.multiple_faces) - for i, face in enumerate(faces): - cx, cy, fw, fh, angle = face - face_size = max(fw, fh) - if args.min_size is not None and face_size < args.min_size: - continue - if args.max_size is not None and face_size >= args.max_size: - continue - face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" - - # オプション指定があれば回転する - face_img = image - if args.rotate: - face_img, cx, cy = rotate_image(face_img, angle, cx, cy) - - # オプション指定があれば顔を中心に切り出す - if crop_width is not None or crop_h_ratio is not None: - cur_crop_width, cur_crop_height = crop_width, crop_height - if crop_h_ratio is not None: - cur_crop_width = int(face_size * crop_h_ratio + .5) - cur_crop_height = int(face_size * crop_v_ratio + .5) - - # リサイズを必要なら行う - scale = 1.0 - if args.resize_face_size is not None: - # 顔サイズを基準にリサイズする - scale = args.resize_face_size / face_size - if scale < cur_crop_width / w: - logger.warning( - f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") - scale = cur_crop_width / w - if scale < cur_crop_height / h: - logger.warning( - f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") - scale = cur_crop_height / h - elif crop_h_ratio is not None: - # 倍率指定の時にはリサイズしない - pass - else: - # 切り出しサイズ指定あり - if w < cur_crop_width: - logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") - scale = cur_crop_width / w - if h < cur_crop_height: - logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") - scale = cur_crop_height / h - if args.resize_fit: - scale = max(cur_crop_width / w, cur_crop_height / h) - - if scale != 1.0: - w = int(w * scale + .5) - h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) - cx = int(cx * scale + .5) - cy = int(cy * scale + .5) - fw = int(fw * scale + .5) - fh = int(fh * scale + .5) - - cur_crop_width = min(cur_crop_width, face_img.shape[1]) - cur_crop_height = min(cur_crop_height, face_img.shape[0]) - - x = cx - cur_crop_width // 2 - cx = cur_crop_width // 2 - if x < 0: - cx = cx + x - x = 0 - elif x + cur_crop_width > w: - cx = cx + (x + cur_crop_width - w) - x = w - cur_crop_width - face_img = face_img[:, x:x+cur_crop_width] - - y = cy - cur_crop_height // 2 - cy = cur_crop_height // 2 - if y < 0: - cy = cy + y - y = 0 - elif y + cur_crop_height > h: - cy = cy + (y + cur_crop_height - h) - y = h - cur_crop_height - face_img = face_img[y:y + cur_crop_height] - - # # debug - # logger.info(path, cx, cy, angle) - # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) - # cv2.imshow("image", crp) - # if cv2.waitKey() == 27: - # break - # cv2.destroyAllWindows() - - # debug - if args.debug: - cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) - - _, buf = cv2.imencode(output_extension, face_img) - with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: - buf.tofile(f) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") - parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") - parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") - parser.add_argument("--resize_fit", action="store_true", - help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") - parser.add_argument("--resize_face_size", type=int, default=None, - help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") - parser.add_argument("--crop_size", type=str, default=None, - help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") - parser.add_argument("--crop_ratio", type=str, default=None, - help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") - parser.add_argument("--min_size", type=int, default=None, - help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)") - parser.add_argument("--max_size", type=int, default=None, - help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)") - parser.add_argument("--multiple_faces", action="store_true", - help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") - parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - - process(args) diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py deleted file mode 100644 index f05cf7194..000000000 --- a/tools/latent_upscaler.py +++ /dev/null @@ -1,354 +0,0 @@ -# 外部から簡単にupscalerを呼ぶためのスクリプト -# 単体で動くようにモデル定義も含めている - -import argparse -import glob -import os -import cv2 -from diffusers import AutoencoderKL - -from typing import Dict, List -import numpy as np - -import torch -from library.device_utils import init_ipex, get_preferred_device -init_ipex() - -from torch import nn -from tqdm import tqdm -from PIL import Image -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -class ResidualBlock(nn.Module): - def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): - super(ResidualBlock, self).__init__() - - if out_channels is None: - out_channels = in_channels - - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False) - self.bn2 = nn.BatchNorm2d(out_channels) - - self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも - - # initialize weights - self._initialize_weights() - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.constant_(m.bias, 0) - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu1(out) - - out = self.conv2(out) - out = self.bn2(out) - - out += residual - - out = self.relu2(out) - - return out - - -class Upscaler(nn.Module): - def __init__(self): - super(Upscaler, self).__init__() - - # define layers - # latent has 4 channels - - self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - self.bn1 = nn.BatchNorm2d(128) - self.relu1 = nn.ReLU(inplace=True) - - # resblocks - # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ - self.resblock1 = ResidualBlock(128) - self.resblock2 = ResidualBlock(128) - self.resblock3 = ResidualBlock(128) - self.resblock4 = ResidualBlock(128) - self.resblock5 = ResidualBlock(128) - self.resblock6 = ResidualBlock(128) - self.resblock7 = ResidualBlock(128) - self.resblock8 = ResidualBlock(128) - self.resblock9 = ResidualBlock(128) - self.resblock10 = ResidualBlock(128) - self.resblock11 = ResidualBlock(128) - self.resblock12 = ResidualBlock(128) - self.resblock13 = ResidualBlock(128) - self.resblock14 = ResidualBlock(128) - self.resblock15 = ResidualBlock(128) - self.resblock16 = ResidualBlock(128) - self.resblock17 = ResidualBlock(128) - self.resblock18 = ResidualBlock(128) - self.resblock19 = ResidualBlock(128) - self.resblock20 = ResidualBlock(128) - - # last convs - self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - self.bn2 = nn.BatchNorm2d(64) - self.relu2 = nn.ReLU(inplace=True) - - self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - self.bn3 = nn.BatchNorm2d(64) - self.relu3 = nn.ReLU(inplace=True) - - # final conv: output 4 channels - self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0)) - - # initialize weights - self._initialize_weights() - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, 0, 0.01) - nn.init.constant_(m.bias, 0) - - # initialize final conv weights to 0: 流行りのzero conv - nn.init.constant_(self.conv_final.weight, 0) - - def forward(self, x): - inp = x - - x = self.conv1(x) - x = self.bn1(x) - x = self.relu1(x) - - # いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず - residual = x - x = self.resblock1(x) - x = self.resblock2(x) - x = self.resblock3(x) - x = self.resblock4(x) - x = x + residual - residual = x - x = self.resblock5(x) - x = self.resblock6(x) - x = self.resblock7(x) - x = self.resblock8(x) - x = x + residual - residual = x - x = self.resblock9(x) - x = self.resblock10(x) - x = self.resblock11(x) - x = self.resblock12(x) - x = x + residual - residual = x - x = self.resblock13(x) - x = self.resblock14(x) - x = self.resblock15(x) - x = self.resblock16(x) - x = x + residual - residual = x - x = self.resblock17(x) - x = self.resblock18(x) - x = self.resblock19(x) - x = self.resblock20(x) - x = x + residual - - x = self.conv2(x) - x = self.bn2(x) - x = self.relu2(x) - x = self.conv3(x) - x = self.bn3(x) - - # ここにreluを入れないほうがいい気がする - - x = self.conv_final(x) - - # network estimates the difference between the input and the output - x = x + inp - - return x - - def support_latents(self) -> bool: - return False - - def upscale( - self, - vae: AutoencoderKL, - lowreso_images: List[Image.Image], - lowreso_latents: torch.Tensor, - dtype: torch.dtype, - width: int, - height: int, - batch_size: int = 1, - vae_batch_size: int = 1, - ): - # assertion - assert lowreso_images is not None, "Upscaler requires lowreso image" - - # make upsampled image with lanczos4 - upsampled_images = [] - for lowreso_image in lowreso_images: - upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS)) - upsampled_images.append(upsampled_image) - - # convert to tensor: this tensor is too large to be converted to cuda - upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images] - upsampled_images = torch.stack(upsampled_images, dim=0) - upsampled_images = upsampled_images.to(dtype) - - # normalize to [-1, 1] - upsampled_images = upsampled_images / 127.5 - 1.0 - - # convert upsample images to latents with batch size - # logger.info("Encoding upsampled (LANCZOS4) images...") - upsampled_latents = [] - for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): - batch = upsampled_images[i : i + vae_batch_size].to(vae.device) - with torch.no_grad(): - batch = vae.encode(batch).latent_dist.sample() - upsampled_latents.append(batch) - - upsampled_latents = torch.cat(upsampled_latents, dim=0) - - # upscale (refine) latents with this model with batch size - logger.info("Upscaling latents...") - upscaled_latents = [] - for i in range(0, upsampled_latents.shape[0], batch_size): - with torch.no_grad(): - upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size])) - upscaled_latents = torch.cat(upscaled_latents, dim=0) - - return upscaled_latents * 0.18215 - - -# external interface: returns a model -def create_upscaler(**kwargs): - weights = kwargs["weights"] - model = Upscaler() - - logger.info(f"Loading weights from {weights}...") - if os.path.splitext(weights)[1] == ".safetensors": - from safetensors.torch import load_file - - sd = load_file(weights) - else: - sd = torch.load(weights, map_location=torch.device("cpu")) - model.load_state_dict(sd) - return model - - -# another interface: upscale images with a model for given images from command line -def upscale_images(args: argparse.Namespace): - DEVICE = get_preferred_device() - us_dtype = torch.float16 # TODO: support fp32/bf16 - os.makedirs(args.output_dir, exist_ok=True) - - # load VAE with Diffusers - assert args.vae_path is not None, "VAE path is required" - logger.info(f"Loading VAE from {args.vae_path}...") - vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") - vae.to(DEVICE, dtype=us_dtype) - - # prepare model - logger.info("Preparing model...") - upscaler: Upscaler = create_upscaler(weights=args.weights) - # logger.info("Loading weights from", args.weights) - # upscaler.load_state_dict(torch.load(args.weights)) - upscaler.eval() - upscaler.to(DEVICE, dtype=us_dtype) - - # load images - image_paths = glob.glob(args.image_pattern) - images = [] - for image_path in image_paths: - image = Image.open(image_path) - image = image.convert("RGB") - - # make divisible by 8 - width = image.width - height = image.height - if width % 8 != 0: - width = width - (width % 8) - if height % 8 != 0: - height = height - (height % 8) - if width != image.width or height != image.height: - image = image.crop((0, 0, width, height)) - - images.append(image) - - # debug output - if args.debug: - for image, image_path in zip(images, image_paths): - image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS) - - basename = os.path.basename(image_path) - basename_wo_ext, ext = os.path.splitext(basename) - dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}") - image_debug.save(dest_file_name) - - # upscale - logger.info("Upscaling...") - upscaled_latents = upscaler.upscale( - vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size - ) - upscaled_latents /= 0.18215 - - # decode with batch - logger.info("Decoding...") - upscaled_images = [] - for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): - with torch.no_grad(): - batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample - batch = batch.to("cpu") - upscaled_images.append(batch) - upscaled_images = torch.cat(upscaled_images, dim=0) - - # tensor to numpy - upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy() - upscaled_images = (upscaled_images + 1.0) * 127.5 - upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8) - - upscaled_images = upscaled_images[..., ::-1] - - # save images - for i, image in enumerate(upscaled_images): - basename = os.path.basename(image_paths[i]) - basename_wo_ext, ext = os.path.splitext(basename) - dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}") - cv2.imwrite(dest_file_name, image) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--vae_path", type=str, default=None, help="VAE path") - parser.add_argument("--weights", type=str, default=None, help="Weights path") - parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern") - parser.add_argument("--output_dir", type=str, default=".", help="Output directory") - parser.add_argument("--batch_size", type=int, default=4, help="Batch size") - parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size") - parser.add_argument("--debug", action="store_true", help="Debug mode") - - args = parser.parse_args() - upscale_images(args) diff --git a/tools/merge_models.py b/tools/merge_models.py deleted file mode 100644 index 8f1fbf2f8..000000000 --- a/tools/merge_models.py +++ /dev/null @@ -1,171 +0,0 @@ -import argparse -import os - -import torch -from safetensors import safe_open -from safetensors.torch import load_file, save_file -from tqdm import tqdm -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def is_unet_key(key): - # VAE or TextEncoder, the last one is for SDXL - return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key) - - -TEXT_ENCODER_KEY_REPLACEMENTS = [ - ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), - ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), - ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), -] - - -# support for models with different text encoder keys -def replace_text_encoder_key(key): - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - if key.startswith(rep_from): - return True, rep_to + key[len(rep_from) :] - return False, key - - -def merge(args): - if args.precision == "fp16": - dtype = torch.float16 - elif args.precision == "bf16": - dtype = torch.bfloat16 - else: - dtype = torch.float - - if args.saving_precision == "fp16": - save_dtype = torch.float16 - elif args.saving_precision == "bf16": - save_dtype = torch.bfloat16 - else: - save_dtype = torch.float - - # check if all models are safetensors - for model in args.models: - if not model.endswith("safetensors"): - logger.info(f"Model {model} is not a safetensors model") - exit() - if not os.path.isfile(model): - logger.info(f"Model {model} does not exist") - exit() - - assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" - - # load and merge - ratio = 1.0 / len(args.models) # default - supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later - - merged_sd = None - first_model_keys = set() # check missing keys in other models - for i, model in enumerate(args.models): - if args.ratios is not None: - ratio = args.ratios[i] - - if merged_sd is None: - # load first model - logger.info(f"Loading model {model}, ratio = {ratio}...") - merged_sd = {} - with safe_open(model, framework="pt", device=args.device) as f: - for key in tqdm(f.keys()): - value = f.get_tensor(key) - _, key = replace_text_encoder_key(key) - - first_model_keys.add(key) - - if not is_unet_key(key) and args.unet_only: - supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder - continue - - value = ratio * value.to(dtype) # first model's value * ratio - merged_sd[key] = value - - logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) - continue - - # load other models - logger.info(f"Loading model {model}, ratio = {ratio}...") - - with safe_open(model, framework="pt", device=args.device) as f: - model_keys = f.keys() - for key in tqdm(model_keys): - _, new_key = replace_text_encoder_key(key) - if new_key not in merged_sd: - if args.show_skipped and new_key not in first_model_keys: - logger.info(f"Skip: {new_key}") - continue - - value = f.get_tensor(key) - merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype) - - # enumerate keys not in this model - model_keys = set(model_keys) - for key in merged_sd.keys(): - if key in model_keys: - continue - logger.warning(f"Key {key} not in model {model}, use first model's value") - if key in supplementary_key_ratios: - supplementary_key_ratios[key] += ratio - else: - supplementary_key_ratios[key] = ratio - - # add supplementary keys' value (including VAE and TextEncoder) - if len(supplementary_key_ratios) > 0: - logger.info("add first model's value") - with safe_open(args.models[0], framework="pt", device=args.device) as f: - for key in tqdm(f.keys()): - _, new_key = replace_text_encoder_key(key) - if new_key not in supplementary_key_ratios: - continue - - if is_unet_key(new_key): # not VAE or TextEncoder - logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") - - value = f.get_tensor(key) # original key - - if new_key not in merged_sd: - merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype) - else: - merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype) - - # save - output_file = args.output - if not output_file.endswith(".safetensors"): - output_file = output_file + ".safetensors" - - logger.info(f"Saving to {output_file}...") - - # convert to save_dtype - for k in merged_sd.keys(): - merged_sd[k] = merged_sd[k].to(save_dtype) - - save_file(merged_sd, output_file) - - logger.info("Done!") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Merge models") - parser.add_argument("--models", nargs="+", type=str, help="Models to merge") - parser.add_argument("--output", type=str, help="Output model") - parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0") - parser.add_argument("--unet_only", action="store_true", help="Only merge unet") - parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu") - parser.add_argument( - "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float" - ) - parser.add_argument( - "--saving_precision", - type=str, - default="float", - choices=["float", "fp16", "bf16"], - help="Saving precision, default is float", - ) - parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)") - - args = parser.parse_args() - merge(args) diff --git a/tools/original_control_net.py b/tools/original_control_net.py deleted file mode 100644 index 5640d542d..000000000 --- a/tools/original_control_net.py +++ /dev/null @@ -1,353 +0,0 @@ -from typing import List, NamedTuple, Any -import numpy as np -import cv2 -import torch -from safetensors.torch import load_file - -from library.original_unet import UNet2DConditionModel, SampleOutput - -import library.model_util as model_util -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -class ControlNetInfo(NamedTuple): - unet: Any - net: Any - prep: Any - weight: float - ratio: float - - -class ControlNet(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - # make control model - self.control_model = torch.nn.Module() - - dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] - zero_convs = torch.nn.ModuleList() - for i, dim in enumerate(dims): - sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) - zero_convs.append(sub_list) - self.control_model.add_module("zero_convs", zero_convs) - - middle_block_out = torch.nn.Conv2d(1280, 1280, 1) - self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) - - dims = [16, 16, 32, 32, 96, 96, 256, 320] - strides = [1, 1, 2, 1, 2, 1, 2, 1] - prev_dim = 3 - input_hint_block = torch.nn.Sequential() - for i, (dim, stride) in enumerate(zip(dims, strides)): - input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) - if i < len(dims) - 1: - input_hint_block.append(torch.nn.SiLU()) - prev_dim = dim - self.control_model.add_module("input_hint_block", input_hint_block) - - -def load_control_net(v2, unet, model): - device = unet.device - - # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む - # state dictを読み込む - logger.info(f"ControlNet: loading control SD model : {model}") - - if model_util.is_safetensors(model): - ctrl_sd_sd = load_file(model) - else: - ctrl_sd_sd = torch.load(model, map_location="cpu") - ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) - - # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む - is_difference = "difference" in ctrl_sd_sd - logger.info(f"ControlNet: loading difference: {is_difference}") - - # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく - # またTransfer Controlの元weightとなる - ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) - - # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける - for key in list(ctrl_unet_sd_sd.keys()): - ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() - - zero_conv_sd = {} - for key in list(ctrl_sd_sd.keys()): - if key.startswith("control_"): - unet_key = "model.diffusion_" + key[len("control_") :] - if unet_key not in ctrl_unet_sd_sd: # zero conv - zero_conv_sd[key] = ctrl_sd_sd[key] - continue - if is_difference: # Transfer Control - ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) - else: - ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) - - unet_config = model_util.create_unet_diffusers_config(v2) - ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict - - # ControlNetのU-Netを作成する - ctrl_unet = UNet2DConditionModel(**unet_config) - info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - logger.info(f"ControlNet: loading Control U-Net: {info}") - - # U-Net以外のControlNetを作成する - # TODO support middle only - ctrl_net = ControlNet() - info = ctrl_net.load_state_dict(zero_conv_sd) - logger.info("ControlNet: loading ControlNet: {info}") - - ctrl_unet.to(unet.device, dtype=unet.dtype) - ctrl_net.to(unet.device, dtype=unet.dtype) - return ctrl_unet, ctrl_net - - -def load_preprocess(prep_type: str): - if prep_type is None or prep_type.lower() == "none": - return None - - if prep_type.startswith("canny"): - args = prep_type.split("_") - th1 = int(args[1]) if len(args) >= 2 else 63 - th2 = int(args[2]) if len(args) >= 3 else 191 - - def canny(img): - img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - return cv2.Canny(img, th1, th2) - - return canny - - logger.info(f"Unsupported prep type: {prep_type}") - return None - - -def preprocess_ctrl_net_hint_image(image): - image = np.array(image).astype(np.float32) / 255.0 - # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている - # image = image[:, :, ::-1].copy() # rgb to bgr - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 - - -def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): - guided_hints = [] - for i, cnet_info in enumerate(control_nets): - # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること - b_hints = [] - if len(hints) == 1: # すべて同じ画像をhintとして使う - hint = hints[0] - if cnet_info.prep is not None: - hint = cnet_info.prep(hint) - hint = preprocess_ctrl_net_hint_image(hint) - b_hints = [hint for _ in range(b_size)] - else: - for bi in range(b_size): - hint = hints[(bi * len(control_nets) + i) % len(hints)] - if cnet_info.prep is not None: - hint = cnet_info.prep(hint) - hint = preprocess_ctrl_net_hint_image(hint) - b_hints.append(hint) - b_hints = torch.cat(b_hints, dim=0) - b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) - - guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) - guided_hints.append(guided_hint) - return guided_hints - - -def call_unet_and_control_net( - step, - num_latent_input, - original_unet, - control_nets: List[ControlNetInfo], - guided_hints, - current_ratio, - sample, - timestep, - encoder_hidden_states, - encoder_hidden_states_for_control_net, -): - # ControlNet - # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する - cnet_cnt = len(control_nets) - cnet_idx = step % cnet_cnt - cnet_info = control_nets[cnet_idx] - - # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) - if cnet_info.ratio < current_ratio: - return original_unet(sample, timestep, encoder_hidden_states) - - guided_hint = guided_hints[cnet_idx] - - # gradual latent support: match the size of guided_hint to the size of sample - if guided_hint.shape[-2:] != sample.shape[-2:]: - # print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}") - org_dtype = guided_hint.dtype - if org_dtype == torch.bfloat16: - guided_hint = guided_hint.to(torch.float32) - guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic") - if org_dtype == torch.bfloat16: - guided_hint = guided_hint.to(org_dtype) - - guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward( - True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net - ) - outs = [o * cnet_info.weight for o in outs] - - # U-Net - return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) - - -""" - # これはmergeのバージョン - # ControlNet - cnet_outs_list = [] - for i, cnet_info in enumerate(control_nets): - # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) - if cnet_info.ratio < current_ratio: - continue - guided_hint = guided_hints[i] - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) - for i in range(len(outs)): - outs[i] *= cnet_info.weight - - cnet_outs_list.append(outs) - - count = len(cnet_outs_list) - if count == 0: - return original_unet(sample, timestep, encoder_hidden_states) - - # sum of controlnets - for i in range(1, count): - cnet_outs_list[0] += cnet_outs_list[i] - - # U-Net - return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states) -""" - - -def unet_forward( - is_control_net, - control_net: ControlNet, - unet: UNet2DConditionModel, - guided_hint, - ctrl_outs, - sample, - timestep, - encoder_hidden_states, -): - # copy from UNet2DConditionModel - default_overall_up_factor = 2**unet.num_upsamplers - - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = unet.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=unet.dtype) - emb = unet.time_embedding(t_emb) - - outs = [] # output of ControlNet - zc_idx = 0 - - # 2. pre-process - sample = unet.conv_in(sample) - if is_control_net: - sample += guided_hint - outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) - zc_idx += 1 - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in unet.down_blocks: - if downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - if is_control_net: - for rs in res_samples: - outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) - zc_idx += 1 - - down_block_res_samples += res_samples - - # 4. mid - sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - if is_control_net: - outs.append(control_net.control_model.middle_block_out[0](sample)) - return outs - - if not is_control_net: - sample += ctrl_outs.pop() - - # 5. up - for i, upsample_block in enumerate(unet.up_blocks): - is_final_block = i == len(unet.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if not is_control_net and len(ctrl_outs) > 0: - res_samples = list(res_samples) - apply_ctrl_outs = ctrl_outs[-len(res_samples) :] - ctrl_outs = ctrl_outs[: -len(res_samples)] - for j in range(len(res_samples)): - res_samples[j] = res_samples[j] + apply_ctrl_outs[j] - res_samples = tuple(res_samples) - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = unet.conv_norm_out(sample) - sample = unet.conv_act(sample) - sample = unet.conv_out(sample) - - return SampleOutput(sample=sample) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py deleted file mode 100644 index 5eea8329a..000000000 --- a/tools/resize_images_to_resolution.py +++ /dev/null @@ -1,141 +0,0 @@ -import glob -import os -import cv2 -import argparse -import shutil -import math -from PIL import Image -import numpy as np -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): - # Split the max_resolution string by "," and strip any whitespaces - max_resolutions = [res.strip() for res in max_resolution.split(',')] - - # # Calculate max_pixels from max_resolution string - # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) - - # Create destination folder if it does not exist - if not os.path.exists(dst_img_folder): - os.makedirs(dst_img_folder) - - # Select interpolation method - if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 - elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC - else: - cv2_interpolation = cv2.INTER_AREA - - # Iterate through all files in src_img_folder - img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py - for filename in os.listdir(src_img_folder): - # Check if the image is png, jpg or webp etc... - if not filename.endswith(img_exts): - # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) - shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) - continue - - # Load image - image = Image.open(os.path.join(src_img_folder, filename)) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - - base, _ = os.path.splitext(filename) - for max_resolution in max_resolutions: - # Calculate max_pixels from max_resolution string - max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) - - # Calculate current number of pixels - current_pixels = img.shape[0] * img.shape[1] - - # Calculate current resolution - current_resolution = (img.shape[0], img.shape[1]) - - # Calculate target resolution - target_resolution = (int(max_resolution.split("x")[0]), int(max_resolution.split("x")[1])) - - # Skip to the next image if the current resolution is less than the target resolution - if current_resolution[0] < target_resolution[0] or current_resolution[1] < target_resolution[1]: - print(f"Skipped image: {filename} as its resolution is smaller than target resolution") - continue - - # Check if the image needs resizing - if current_pixels > max_pixels: - # Calculate scaling factor - scale_factor = max_pixels / current_pixels - - # Calculate new dimensions - new_height = int(img.shape[0] * math.sqrt(scale_factor)) - new_width = int(img.shape[1] * math.sqrt(scale_factor)) - - # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) - else: - new_height, new_width = img.shape[0:2] - - # Calculate the new height and width that are divisible by divisible_by (with/without resizing) - new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by - new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by - - # Center crop the image to the calculated dimensions - y = int((img.shape[0] - new_height) / 2) - x = int((img.shape[1] - new_width) / 2) - img = img[y:y + new_height, x:x + new_width] - - # Split filename into base and extension - new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') - - # Save resized image in dst_img_folder - # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) - image = Image.fromarray(img) - image.save(os.path.join(dst_img_folder, new_filename), quality=100) - - proc = "Resized" if current_pixels > max_pixels else "Saved" - logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") - - # If other files with same basename, copy them with resolution suffix - if copy_associated_files: - asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) - for asoc_file in asoc_files: - ext = os.path.splitext(asoc_file)[1] - if ext in img_exts: - continue - for max_resolution in max_resolutions: - new_asoc_file = base + '+' + max_resolution + ext - logger.info(f"Copy {asoc_file} as {new_asoc_file}") - shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') - parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') - parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') - parser.add_argument('--max_resolution', type=str, - help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") - parser.add_argument('--divisible_by', type=int, - help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) - parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], - default='area', help='Interpolation method for resizing / リサイズ時の補完方法') - parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') - parser.add_argument('--copy_associated_files', action='store_true', - help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') - - return parser - - -def main(): - parser = setup_parser() - - args = parser.parse_args() - resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, - args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) - - -if __name__ == '__main__': - main() diff --git a/tools/show_metadata.py b/tools/show_metadata.py deleted file mode 100644 index 05bfbe0a4..000000000 --- a/tools/show_metadata.py +++ /dev/null @@ -1,23 +0,0 @@ -import json -import argparse -from safetensors import safe_open -from library.utils import setup_logging -setup_logging() -import logging -logger = logging.getLogger(__name__) - -parser = argparse.ArgumentParser() -parser.add_argument("--model", type=str, required=True) -args = parser.parse_args() - -with safe_open(args.model, framework="pt") as f: - metadata = f.metadata() - -if metadata is None: - logger.error("No metadata found") -else: - # metadata is json dict, but not pretty printed - # sort by key and pretty print - print(json.dumps(metadata, indent=4, sort_keys=True)) - - diff --git a/train_controlnet.py b/train_controlnet.py deleted file mode 100644 index dc73a91c8..000000000 --- a/train_controlnet.py +++ /dev/null @@ -1,620 +0,0 @@ -import argparse -import json -import math -import os -import random -import time -from multiprocessing import Value -from types import SimpleNamespace -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP -from accelerate.utils import set_seed -from diffusers import DDPMScheduler, ControlNetModel -from safetensors.torch import load_file - -import library.model_util as model_util -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - pyramid_noise_like, - apply_noise_offset, -) -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - # session_id = random.randint(0, 2**32) - # training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True - ) - - # DiffusersのControlNetが使用するデータを準備する - if args.v2: - unet.config = { - "act_fn": "silu", - "attention_head_dim": [5, 10, 20, 20], - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 1024, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "dual_cross_attention": False, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_class_embeds": None, - "only_cross_attention": False, - "out_channels": 4, - "sample_size": 96, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "use_linear_projection": True, - "upcast_attention": True, - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": True, - "class_embed_type": None, - "num_class_embeds": None, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - else: - unet.config = { - "act_fn": "silu", - "attention_head_dim": 8, - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 768, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "out_channels": 4, - "sample_size": 64, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": False, - "class_embed_type": None, - "num_class_embeds": None, - "upcast_attention": False, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - unet.config = SimpleNamespace(**unet.config) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - controlnet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = controlnet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) - - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.to(accelerator.device) - text_encoder.to(accelerator.device) - - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet - ) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device, - ) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/train_db.py b/train_db.py deleted file mode 100644 index 8d36097a5..000000000 --- a/train_db.py +++ /dev/null @@ -1,504 +0,0 @@ -# DreamBooth training -# XXX dropped option: fine_tune - -import argparse -import itertools -import math -import os -from multiprocessing import Value -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from accelerate.utils import set_seed -from diffusers import DDPMScheduler - -import library.train_util as train_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - get_weighted_text_embeddings, - prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, - scale_v_prediction_loss_like_noise_prediction, - apply_debiased_estimation, -) -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -# perlin_noise, - - -def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, False) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.no_token_padding: - train_dataset_group.disable_token_padding() - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - - if args.gradient_accumulation_steps > 1: - logger.warning( - f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" - ) - logger.warning( - f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" - ) - - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) - - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # 学習を準備する:モデルを適切な状態にする - train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 - unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(train_text_encoder) - if not train_text_encoder: - accelerator.print("Text Encoder is not trained.") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - if train_text_encoder: - if args.learning_rate_te is None: - # wightout list, adamw8bit is crashed - trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) - else: - trainable_params = [ - {"params": list(unet.parameters()), "lr": args.learning_rate}, - {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, - ] - else: - trainable_params = unet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - if not train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - accelerator.print( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - ) - accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - - # For --sample_at_first - train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - unet.train() - # train==True is required to enable gradient_checkpointing - if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: - text_encoder.train() - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - # 指定したステップ数でText Encoderの学習を止める - if global_step == args.stop_text_encoder_training: - accelerator.print(f"stop text encoder training at step {global_step}") - if not args.gradient_checkpointing: - text_encoder.train(False) - text_encoder.requires_grad_(False) - - with accelerator.accumulate(unet): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Get the text embedding for conditioning - with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - if train_text_encoder: - params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) - else: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - False, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(text_encoder), - accelerator.unwrap_model(unet), - vae, - ) - - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss} - train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) - accelerator.log(logs, step=global_step) - - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - # checking for saving is in util - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - src_path, - save_stable_diffusion_format, - use_safetensors, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(text_encoder), - accelerator.unwrap_model(unet), - vae, - ) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - is_main_process = accelerator.is_main_process - if is_main_process: - unet = accelerator.unwrap_model(unet) - text_encoder = accelerator.unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end( - args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae - ) - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False, True) - train_util.add_training_arguments(parser, True) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--learning_rate_te", - type=float, - default=None, - help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", - ) - parser.add_argument( - "--no_token_padding", - action="store_true", - help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", - ) - parser.add_argument( - "--stop_text_encoder_training", - type=int, - default=None, - help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", - ) - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/train_db_README.md b/train_db_README.md deleted file mode 100644 index 7c3be2e3b..000000000 --- a/train_db_README.md +++ /dev/null @@ -1,309 +0,0 @@ -A guide to DreamBooth. The same procedure is used for training additional networks such as LoRA. - -# overview - -The main functions of the script are as follows. - -- Memory saving by 8bit Adam optimizer and latent cache (similar to ShivamShirao's version). -- Saved memory by xformers. -- Study in any size, not just 512x512. -- Quality improvement with augmentation. -- Supports fine tuning of Text Encoder+U-Net as well as DreamBooth. -- Read and write models in StableDiffusion format. -- Aspect Ratio Bucketing. -- Supports Stable Diffusion v2.0. - -# learning procedure - -## step 1. Environment improvement - -See the README in this repository. - - -## step 2. Determine identifier and class - -Decide the word identifier that connects the target you want to learn and the class to which the target belongs. - -(There are various names such as instance, but for the time being I will stick to the original paper.) - -Here's a very brief explanation (look it up for more details). - -class is the general type to learn. For example, if you want to learn a specific breed of dog, the class will be dog. Anime characters will be boy, girl, 1boy or 1girl depending on the model. - -The identifier is for identifying and learning the learning target. Any word is fine, but according to the original paper, ``a rare word with 3 letters or less that becomes one token with tokinizer'' is good. - -By using the identifier and class to train the model, for example, "shs dog", you can learn by identifying the object you want to learn from the class. - -When generating an image, if you say "shs dog", an image of the learned dog breed will be generated. - -(For reference, the identifier I use these days is ``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny``.) - -## step 3. Prepare images for training -Create a folder to store training images. __In addition, create a directory with the following name: - -``` -_ -``` - -Don't forget the ``_`` between them. - -The number of repetitions is specified to match the number of regularized images (described later). - -For example, at the prompt "sls frog", to repeat the data 20 times, it would be "20_sls frog". It will be as follows. - -![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) - -## step 4. Preparing regularized images -This is the procedure when using a regularized image. It is also possible to learn without using the regularization image (the whole target class is affected because it is impossible to distinguish without using the regularization image). - -Create a folder to store the regularized images. __In addition, __ create a directory named ``_``. - -For example, with the prompt "frog" and without repeating the data (just once): - -![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) - -Specify the number of iterations so that " __ number of iterations of training images x number of training images ≥ number of iterations of regularization images x number of regularization images __". - -(The number of data in one epoch is "number of repetitions of training images x number of training images". If the number of regularization images is more than that, the remaining regularization images will not be used.) - -## step 5. Run training -Run the script. The maximally memory-saving command looks like this (actually typed on one line): - -*The command for learning additional networks such as LoRA is ``train_network.py`` instead of ``train_db.py``. You will also need additional network_\* options, so please refer to LoRA's guide. - -``` -accelerate launch --num_cpu_threads_per_process 8 train_db.py - --pretrained_model_name_or_path= - --train_data_dir= - --reg_data_dir= - --output_dir= - --prior_loss_weight=1.0 - --resolution=512 - --train_batch_size=1 - --learning_rate=1e-6 - --max_train_steps=1600 - --use_8bit_adam - --xformers - --mixed_precision="bf16" - --cache_latents - --gradient_checkpointing -``` - -It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process. - -Specify the model to perform additional training in pretrained_model_name_or_path. You can specify a Stable Diffusion checkpoint file (.ckpt or .safetensors), a model directory on the Diffusers local disk, or a Diffusers model ID (such as "stabilityai/stable-diffusion-2"). The saved model after training will be saved in the same format as the original model by default (can be changed with the save_model_as option). - -prior_loss_weight is the loss weight of the regularized image. Normally, specify 1.0. - -resolution will be the size of the image (resolution, width and height). If bucketing (described later) is not used, use this size for training images and regularization images. - -train_batch_size is the training batch size. Set max_train_steps to 1600. The learning rate learning_rate is 5e-6 in the diffusers version and 1e-6 in the StableDiffusion version, so 1e-6 is specified here. - -Specify mixed_precision="bf16" (or "fp16") and gradient_checkpointing for memory saving. - -Specify the xformers option and use xformers' CrossAttention. If you don't have xformers installed, if you get an error (without mixed_precision, it was an error in my environment), specify the mem_eff_attn option instead to use the memory-saving version of CrossAttention (speed will be slower) . - -Cache VAE output with cache_latents option to save memory. - -If you have a certain amount of memory, specify it as follows, for example. - -``` -accelerate launch --num_cpu_threads_per_process 8 train_db.py - --pretrained_model_name_or_path= - --train_data_dir= - --reg_data_dir= - --output_dir= - --prior_loss_weight=1.0 - --resolution=512 - --train_batch_size=4 - --learning_rate=1e-6 - --max_train_steps=400 - --use_8bit_adam - --xformers - --mixed_precision="bf16" - --cache_latents -``` - -Remove gradient_checkpointing to speed up (memory usage will increase). Increase the batch size to improve speed and accuracy. - -An example of using bucketing (see below) and using augmentation (see below) looks like this: - -``` -accelerate launch --num_cpu_threads_per_process 8 train_db.py - --pretrained_model_name_or_path= - --train_data_dir= - --reg_data_dir= - --output_dir= - --resolution=768,512 - --train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800 - --use_8bit_adam --xformers --mixed_precision="bf16" - --save_every_n_epochs=1 --save_state --save_precision="bf16" - --logging_dir=logs - --enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280 - --color_aug --flip_aug --gradient_checkpointing --seed 42 -``` - -### About the number of steps -To save memory, the number of training steps per step is half that of train_drebooth.py (because the target image and the regularization image are divided into different batches instead of the same batch). -Double the number of steps to get almost the same training as the original Diffusers version and XavierXiao's StableDiffusion version. - -(Strictly speaking, the order of the data changes due to shuffle=True, but I don't think it has a big impact on learning.) - -## Generate an image with the trained model - -Name last.ckpt in the specified folder when learning is completed will output the checkpoint (if you learned the DiffUsers version model, it will be the last folder). - -For v1.4/1.5 and other derived models, this model can be inferred by Automatic1111's WebUI, etc. Place it in the models\Stable-diffusion folder. - -When generating images with WebUI with the v2.x model, a separate .yaml file that describes the model specifications is required. Place v2-inference.yaml for v2.x base and v2-inference-v.yaml for 768/v in the same folder and make the part before the extension the same name as the model. - -![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) - -Each yaml file can be found at [https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion] (Stability AI SD2.0 repository). - -# Other study options - -## Supports Stable Diffusion 2.0 --v2 / --v_parameterization - -Specify the v2 option when using Hugging Face's stable-diffusion-2-base, and specify both the v2 and v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt. - -In addition, learning SD 2.0 seems to be difficult with VRAM 12GB because the Text Encoder is getting bigger. - -The following points have changed significantly in Stable Diffusion 2.0. - -1. Tokenizer to use -2. Which Text Encoder to use and which output layer to use (2.0 uses the penultimate layer) -3. Output dimensionality of Text Encoder (768->1024) -4. Structure of U-Net (number of heads of CrossAttention, etc.) -5. v-parameterization (the sampling method seems to have changed) - -Among these, 1 to 4 are adopted for base, and 1 to 5 are adopted for the one without base (768-v). Enabling 1-4 is the v2 option, and enabling 5 is the v_parameterization option. - -## check training data --debug_dataset - -By adding this option, you can check what kind of image data and captions will be learned in advance before learning. Press Esc to exit and return to the command line. - -*Please note that it seems to hang when executed in an environment where there is no screen such as Colab. - -## Stop training Text Encoder --stop_text_encoder_training - -If you specify a numerical value for the stop_text_encoder_training option, after that number of steps, only the U-Net will be trained without training the Text Encoder. In some cases, the accuracy may be improved. - -(Probably only the Text Encoder may overfit first, and I guess that it can be prevented, but the detailed impact is unknown.) - -## Load and learn VAE separately --vae -If you specify either a Stable Diffusion checkpoint, a VAE checkpoint file, a Diffuses model, or a VAE (both of which can specify a local or Hugging Face model ID) in the vae option, that VAE is used for learning (latents when caching or getting latents during learning). -The saved model will incorporate this VAE. - -## save during learning --save_every_n_epochs / --save_state / --resume -Specifying a number for the save_every_n_epochs option saves the model during training every epoch. - -If you specify the save_state option at the same time, the learning state including the state of the optimizer etc. will be saved together (compared to restarting learning from the checkpoint, you can expect to improve accuracy and shorten the learning time). The learning state is output in a folder named "epoch-??????-state" (?????? is the number of epochs) in the destination folder. Please use it when studying for a long time. - -Use the resume option to resume training from a saved training state. Please specify the learning state folder. - -Please note that due to the specifications of Accelerator (?), the number of epochs and global step are not saved, and it will start from 1 even when you resume. - -## No tokenizer padding --no_token_padding - -The no_token_padding option does not pad the output of the Tokenizer (same behavior as Diffusers version of old DreamBooth). - -## Training with arbitrary size images --resolution - -You can study outside the square. Please specify "width, height" like "448,640" in resolution. Width and height must be divisible by 64. Match the size of the training image and the regularization image. - -Personally, I often generate vertically long images, so I sometimes learn with "448, 640". - -## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso - -It is enabled by specifying the enable_bucket option. Stable Diffusion is trained at 512x512, but also at resolutions such as 256x768 and 384x640. - -If you specify this option, you do not need to unify the training images and regularization images to a specific resolution. Choose from several resolutions (aspect ratios) and learn at that resolution. -Since the resolution is 64 pixels, the aspect ratio may not be exactly the same as the original image. - -You can specify the minimum size of the resolution with the min_bucket_reso option and the maximum size with the max_bucket_reso. The defaults are 256 and 1024 respectively. -For example, specifying a minimum size of 384 will not use resolutions such as 256x1024 or 320x768. -If you increase the resolution to 768x768, you may want to specify 1280 as the maximum size. - -When Aspect Ratio Bucketing is enabled, it may be better to prepare regularization images with various resolutions that are similar to the training images. - -(Because the images in one batch are not biased toward training images and regularization images. - -## augmentation --color_aug / --flip_aug - -Augmentation is a method of improving model performance by dynamically changing data during learning. Learn while subtly changing the hue with color_aug and flipping left and right with flip_aug. - -Since the data changes dynamically, it cannot be specified together with the cache_latents option. - -## Specify data precision when saving --save_precision - -Specifying float, fp16, or bf16 as the save_precision option will save the checkpoint in that format (only when saving in Stable Diffusion format). Please use it when you want to reduce the size of checkpoint. - -## save in any format --save_model_as - -Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors. - -When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face. - -## Save learning log --logging_dir / --log_prefix - -Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved. - -For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder. -Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=db_style1_" for identification. - -To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard). - -``` -tensorboard --logdir=logs -``` - -Then open your browser and go to http://localhost:6006/ to see it. - -## scheduler related specification of learning rate --lr_scheduler / --lr_warmup_steps - -You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details. - -## Training with fp16 gradient (experimental feature) --full_fp16 - -The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). -As a result, it seems that the SD1.x 512x512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512x512 size can be learned with a VRAM usage of less than 12GB. - -Specify fp16 in the accelerate config beforehand and optionally set ``mixed_precision="fp16"`` (bf16 does not work). - -To minimize memory usage, use xformers, use_8bit_adam, cache_latents, gradient_checkpointing options and set train_batch_size to 1. - -(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.) - -It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). Accuracy will drop considerably, and the probability of learning failure on the way will also increase. -The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk. - -# Other learning methods - -## Learning multiple classes, multiple identifiers - -The method is simple, multiple folders with ``Repetition count_ `` in the training image folder, and a folder with ``Repetition count_`` in the regularization image folder. Please prepare multiple - -For example, learning "sls frog" and "cpc rabbit" at the same time would look like this: - -![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png) - -If you have one class and multiple targets, you can have only one regularized image folder. For example, if 1girl has character A and character B, do as follows. - -- train_girls - - 10_sls 1girl - - 10_cpc 1girl -- reg_girls - -1_1girl - -If the number of data varies, it seems that good results can be obtained by adjusting the number of repetitions to unify the number of sheets for each class and identifier. - -## Use captions in DreamBooth - -If you put a file with the same file name as the image and the extension .caption (you can change it in the option) in the training image and regularization image folders, the caption will be read from that file and learned as a prompt. - -* The folder name (identifier class) will no longer be used for training those images. - -Adding captions to each image (you can use BLIP, etc.) may help clarify the attributes you want to learn. - -Caption files have a .caption extension by default. You can change it with --caption_extension. With the --shuffle_caption option, study captions during learning while shuffling each part separated by commas. \ No newline at end of file diff --git a/train_network.py b/train_network.py deleted file mode 100644 index e0fa69458..000000000 --- a/train_network.py +++ /dev/null @@ -1,1058 +0,0 @@ -import importlib -import argparse -import math -import os -import sys -import random -import time -import json -from multiprocessing import Value -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from torch.nn.parallel import DistributedDataParallel as DDP - -from accelerate.utils import set_seed -from diffusers import DDPMScheduler -from library import model_util - -import library.train_util as train_util -from library.train_util import ( - DreamBoothDataset, -) -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.huggingface_util as huggingface_util -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - get_weighted_text_embeddings, - prepare_scheduler_for_custom_training, - scale_v_prediction_loss_like_noise_prediction, - add_v_prediction_like_loss, - apply_debiased_estimation, -) -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -class NetworkTrainer: - def __init__(self): - self.vae_scale_factor = 0.18215 - self.is_sdxl = False - - # TODO 他のスクリプトと共通化する - def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None - ): - logs = {"loss/current": current_loss, "loss/average": avr_loss} - - if keys_scaled is not None: - logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm - logs["max_norm/max_key_norm"] = maximum_norm - - lrs = lr_scheduler.get_last_lr() - - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - - return logs - - def assert_extra_args(self, args, train_dataset_group): - pass - - def load_target_model(self, args, weight_dtype, accelerator): - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer - - def is_text_encoder_outputs_cached(self, args): - return False - - def is_train_text_encoder(self, args): - return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) - - def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype - ): - for t_enc in text_encoders: - t_enc.to(accelerator.device, dtype=weight_dtype) - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample - return noise_pred - - def all_reduce_network(self, accelerator, network): - for param in network.parameters(): - if param.grad is not None: - param.grad = accelerator.reduce(param.grad, reduction="mean") - - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - - def train(self, args): - session_id = random.randint(0, 2**32) - training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため - tokenizer = self.load_tokenizer(args) - tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if use_user_config: - logger.info(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - self.assert_extra_args(args, train_dataset_group) - - # acceleratorを準備する - logger.info("preparing accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - - # text_encoder is List[CLIPTextModel] or CLIPTextModel - text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - - # 差分追加学習のためにモデルを読み込む - sys.path.append(os.path.dirname(__file__)) - accelerator.print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - - if args.base_weights is not None: - # base_weights が指定されている場合は、指定された重みを読み込みマージする - for i, weight_path in enumerate(args.base_weights): - if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: - multiplier = 1.0 - else: - multiplier = args.base_weights_multiplier[i] - - accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") - - module, weights_sd = network_module.create_network_from_weights( - multiplier, weight_path, vae, text_encoder, unet, for_inference=True - ) - module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - - accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される - # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype - ) - - # prepare network - net_kwargs = {} - if args.network_args is not None: - for net_arg in args.network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - # if a new network is added in future, add if ~ then blocks for each network (;'∀') - if args.dim_from_weights: - network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) - else: - if "dropout" not in net_kwargs: - # workaround for LyCORIS (;^ω^) - net_kwargs["dropout"] = args.network_dropout - - network = network_module.create_network( - 1.0, - args.network_dim, - args.network_alpha, - vae, - text_encoder, - unet, - neuron_dropout=args.network_dropout, - **net_kwargs, - ) - if network is None: - return - network_has_multiplier = hasattr(network, "set_multiplier") - - if hasattr(network, "prepare_network"): - network.prepare_network(args) - if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - logger.warning( - "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" - ) - args.scale_weight_norms = False - - train_unet = not args.network_train_text_encoder_only - train_text_encoder = self.is_train_text_encoder(args) - network.apply_to(text_encoder, unet, train_text_encoder, train_unet) - - if args.network_weights is not None: - info = network.load_weights(args.network_weights) - accelerator.print(f"load network weights from {args.network_weights}: {info}") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() - del t_enc - network.enable_gradient_checkpointing() # may have no effect - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - # 後方互換性を確保するよ - try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) - except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - network.to(weight_dtype) - elif args.full_bf16: - assert ( - args.mixed_precision == "bf16" - ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" - accelerator.print("enable full bf16 training.") - network.to(weight_dtype) - - unet_weight_dtype = te_weight_dtype = weight_dtype - # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: - assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" - assert ( - args.mixed_precision != "no" - ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") - unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn - - unet.requires_grad_(False) - unet.to(dtype=unet_weight_dtype) - for t_enc in text_encoders: - t_enc.requires_grad_(False) - - # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 - if t_enc.device.type != "cpu": - t_enc.to(dtype=te_weight_dtype) - # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) - - # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good - if train_unet: - unet = accelerator.prepare(unet) - else: - unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator - if train_text_encoder: - if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] - else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] - else: - pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) - - if args.gradient_checkpointing: - # according to TI example in Diffusers, train is required - unet.train() - for t_enc in text_encoders: - t_enc.train() - - # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: - t_enc.text_model.embeddings.requires_grad_(True) - - else: - unet.eval() - for t_enc in text_encoders: - t_enc.eval() - - del t_enc - - accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) - - if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - # TODO refactor metadata creation and move to util - metadata = { - "ss_session_id": session_id, # random integer indicating which group of epochs the model came from - "ss_training_started_at": training_started_at, # unix timestamp - "ss_output_name": args.output_name, - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset_group.num_train_images, - "ss_num_reg_images": train_dataset_group.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_gradient_checkpointing": args.gradient_checkpointing, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim - "ss_network_alpha": args.network_alpha, # some networks may not have alpha - "ss_network_dropout": args.network_dropout, # some networks may not have dropout - "ss_mixed_precision": args.mixed_precision, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_base_model_version": model_version, - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_cache_latents": bool(args.cache_latents), - "ss_seed": args.seed, - "ss_lowram": args.lowram, - "ss_noise_offset": args.noise_offset, - "ss_multires_noise_iterations": args.multires_noise_iterations, - "ss_multires_noise_discount": args.multires_noise_discount, - "ss_adaptive_noise_scale": args.adaptive_noise_scale, - "ss_zero_terminal_snr": args.zero_terminal_snr, - "ss_training_comment": args.training_comment, # will not be updated after training - "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), - "ss_max_grad_norm": args.max_grad_norm, - "ss_caption_dropout_rate": args.caption_dropout_rate, - "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, - "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, - "ss_face_crop_aug_range": args.face_crop_aug_range, - "ss_prior_loss_weight": args.prior_loss_weight, - "ss_min_snr_gamma": args.min_snr_gamma, - "ss_scale_weight_norms": args.scale_weight_norms, - "ss_ip_noise_gamma": args.ip_noise_gamma, - "ss_debiased_estimation": bool(args.debiased_estimation_loss), - } - - if use_user_config: - # save metadata of multiple datasets - # NOTE: pack "ss_datasets" value as json one time - # or should also pack nested collections as json? - datasets_metadata = [] - tag_frequency = {} # merge tag frequency for metadata editor - dataset_dirs_info = {} # merge subset dirs for metadata editor - - for dataset in train_dataset_group.datasets: - is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) - dataset_metadata = { - "is_dreambooth": is_dreambooth_dataset, - "batch_size_per_device": dataset.batch_size, - "num_train_images": dataset.num_train_images, # includes repeating - "num_reg_images": dataset.num_reg_images, - "resolution": (dataset.width, dataset.height), - "enable_bucket": bool(dataset.enable_bucket), - "min_bucket_reso": dataset.min_bucket_reso, - "max_bucket_reso": dataset.max_bucket_reso, - "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, - } - - subsets_metadata = [] - for subset in dataset.subsets: - subset_metadata = { - "img_count": subset.img_count, - "num_repeats": subset.num_repeats, - "color_aug": bool(subset.color_aug), - "flip_aug": bool(subset.flip_aug), - "random_crop": bool(subset.random_crop), - "shuffle_caption": bool(subset.shuffle_caption), - "keep_tokens": subset.keep_tokens, - } - - image_dir_or_metadata_file = None - if subset.image_dir: - image_dir = os.path.basename(subset.image_dir) - subset_metadata["image_dir"] = image_dir - image_dir_or_metadata_file = image_dir - - if is_dreambooth_dataset: - subset_metadata["class_tokens"] = subset.class_tokens - subset_metadata["is_reg"] = subset.is_reg - if subset.is_reg: - image_dir_or_metadata_file = None # not merging reg dataset - else: - metadata_file = os.path.basename(subset.metadata_file) - subset_metadata["metadata_file"] = metadata_file - image_dir_or_metadata_file = metadata_file # may overwrite - - subsets_metadata.append(subset_metadata) - - # merge dataset dir: not reg subset only - # TODO update additional-network extension to show detailed dataset config from metadata - if image_dir_or_metadata_file is not None: - # datasets may have a certain dir multiple times - v = image_dir_or_metadata_file - i = 2 - while v in dataset_dirs_info: - v = image_dir_or_metadata_file + f" ({i})" - i += 1 - image_dir_or_metadata_file = v - - dataset_dirs_info[image_dir_or_metadata_file] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count, - } - - dataset_metadata["subsets"] = subsets_metadata - datasets_metadata.append(dataset_metadata) - - # merge tag frequency: - for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): - # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える - # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない - # なので、ここで複数datasetの回数を合算してもあまり意味はない - if ds_dir_name in tag_frequency: - continue - tag_frequency[ds_dir_name] = ds_freq_for_dir - - metadata["ss_datasets"] = json.dumps(datasets_metadata) - metadata["ss_tag_frequency"] = json.dumps(tag_frequency) - metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) - else: - # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir - assert ( - len(train_dataset_group.datasets) == 1 - ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" - - dataset = train_dataset_group.datasets[0] - - dataset_dirs_info = {} - reg_dataset_dirs_info = {} - if use_dreambooth_method: - for subset in dataset.subsets: - info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info - info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} - else: - for subset in dataset.subsets: - dataset_dirs_info[os.path.basename(subset.metadata_file)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count, - } - - metadata.update( - { - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_resolution": args.resolution, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_enable_bucket": bool(dataset.enable_bucket), - "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), - "ss_min_bucket_reso": dataset.min_bucket_reso, - "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(dataset.tag_frequency), - "ss_bucket_info": json.dumps(dataset.bucket_info), - } - ) - - # add extra args - if args.network_args: - metadata["ss_network_args"] = json.dumps(net_kwargs) - - # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path - if os.path.exists(sd_model_name): - metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) - metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) - sd_model_name = os.path.basename(sd_model_name) - metadata["ss_sd_model_name"] = sd_model_name - - if args.vae is not None: - vae_name = args.vae - if os.path.exists(vae_name): - metadata["ss_vae_hash"] = train_util.model_hash(vae_name) - metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) - vae_name = os.path.basename(vae_name) - metadata["ss_vae_name"] = vae_name - - metadata = {k: str(v) for k, v in metadata.items()} - - # make minimum metadata for filtering - minimum_metadata = {} - for key in train_util.SS_METADATA_MINIMUM_KEYS: - if key in metadata: - minimum_metadata[key] = metadata[key] - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # callback for step start - if hasattr(accelerator.unwrap_model(network), "on_step_start"): - on_step_start = accelerator.unwrap_model(network).on_step_start - else: - on_step_start = lambda *args, **kwargs: None - - # function for saving/removing - def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - metadata["ss_training_finished_at"] = str(time.time()) - metadata["ss_steps"] = str(steps) - metadata["ss_epoch"] = str(epoch_no) - - metadata_to_save = minimum_metadata if args.no_metadata else metadata - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) - metadata_to_save.update(sai_metadata) - - unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - metadata["ss_epoch"] = str(epoch + 1) - - accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(network): - on_step_start(text_encoder, unet) - - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - latents = latents * self.vae_scale_factor - - # get multiplier for each sample - if network_has_multiplier: - multipliers = batch["network_multipliers"] - # if all multipliers are same, use single multiplier - if torch.all(multipliers == multipliers[0]): - multipliers = multipliers[0].item() - else: - raise NotImplementedError("multipliers for each sample is not supported yet") - # print(f"set multiplier: {multipliers}") - accelerator.unwrap_model(network).set_multiplier(multipliers) - - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients: - self.all_reduce_network(accelerator, network) # sync DDP grad manually - if args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(network).get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( - args.scale_weight_norms, accelerator.device - ) - max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} - else: - keys_scaled, mean_norm, maximum_norm = None, None, None - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) - - if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # end of epoch - - # metadata["ss_epoch"] = str(num_train_epochs) - metadata["ss_training_finished_at"] = str(time.time()) - - if is_main_process: - network = accelerator.unwrap_model(network) - - accelerator.end_training() - - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, True) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" - ) - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - - parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - - parser.add_argument( - "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" - ) - parser.add_argument( - "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" - ) - parser.add_argument( - "--network_dim", - type=int, - default=None, - help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", - ) - parser.add_argument( - "--network_alpha", - type=float, - default=1, - help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", - ) - parser.add_argument( - "--network_dropout", - type=float, - default=None, - help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", - ) - parser.add_argument( - "--network_args", - type=str, - default=None, - nargs="*", - help="additional arguments for network (key=value) / ネットワークへの追加の引数", - ) - parser.add_argument( - "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" - ) - parser.add_argument( - "--network_train_text_encoder_only", - action="store_true", - help="only training Text Encoder part / Text Encoder関連部分のみ学習する", - ) - parser.add_argument( - "--training_comment", - type=str, - default=None, - help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", - ) - parser.add_argument( - "--dim_from_weights", - action="store_true", - help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", - ) - parser.add_argument( - "--scale_weight_norms", - type=float, - default=None, - help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", - ) - parser.add_argument( - "--base_weights", - type=str, - default=None, - nargs="*", - help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", - ) - parser.add_argument( - "--base_weights_multiplier", - type=float, - default=None, - nargs="*", - help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", - ) - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - trainer = NetworkTrainer() - trainer.train(args) diff --git a/train_network_README.md b/train_network_README.md deleted file mode 100644 index ed62dad8b..000000000 --- a/train_network_README.md +++ /dev/null @@ -1,189 +0,0 @@ -# About learning LoRA - -[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) (arxiv), [LoRA](https://github.com/microsoft/LoRA) (github) to Stable Applied to Diffusion. - -[cloneofsimo's repository](https://github.com/cloneofsimo/lora) was a great reference. Thank you very much. - -8GB VRAM seems to work just fine. - -## A Note about Trained Models - -Cloneofsimo's repository and d8ahazard's [Drebooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_drebooth_extension) are currently incompatible. Because we are doing some enhancements (see below). - -When generating images with WebUI, etc., merge the learned LoRA model with the learning source Stable Diffusion model in advance with the script in this repository, or click here [Extention for WebUI] (https://github .com/kohya-ss/sd-webui-additional-networks). - -## Learning method - -Use train_network.py. - -You can learn both the DreamBooth method (using identifiers (sks, etc.) and classes, optionally regularized images) and the fine tuning method using captions. - -Both methods can be learned in much the same way as existing scripts. We will discuss the differences later. - -### Using the DreamBooth Method - -Please refer to [DreamBooth guide](./train_db_README-en.md) and prepare the data. - -Specify train_network.py instead of train_db.py when training. - -Almost all options are available (except Stable Diffusion model save related), but stop_text_encoder_training is not supported. - -### When to use captions - -Please refer to [fine-tuning guide](./fine_tune_README_en.md) and perform each step. - -Specify train_network.py instead of fine_tune.py when training. Almost all options (except for model saving) can be used as is. - -In addition, it will work even if you do not perform "Pre-obtain latents". Since the latent is acquired from the VAE when learning (or caching), the learning speed will be slower, but color_aug can be used instead. - -### Options for Learning LoRA - -In train_network.py, specify the name of the module to be trained in the --network_module option. LoRA is compatible with network.lora, so please specify it. - -The learning rate should be set to about 1e-4, which is higher than normal DreamBooth and fine tuning. - -Below is an example command line (DreamBooth technique). - -``` -accelerate launch --num_cpu_threads_per_process 12 train_network.py - --pretrained_model_name_or_path=..\models\model.ckpt - --train_data_dir=..\data\db\char1 --output_dir=..\lora_train1 - --reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0 - --resolution=448,640 --train_batch_size=1 --learning_rate=1e-4 - --max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16 - --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug - --network_module=networks.lora -``` - -The LoRA model will be saved in the directory specified by the --output_dir option. - -In addition, the following options can be specified. - -* --network_dim - * Specify the number of dimensions of LoRA (such as ``--networkdim=4``). Default is 4. The greater the number, the greater the expressive power, but the memory and time required for learning also increase. In addition, it seems that it is not good to increase it blindly. -* --network_weights - * Load pretrained LoRA weights before training and additionally learn from them. -* --network_train_unet_only - * Valid only for LoRA modules related to U-Net. It may be better to specify it in fine-tuning study. -* --network_train_text_encoder_only - * Only LoRA modules related to Text Encoder are enabled. You may be able to expect a textual inversion effect. -* --unet_lr - * Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module related to U-Net. -* --text_encoder_lr - * Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module associated with the Text Encoder. Some people say that it is better to set the Text Encoder to a slightly lower learning rate (such as 5e-5). - -When neither --network_train_unet_only nor --network_train_text_encoder_only is specified (default), both Text Encoder and U-Net LoRA modules are enabled. - -## About the merge script - -merge_lora.py allows you to merge LoRA training results into a Stable Diffusion model, or merge multiple LoRA models. - -### Merge LoRA model into Stable Diffusion model - -The model after merging can be handled in the same way as normal Stable Diffusion ckpt. For example, a command line like: - -``` -python networks\merge_lora.py --sd_model ..\model\model.ckpt - --save_to ..\lora_train1\model-char1-merged.safetensors - --models ..\lora_train1\last.safetensors --ratios 0.8 -``` - -Specify the --v2 option if you want to train with a Stable Diffusion v2.x model and merge with it. - -Specify the Stable Diffusion model file to be merged in the --sd_model option (only .ckpt or .safetensors are supported, Diffusers is not currently supported). - -Specify the save destination of the model after merging in the --save_to option (.ckpt or .safetensors, automatically determined by extension). - -Specify the LoRA model file learned in --models. It is possible to specify more than one, in which case they will be merged in order. - -For --ratios, specify the application rate of each model (how much weight is reflected in the original model) with a numerical value from 0 to 1.0. For example, if it is close to over fitting, it may be better if the application rate is lowered. Specify as many as the number of models. - -When specifying multiple, it will be as follows. - -``` -python networks\merge_lora.py --sd_model ..\model\model.ckpt - --save_to ..\lora_train1\model-char1-merged.safetensors - --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.8 0.5 -``` - -### Merge multiple LoRA models - -Applying multiple LoRA models one by one to the SD model and merging multiple LoRA models and then merging them into the SD model yield slightly different results in relation to the calculation order. - -For example, a command line like: - -```shell -python networks\merge_lora.py - --save_to ..\lora_train1\model-char1-style1-merged.safetensors - --models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4 -``` - -The --sd_model option does not need to be specified. - -Specify the save destination of the merged LoRA model in the --save_to option (.ckpt or .safetensors, automatically determined by extension). - -Specify the LoRA model file learned in --models. Three or more can be specified. - -For --ratios, specify the ratio of each model (how much weight is reflected in the original model) with a numerical value from 0 to 1.0. If you merge two models one-to-one, it will be "0.5 0.5". "1.0 1.0" would give too much weight to the sum, and the result would probably be less desirable. - -LoRA trained with v1 and LoRA trained with v2, and LoRA with different number of dimensions cannot be merged. U-Net only LoRA and U-Net+Text Encoder LoRA should be able to merge, but the result is unknown. - -### Other Options - -* precision - * The precision for merge calculation can be specified from float, fp16, and bf16. If omitted, it will be float to ensure accuracy. Specify fp16/bf16 if you want to reduce memory usage. -* save_precision - * You can specify the precision when saving the model from float, fp16, bf16. If omitted, the precision is the same as precision. - -## Generate with the image generation script in this repository - -Add options --network_module, --network_weights, --network_dim (optional) to gen_img_diffusers.py. The meaning is the same as when learning. - -You can change the LoRA application rate by specifying a value between 0 and 1.0 with the --network_mul option. - -## Create a LoRA model from the difference between two models - -It was implemented with reference to [this discussion](https://github.com/cloneofsimo/lora/discussions/56). I used the formula as it is (I don't understand it well, but it seems that singular value decomposition is used for approximation). - -LoRA approximates the difference between two models (for example, the original model after fine tuning and the model after fine tuning). - -### How to run scripts - -Please specify as follows. - -```shell -python networks\extract_lora_from_models.py --model_org base-model.ckpt - --model_tuned fine-tuned-model.ckpt - --save_to lora-weights.safetensors --dim 4 -``` - -Specify the original Stable Diffusion model for the --model_org option. When applying the created LoRA model, this model will be specified and applied. .ckpt or .safetensors can be specified. - -Specify the Stable Diffusion model to extract the difference in the --model_tuned option. For example, specify a model after fine tuning or DreamBooth. .ckpt or .safetensors can be specified. - -Specify the save destination of the LoRA model in --save_to. Specify the number of dimensions of LoRA in --dim. - -A generated LoRA model can be used in the same way as a trained LoRA model. - -If the Text Encoder is the same for both models, LoRA will be U-Net only LoRA. - -### Other Options - ---v2 - - Please specify when using the v2.x Stable Diffusion model. ---device - - If cuda is specified as ``--device cuda``, the calculation will be performed on the GPU. Processing will be faster (because even the CPU is not that slow, it seems to be at most twice or several times faster). ---save_precision - - Specify the LoRA save format from "float", "fp16", "bf16". Default is float. - -## Additional Information - -### Differences from cloneofsimo's repository - -As of 12/25, this repository has expanded LoRA application points to Text Encoder's MLP, U-Net's FFN, and Transformer's in/out projection, increasing its expressiveness. However, the amount of memory used increased instead, and it became the last minute of 8GB. - -Also, the module replacement mechanism is completely different. - -### About Future Expansion - -It is possible to support not only LoRA but also other expansions, so we plan to add them as well. \ No newline at end of file diff --git a/train_network_appl_weights_README.md b/train_network_appl_weights_README.md deleted file mode 100644 index 29996cc96..000000000 --- a/train_network_appl_weights_README.md +++ /dev/null @@ -1,39 +0,0 @@ -# Exploring Layer-Specific Application Rates for LoRA -## Introduction -Added a tool, train_network_appl_weights.py, for exploring layer-specific application rates. Currently, it supports SDXL only. - -## Concept -The process involves running the standard training process with varying layer-specific application rates on trained networks like LoRA. The goal is to explore which rates produce images closest to the training data. - -## Penalty for Total Application Rates -It's possible to use the total of the layer-specific application rates as a penalty, aiming to reproduce images while minimizing the impact of less significant layers. - -## Multi-Network Exploration -The exploration can be conducted on multiple networks and requires at least one piece of training data. - -Note: The effectiveness with a specific number of images has not been confirmed, but it has been tested with approximately 50 images. The training data does not necessarily have to be from LoRA's training phase, although this has not been confirmed. - -## Command Line Options -The command line options are almost identical to those for `sdxl_train_network.py`, with the following additions and extensions: - -- `--application_loss_weight`: Weight of the layer-specific application rate when added to the loss. Default is 0.0001. Increasing this value trains the model to minimize the application rates. Setting it to 0 allows free exploration of the application rates that yield the highest fidelity. -- `--network_module`: Allows specifying multiple modules for exploration, e.g., `--network_module networks.lora networks.lora`. -- `--network_weights`: Allows specifying weights for multiple networks to be explored, e.g., `--network_weights model1.safetensors model2.safetensors`. - -## Parameters -The number of parameters for layer-specific application rates is 20, including BASE, IN00-08, MID, OUT00-08. BASE is applied to the Text Encoder (Note: LoRA's operation on the Text Encoder has not been confirmed). - -Although the parameters are saved to a file, it's recommended to copy and save the values displayed on the screen. - -## Remarks -Confirmed to work with the AdamW optimizer and a learning rate of 1e-1. The learning rate can be set quite high. With this setting, reasonable results can be obtained in about 1/20 to 1/10 the epochs used during LoRA training. -Increasing `application_loss_weight` above 0.0001 significantly reduces the total application rate, meaning LoRA is applied less. Adjust as needed. -Using negative values for the application rate can lead to minimizing the total by excessively reducing less influential layers' application rates. Negative values are weighted ten times (e.g., -0.01 is almost the same penalty as 0.1). Modify the source code to change the weighting. - -## Potential Uses -Beyond reducing unnecessary layers' application rates, potential uses include: - -- Searching for LoRA application rates to maintain a character while changing their pose based on a reference image. -- Exploring application rates for LoRA to maintain a character's style while altering the artistic style of the image. -- Exploring necessary layers to reproduce a character's attributes using an image in a different style as training data. -- Applying numerous LoRAs to an ideal image as training data and searching for the application rates that achieve the highest fidelity (though more LoRAs will slow down the training). \ No newline at end of file diff --git a/train_textual_inversion.py b/train_textual_inversion.py deleted file mode 100644 index df1d8485a..000000000 --- a/train_textual_inversion.py +++ /dev/null @@ -1,805 +0,0 @@ -import argparse -import math -import os -from multiprocessing import Value -import toml - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from accelerate.utils import set_seed -from diffusers import DDPMScheduler -from transformers import CLIPTokenizer -from library import model_util - -import library.train_util as train_util -import library.huggingface_util as huggingface_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - prepare_scheduler_for_custom_training, - scale_v_prediction_loss_like_noise_prediction, - add_v_prediction_like_loss, - apply_debiased_estimation, -) -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", -] - -imagenet_style_templates_small = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", -] - - -class TextualInversionTrainer: - def __init__(self): - self.vae_scale_factor = 0.18215 - self.is_sdxl = False - - def assert_extra_args(self, args, train_dataset_group): - pass - - def load_target_model(self, args, weight_dtype, accelerator): - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - - def load_tokenizer(self, args): - tokenizer = train_util.load_tokenizer(args) - return tokenizer - - def assert_token_string(self, token_string, tokenizers: CLIPTokenizer): - pass - - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - with torch.enable_grad(): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) - return encoder_hidden_states - - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - noise_pred = unet(noisy_latents, timesteps, text_conds).sample - return noise_pred - - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): - train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) - - def save_weights(self, file, updated_embs, save_dtype, metadata): - state_dict = {"emb_params": updated_embs[0]} - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) # can be loaded in Web UI - - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - data = load_file(file) - else: - # compatible to Web UI's file format - data = torch.load(file, map_location="cpu") - if type(data) != dict: - raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - - if "string_to_param" in data: # textual inversion embeddings - data = data["string_to_param"] - if hasattr(data, "_parameters"): # support old PyTorch? - data = getattr(data, "_parameters") - - emb = next(iter(data.values())) - if type(emb) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - - if len(emb.size()) == 1: - emb = emb.unsqueeze(0) - - return [emb] - - def train(self, args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template - - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - - if args.seed is not None: - set_seed(args.seed) - - tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer - tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype - - # モデルを読み込む - model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) - text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - - if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: - accelerator.print( - "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " - + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" - ) - - # Convert the init_word to token_id - init_token_ids_list = [] - if args.init_word is not None: - for i, tokenizer in enumerate(tokenizers): - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - accelerator.print( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / " - + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}" - ) - init_token_ids_list.append(init_token_ids) - else: - init_token_ids_list = [None] * len(tokenizers) - - # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token - # token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される - # add new word to tokenizer, count is num_vectors_per_token - # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added - - self.assert_token_string(args.token_string, tokenizers) - - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - token_ids_list = [] - token_embeds_list = [] - for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)): - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == args.num_vectors_per_token - ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}") - assert ( - min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 - ), f"token ids is not ordered : tokenizer {i+1}, {token_ids}" - assert ( - len(tokenizer) - 1 == token_ids[-1] - ), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}" - token_ids_list.append(token_ids) - - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids): - token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - token_embeds_list.append(token_embeds) - - # load weights - if args.weights is not None: - embeddings_list = self.load_weights(args.weights) - assert len(token_ids) == len( - embeddings_list[0] - ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # accelerator.print(token_ids, embeddings.size()) - for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list): - for token_id, embedding in zip(token_ids, embeddings): - token_embeds[token_id] = embedding - # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - accelerator.print(f"weighs loaded") - - accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) - if args.dataset_config is not None: - accelerator.print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - accelerator.print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - accelerator.print("Use DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) - - self.assert_extra_args(args, train_dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - accelerator.print(f"use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) - - # サンプル生成用 - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - else: - # サンプル生成用 - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える - vae.set_use_memory_efficient_attention_xformers(args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=vae_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - for text_encoder in text_encoders: - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - trainable_params = [] - for text_encoder in text_encoders: - trainable_params += text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # acceleratorがなんかよろしくやってくれるらしい - if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - - elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - - text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - - else: - raise NotImplementedError() - - index_no_updates_list = [] - orig_embeds_params_list = [] - for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders): - index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - index_no_updates_list.append(index_no_updates) - - # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - orig_embeds_params_list.append(orig_embeds_params) - - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) - unwrapped_text_encoder.text_model.encoder.requires_grad_(False) - unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False) - unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す - unet.train() - else: - unet.eval() - - if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=vae_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - for text_encoder in text_encoders: - text_encoder.to(weight_dtype) - if args.full_bf16: - for text_encoder in text_encoders: - text_encoder.to(weight_dtype) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - accelerator.print( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - ) - accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - # function for saving/removing - def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, False, True) - - self.save_weights(ckpt_file, embs_list, save_dtype, sai_metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - self.sample_images( - accelerator, - args, - 0, - global_step, - accelerator.device, - vae, - tokenizer_or_list, - text_encoder_or_list, - unet, - prompt_replacement, - ) - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for text_encoder in text_encoders: - text_encoder.train() - - loss_total = 0 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(text_encoders[0]): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - latents = latents * self.vae_scale_factor - - # Get the text embedding for conditioning - text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - for text_encoder, orig_embeds_params, index_no_updates in zip( - text_encoders, orig_embeds_params_list, index_no_updates_list - ): - # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 - input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight - input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ - index_no_updates - ] - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - self.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer_or_list, - text_encoder_or_list, - unet, - prompt_replacement, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - updated_embs_list = [] - for text_encoder, token_ids in zip(text_encoders, token_ids_list): - updated_embs = ( - accelerator.unwrap_model(text_encoder) - .get_input_embeddings() - .weight[token_ids] - .data.detach() - .clone() - ) - updated_embs_list.append(updated_embs) - - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, updated_embs_list, global_step, epoch) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) - accelerator.log(logs, step=global_step) - - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - updated_embs_list = [] - for text_encoder, token_ids in zip(text_encoders, token_ids_list): - updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - updated_embs_list.append(updated_embs) - - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if accelerator.is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, updated_embs_list, epoch + 1, global_step) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - self.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer_or_list, - text_encoder_or_list, - unet, - prompt_replacement, - ) - - # end of epoch - - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = accelerator.unwrap_model(text_encoder) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - - accelerator.end_training() - - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, False) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser, False) - - parser.add_argument( - "--save_model_as", - type=str, - default="pt", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", - ) - - parser.add_argument( - "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" - ) - parser.add_argument( - "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" - ) - parser.add_argument( - "--token_string", - type=str, - default=None, - help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", - ) - parser.add_argument( - "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" - ) - parser.add_argument( - "--use_object_template", - action="store_true", - help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", - ) - parser.add_argument( - "--use_style_template", - action="store_true", - help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", - ) - parser.add_argument( - "--no_half_vae", - action="store_true", - help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - trainer = TextualInversionTrainer() - trainer.train(args) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py deleted file mode 100644 index 695fad2a8..000000000 --- a/train_textual_inversion_XTI.py +++ /dev/null @@ -1,712 +0,0 @@ -import importlib -import argparse -import math -import os -import toml -from multiprocessing import Value - -from tqdm import tqdm - -import torch -from library.device_utils import init_ipex, clean_memory_on_device -init_ipex() - -from accelerate.utils import set_seed -import diffusers -from diffusers import DDPMScheduler -import library - -import library.train_util as train_util -import library.huggingface_util as huggingface_util -import library.config_util as config_util -from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, -) -import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import ( - apply_snr_weight, - prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, - scale_v_prediction_loss_like_noise_prediction, - apply_debiased_estimation, -) -import library.original_unet as original_unet -from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -from library.utils import setup_logging, add_logging_arguments - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - -imagenet_templates_small = [ - "a photo of a {}", - "a rendering of a {}", - "a cropped photo of the {}", - "the photo of a {}", - "a photo of a clean {}", - "a photo of a dirty {}", - "a dark photo of the {}", - "a photo of my {}", - "a photo of the cool {}", - "a close-up photo of a {}", - "a bright photo of the {}", - "a cropped photo of a {}", - "a photo of the {}", - "a good photo of the {}", - "a photo of one {}", - "a close-up photo of the {}", - "a rendition of the {}", - "a photo of the clean {}", - "a rendition of a {}", - "a photo of a nice {}", - "a good photo of a {}", - "a photo of the nice {}", - "a photo of the small {}", - "a photo of the weird {}", - "a photo of the large {}", - "a photo of a cool {}", - "a photo of a small {}", -] - -imagenet_style_templates_small = [ - "a painting in the style of {}", - "a rendering in the style of {}", - "a cropped painting in the style of {}", - "the painting in the style of {}", - "a clean painting in the style of {}", - "a dirty painting in the style of {}", - "a dark painting in the style of {}", - "a picture in the style of {}", - "a cool painting in the style of {}", - "a close-up painting in the style of {}", - "a bright painting in the style of {}", - "a cropped painting in the style of {}", - "a good painting in the style of {}", - "a close-up painting in the style of {}", - "a rendition in the style of {}", - "a nice painting in the style of {}", - "a small painting in the style of {}", - "a weird painting in the style of {}", - "a large painting in the style of {}", -] - - -def train(args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template - setup_logging(args, reset=True) - - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - - if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: - logger.warning( - "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" - ) - assert ( - args.dataset_class is None - ), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません" - - cache_latents = args.cache_latents - - if args.seed is not None: - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - - # Convert the init_word to token_id - if args.init_word is not None: - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - logger.warning( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" - ) - else: - init_token_ids = None - - # add new word to tokenizer, count is num_vectors_per_token - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == args.num_vectors_per_token - ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - logger.info(f"tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - token_strings_XTI = [] - XTI_layers = [ - "IN01", - "IN02", - "IN04", - "IN05", - "IN07", - "IN08", - "MID", - "OUT03", - "OUT04", - "OUT05", - "OUT06", - "OUT07", - "OUT08", - "OUT09", - "OUT10", - "OUT11", - ] - for layer_name in XTI_layers: - token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] - - tokenizer.add_tokens(token_strings_XTI) - token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) - logger.info(f"tokens are added (XTI): {token_ids_XTI}") - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids_XTI): - token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] - # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - - # load weights - if args.weights is not None: - embeddings = load_weights(args.weights) - assert len(token_ids) == len( - embeddings - ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # logger.info(token_ids, embeddings.size()) - for token_id, embedding in zip(token_ids_XTI, embeddings): - token_embeds[token_id] = embedding - # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - logger.info(f"weighs loaded") - - logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) - if args.dataset_config is not None: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.info( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - logger.info("Use DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } - else: - logger.info("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - logger.info(f"use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) - - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - else: - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - original_unet.UNet2DConditionModel.forward = unet_forward_XTI - original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI - original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - logger.info("prepare optimizer, data loader etc.") - trainable_params = text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - logger.info( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) - - index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] - # logger.info(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - else: - unet.eval() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - text_encoder.to(weight_dtype) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - logger.info("running training / 学習開始") - logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - logger.info(f" num epochs / epoch数: {num_train_epochs}") - logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}") - logger.info( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - ) - logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - if args.zero_terminal_snr: - custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) - - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - # function for saving/removing - def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - logger.info("") - logger.info(f"saving checkpoint: {ckpt_file}") - save_weights(ckpt_file, embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - logger.info(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - logger.info("") - logger.info(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - text_encoder.train() - - loss_total = 0 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(text_encoder): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - # weight_dtype) use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = torch.stack( - [ - train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) - for s in torch.split(input_ids, 1, dim=1) - ] - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - - loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ - index_no_updates - ] - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - # TODO: fix sample_images - # train_util.sample_images( - # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - # ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - updated_embs = ( - accelerator.unwrap_model(text_encoder) - .get_input_embeddings() - .weight[token_ids_XTI] - .data.detach() - .clone() - ) - - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, updated_embs, global_step, epoch) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) - accelerator.log(logs, step=global_step) - - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() - - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if accelerator.is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, updated_embs, epoch + 1, global_step) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - # TODO: fix sample_images - # train_util.sample_images( - # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - # ) - - # end of epoch - - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = accelerator.unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) - - updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) - - logger.info("model saved.") - - -def save_weights(file, updated_embs, save_dtype): - updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1]) - updated_embs = updated_embs.chunk(16) - XTI_layers = [ - "IN01", - "IN02", - "IN04", - "IN05", - "IN07", - "IN08", - "MID", - "OUT03", - "OUT04", - "OUT05", - "OUT06", - "OUT07", - "OUT08", - "OUT09", - "OUT10", - "OUT11", - ] - state_dict = {} - for i, layer_name in enumerate(XTI_layers): - state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype) - - # if save_dtype is not None: - # for key in list(state_dict.keys()): - # v = state_dict[key] - # v = v.detach().clone().to("cpu").to(save_dtype) - # state_dict[key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, file) - else: - torch.save(state_dict, file) # can be loaded in Web UI - - -def load_weights(file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - data = load_file(file) - else: - raise ValueError(f"NOT XTI: {file}") - - if len(data.values()) != 16: - raise ValueError(f"NOT XTI: {file}") - - emb = torch.concat([x for x in data.values()]) - - return emb - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, False) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser, False) - - parser.add_argument( - "--save_model_as", - type=str, - default="pt", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", - ) - - parser.add_argument( - "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み" - ) - parser.add_argument( - "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" - ) - parser.add_argument( - "--token_string", - type=str, - default=None, - help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", - ) - parser.add_argument( - "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可" - ) - parser.add_argument( - "--use_object_template", - action="store_true", - help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", - ) - parser.add_argument( - "--use_style_template", - action="store_true", - help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", - ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - args = train_util.read_config_from_file(args, parser) - - train(args) diff --git a/train_ti_README.md b/train_ti_README.md deleted file mode 100644 index e655f8320..000000000 --- a/train_ti_README.md +++ /dev/null @@ -1,61 +0,0 @@ -# About learning Textual Inversion - -[Textual Inversion](https://textual-inversion.github.io/). I heavily referenced https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion for the implementation. - -The trained model can be used as is on the Web UI. - -In addition, it is probably compatible with SD2.x, but it has not been tested at this time. - -## Learning method - -Use ``train_textual_inversion.py``. - -Data preparation is exactly the same as ``train_network.py``, so please refer to [their document](./train_network_README-en.md). - -## options - -Below is an example command line (DreamBooth technique). - -```shell -accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py - --pretrained_model_name_or_path=..\models\model.ckpt - --train_data_dir=..\data\db\char1 --output_dir=..\ti_train1 - --resolution=448,640 --train_batch_size=1 --learning_rate=1e-4 - --max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16 - --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug - --token_string=mychar4 --init_word=cute --num_vectors_per_token=4 -``` - -``--token_string`` specifies the token string for learning. __The learning prompt should contain this string (eg ``mychar4 1girl`` if token_string is mychar4)__. This string part of the prompt is replaced with a new token for Textual Inversion and learned. - -``--debug_dataset`` will display the token id after substitution, so you can check if the token string after ``49408`` exists as shown below. I can confirm. - -```python -input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, - 49407, 49407, 49407, 49407, 49407, 49407, 49407]]) -``` - -Words that the tokenizer already has (common words) cannot be used. - -In ``--init_word``, specify the string of the copy source token when initializing embeddings. It seems to be a good idea to choose something that has a similar concept to what you want to learn. You cannot specify a character string that becomes two or more tokens. - -``--num_vectors_per_token`` specifies how many tokens to use for this training. The higher the number, the more expressive it is, but it consumes more tokens. For example, if num_vectors_per_token=8, then the specified token string will consume 8 tokens (out of the 77 token limit for a typical prompt). - -In addition, the following options can be specified. - -* --weights - * Load learned embeddings before learning and learn additionally from there. -* --use_object_template - * Learn with default object template strings (such as ``a photo of a {}``) instead of captions. It will be the same as the official implementation. Captions are ignored. -* --use_style_template - * Learn with default style template strings instead of captions (such as ``a painting in the style of {}``). It will be the same as the official implementation. Captions are ignored. - -## Generate with the image generation script in this repository - -In gen_img_diffusers.py, specify the learned embeddings file with the ``--textual_inversion_embeddings`` option. Using the filename (without the extension) of the embeddings file at the prompt will apply the embeddings. \ No newline at end of file