diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..9f6d672ca --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +*.sh text eol=lf +*.ps1 text eol=crlf +*.bat text eol=crlf +*.cmd text eol=crlf \ No newline at end of file diff --git a/README.md b/README.md index e963020c5..bdb2037a9 100644 --- a/README.md +++ b/README.md @@ -232,6 +232,44 @@ python lora_gui.py Once you have created the LoRA network, you can generate images via auto1111 by installing [this extension](https://github.com/kohya-ss/sd-webui-additional-networks). +### Naming of LoRA + +The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository. + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers) + + LoRA for Linear layers and Conv2d layers with 1x1 kernel + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers) + + In addition to 1., LoRA for Conv2d layers with 3x3 kernel + +LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI. + +To use LoRA-C3Liar with Web UI, please use our extension. + +## Sample image generation during training +A prompt file might look like this, for example + +``` +# prompt 1 +masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 + +# prompt 2 +masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 +``` + + Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. + + * `--n` Negative prompt up to the next option. + * `--w` Specifies the width of the generated image. + * `--h` Specifies the height of the generated image. + * `--d` Specifies the seed of the generated image. + * `--l` Specifies the CFG scale of the generated image. + * `--s` Specifies the number of steps in the generation. + + The prompt weighting such as `( )` and `[ ]` are working. + ## Troubleshooting ### Page File Limit @@ -258,6 +296,21 @@ This will store a backup file with your current locally installed pip packages a ## Change History +* 2023/04/17 (v21.5.4) + - Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`. + - Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf! + - Upgrade Gradio to latest release + - Fix issue when Adafactor is used as optimizer and LR Warmup is not 0: https://github.com/bmaltais/kohya_ss/issues/617 + - Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README-ja.md#dylora) for details (currently only in Japanese). + - Added support for caching latents to disk in each training script. Please specify __both__ `--cache_latents` and `--cache_latents_to_disk` options. + - The files are saved in the same folder as the images with the extension `.npz`. If you specify the `--flip_aug` option, the files with `_flip.npz` will also be saved. + - Multi-GPU training has not been tested. + - This feature is not tested with all combinations of datasets and training scripts, so there may be bugs. + - Added workaround for an error that occurs when training with `fp16` or `bf16` in `fine_tune.py`. + - Implemented DyLoRA GUI support. There will now be a new 'DyLoRA Unit` slider when the LoRA type is selected as `kohya DyLoRA` to specify the desired Unit value for DyLoRA training. + - Update gui.bat and gui.ps1 based on: https://github.com/bmaltais/kohya_ss/issues/188 + - Update `setup.bat` to install torch 2.0.0 instead of 1.2.1. If you want to upgrade from 1.2.1 to 2.0.0 run setup.bat again, select 1 to uninstall the previous torch modules, then select 2 for torch 2.0.0 + * 2023/04/09 (v21.5.2) - Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution! @@ -267,120 +320,3 @@ This will store a backup file with your current locally installed pip packages a - The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible. - If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses. - Run gui.sh from any place - -* 2023/04/08 (v21.5.1) - - Integrate latest sd-scripts updates. Not integrated in the GUI. Will consider if you think it is wort integrating. At the moment you can add the required parameters using the `Additional parameters` field under the `Advanced Configuration` accordion in the `Training Parameters` tab: - - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. - - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. - - - Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) - - When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model. - - Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens). - - For example, specify other arguments as follows. - - `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere` - - If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private. - - If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded. - - If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed. - - In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model` - - If you specify `--async_upload`, the upload will be done asynchronously. - - Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](https://github-com.translate.goog/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md?_x_tr_sl=fr&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp#diffusers%E3%81%AEpipeline%E3%81%A7%E7%94%9F%E6%88%90%E3%81%99%E3%82%8B) (Google translate from Japanese) - - Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`. - - If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match. - - Resolved bug https://github.com/bmaltais/kohya_ss/issues/554 -* 2023/04/07 (v21.5.0) - - Update MacOS and Linux install scripts. Thanks @jstayco - - Update windows upgrade ps1 and bat - - Update kohya_ss sd-script code to latest release... this is a big one so it might cause some training issue. If you find that this release is causing issues for you you can go back to the previous release with `git checkout v21.4.2` and then run the upgrade script for your platform. Here is the list of changes in the new sd-scripts: - - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. - - The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed. - - - Fix some bugs and add some features. - - Fix an issue that `.json` format dataset config files cannot be read. [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) Thanks to rockerBOO! - - Raise an error when an invalid `--lr_warmup_steps` option is specified (when warmup is not valid for the specified scheduler). [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) Thanks to shirayu! - - Add `min_snr_gamma` to metadata in `train_network.py`. [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) Thanks to rockerBOO! - - Fix the data type handling in `fine_tune.py`. This may fix an error that occurs in some environments when using xformers, npz format cache, and mixed_precision. - - - Add options to `train_network.py` to specify block weights for learning rates. [PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) Thanks to u-haru for the great contribution! - - Specify the weights of 25 blocks for the full model. - - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. - - Specify the following arguments with `--network_args`. - - `down_lr_weight` : Specify the learning rate weight of the down blocks of U-Net. The following can be specified. - - The weight for each block: Specify 12 numbers such as `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"`. - - Specify from preset: Specify such as `"down_lr_weight=sine"` (the weights by sine curve). sine, cosine, linear, reverse_linear, zeros can be specified. Also, if you add `+number` such as `"down_lr_weight=cosine+.25"`, the specified number is added (such as 0.25~1.25). - - `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`. - - `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight. - - If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created. - - `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0. - - - Add options to `train_network.py` to specify block dims (ranks) for variable rank. - - Specify 25 values ​​for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values ​​always. - - Specify the following arguments with `--network_args`. - - `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`. - - `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used. - - `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. - - `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used. - - Add GUI support for new features introduced above by kohya_ss. Those will be visible only if the LoRA is of type `Standard` or `kohya LoCon`. You will find the new parameters under the `Advanced Configuration` accordion in the `Training parameters` tab. - - Various improvements to linux and macos srtup scripts thanks to @Oceanswave and @derVedro - - Integrated sd-scripts commits into commit history. Thanks to @Cauldrath -* 2023/04/02 (v21.4.2) - - removes TensorFlow from requirements.txt for Darwin platforms as pip does not support advanced conditionals like CPU architecture. The logic is now defined in setup.sh to avoid version bump headaches, and the selection logic is in the pre-existing pip function. Additionally, the release includes the addition of the tensorflow-metal package for M1+ Macs, which enables GPU acceleration per Apple's documentation. Thanks @jstayco -* 2023/04/01 (v21.4.1) - - Fix type for linux install by @bmaltais in https://github.com/bmaltais/kohya_ss/pull/517 - - Fix .gitignore by @bmaltais in https://github.com/bmaltais/kohya_ss/pull/518 -* 2023/04/01 (v21.4.0) - - Improved linux and macos installation and updates script. See README for more details. Many thanks to @jstayco and @Galunid for the great PR! - - Fix issue with "missing library" error. -* 2023/04/01 (v21.3.9) - - Update how setup is done on Windows by introducing a setup.bat script. This will make it easier to install/re-install on Windows if needed. Many thanks to @missionfloyd for his PR: https://github.com/bmaltais/kohya_ss/pull/496 -* 2023/03/30 (v21.3.8) - - Fix issue with LyCORIS version not being found: https://github.com/bmaltais/kohya_ss/issues/481 -* 2023/03/29 (v21.3.7) - - Allow for 0.1 increment in Network and Conv alpha values: https://github.com/bmaltais/kohya_ss/pull/471 Thanks to @srndpty - - Updated Lycoris module version -* 2023/03/28 (v21.3.6) - - Fix issues when `--persistent_data_loader_workers` is specified. - - The batch members of the bucket are not shuffled. - - `--caption_dropout_every_n_epochs` does not work. - - These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue. - - Fix an issue that images are loaded twice in Windows environment. - - Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work! - - Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper. - - The Min SNR gamma fields can be found under the advanced training tab in all trainers. - - Fixed the error while images are ended with capital image extensions. Thanks to @kvzn. https://github.com/bmaltais/kohya_ss/pull/454 -* 2023/03/26 (v21.3.5) - - Fix for https://github.com/bmaltais/kohya_ss/issues/230 - - Added detection for Google Colab to not bring up the GUI file/folder window on the platform. Instead it will only use the file/folder path provided in the input field. -* 2023/03/25 (v21.3.4) - - Added untested support for MacOS base on this gist: https://gist.github.com/jstayco/9f5733f05b9dc29de95c4056a023d645 - - Let me know how this work. From the look of it it appear to be well thought out. I modified a few things to make it fit better with the rest of the code in the repo. - - Fix for issue https://github.com/bmaltais/kohya_ss/issues/433 by implementing default of 0. - - Removed non applicable save_model_as choices for LoRA and TI. -* 2023/03/24 (v21.3.3) - - Add support for custom user gui files. THey will be created at installation time or when upgrading is missing. You will see two files in the root of the folder. One named `gui-user.bat` and the other `gui-user.ps1`. Edit the file based on your preferred terminal. Simply add the parameters you want to pass the gui in there and execute it to start the gui with them. Enjoy! -* 2023/03/23 (v21.3.2) - - Fix issue reported: https://github.com/bmaltais/kohya_ss/issues/439 -* 2023/03/23 (v21.3.1) - - Merge PR to fix refactor naming issue for basic captions. Thank @zrma -* 2023/03/22 (v21.3.0) - - Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution! - - Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details. - - All sub-sections are combined to a single dictionary (the section names are ignored.) - - Omitted arguments are the default values for command line arguments. - - Command line args override the arguments in `.toml`. - - With `--output_config` option, you can output current command line options to the `.toml` specified with`--config_file`. Please use as a template. - - Add `--lr_scheduler_type` and `--lr_scheduler_args` arguments for custom LR scheduler to each training script. Thanks to Isotr0py! [#271](https://github.com/kohya-ss/sd-scripts/pull/271) - - Same as the optimizer. - - Add sample image generation with weight and no length limit. Thanks to mio2333! [#288](https://github.com/kohya-ss/sd-scripts/pull/288) - - `( )`, `(xxxx:1.2)` and `[ ]` can be used. - - Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290) - - Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404 - - Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls. - - Please start with`2` or `4` depending on the size of VRAM. - - Fix a number of training steps with `--gradient_accumulation_steps` and `--max_train_epochs`. Thanks to tsukimiya! - - Extract parser setup to external scripts. Thanks to robertsmieja! - - Fix an issue without `.npz` and with `--full_path` in training. - - Support extensions with upper cases for images for not Windows environment. - - Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki! - - Fix issue: https://github.com/bmaltais/kohya_ss/issues/406 - - Add device support to LoRA extract. diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 6a96f89ca..709ecdd4d 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -333,6 +333,10 @@ def train_model( if check_if_model_exist(output_name, output_dir, save_model_as): return + + if optimizer == 'Adafactor' and lr_warmup != '0': + msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") + lr_warmup = '0' # Get a list of all subfolders in train_data_dir, excluding hidden folders subfolders = [ diff --git a/fine_tune.py b/fine_tune.py index 2157de985..61f6c1919 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -142,12 +142,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # 学習を準備する:モデルを適切な状態にする training_models = [] if args.gradient_checkpointing: @@ -273,7 +275,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() @@ -311,7 +313,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training diff --git a/finetune/make_captions.py b/finetune/make_captions.py index e690349a2..9e51037f3 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -4,6 +4,7 @@ import json import random +from pathlib import Path from PIL import Image from tqdm import tqdm import numpy as np @@ -13,156 +14,185 @@ from blip.blip import blip_decoder import library.train_util as train_util -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGE_SIZE = 384 # 正方形でいいのか? という気がするがソースがそうなので -IMAGE_TRANSFORM = transforms.Compose([ - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) -]) +IMAGE_TRANSFORM = transforms.Compose( + [ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] +) + # 共通化したいが微妙に処理が異なる…… class ImageLoadingTransformDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] + def __getitem__(self, idx): + img_path = self.images[idx] - try: - image = Image.open(img_path).convert("RGB") - # convert to tensor temporarily so dataloader will accept it - tensor = IMAGE_TRANSFORM(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor = IMAGE_TRANSFORM(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None - return (tensor, img_path) + return (tensor, img_path) def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # fix the seed for reproducibility - seed = args.seed # + utils.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - - if not os.path.exists("blip"): - args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path - - cwd = os.getcwd() - print('Current Working Directory is: ', cwd) - os.chdir('finetune') - - print(f"load images from {args.train_data_dir}") - image_paths = train_util.glob_images(args.train_data_dir) - print(f"found {len(image_paths)} images.") - - print(f"loading BLIP caption: {args.caption_weights}") - model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") - model.eval() - model = model.to(DEVICE) - print("BLIP loaded") - - # captioningする - def run_batch(path_imgs): - imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) - - with torch.no_grad(): - if args.beam_search: - captions = model.generate(imgs, sample=False, num_beams=args.num_beams, - max_length=args.max_length, min_length=args.min_length) - else: - captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) - - for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(caption + "\n") - if args.debug: - print(image_path, caption) - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingTransformDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - img_tensor, image_path = data - if img_tensor is None: - try: - raw_image = Image.open(image_path) - if raw_image.mode != 'RGB': - raw_image = raw_image.convert("RGB") - img_tensor = IMAGE_TRANSFORM(raw_image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - - b_imgs.append((image_path, img_tensor)) - if len(b_imgs) >= args.batch_size: + # fix the seed for reproducibility + seed = args.seed # + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + if not os.path.exists("blip"): + args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path + + cwd = os.getcwd() + print("Current Working Directory is: ", cwd) + os.chdir("finetune") + + print(f"load images from {args.train_data_dir}") + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") + + print(f"loading BLIP caption: {args.caption_weights}") + model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") + model.eval() + model = model.to(DEVICE) + print("BLIP loaded") + + # captioningする + def run_batch(path_imgs): + imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) + + with torch.no_grad(): + if args.beam_search: + captions = model.generate( + imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length + ) + else: + captions = model.generate( + imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length + ) + + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingTransformDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + img_tensor, image_path = data + if img_tensor is None: + try: + raw_image = Image.open(image_path) + if raw_image.mode != "RGB": + raw_image = raw_image.convert("RGB") + img_tensor = IMAGE_TRANSFORM(raw_image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, img_tensor)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + if len(b_imgs) > 0: run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) - print("done!") + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", - help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--beam_search", action="store_true", - help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") - parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") - parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") - parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") - parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "--caption_weights", + type=str, + default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", + help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument( + "--beam_search", + action="store_true", + help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") + parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") + parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") + parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") + parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + main(args) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 06af55987..ce6e66955 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -2,6 +2,7 @@ import os import re +from pathlib import Path from PIL import Image from tqdm import tqdm import torch @@ -11,141 +12,161 @@ import library.train_util as train_util -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") PATTERN_REPLACE = [ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), - re.compile(r'with the number \d+ on (it|\w+ \w+)'), + re.compile(r"with the number \d+ on (it|\w+ \w+)"), re.compile(r'with the words "'), - re.compile(r'word \w+ on it'), - re.compile(r'that says the word \w+ on it'), - re.compile('that says\'the word "( on it)?'), + re.compile(r"word \w+ on it"), + re.compile(r"that says the word \w+ on it"), + re.compile("that says'the word \"( on it)?"), ] # 誤検知しまくりの with the word xxxx を消す def remove_words(captions, debug): - removed_caps = [] - for caption in captions: - cap = caption - for pat in PATTERN_REPLACE: - cap = pat.sub("", cap) - if debug and cap != caption: - print(caption) - print(cap) - removed_caps.append(cap) - return removed_caps + removed_caps = [] + for caption in captions: + cap = caption + for pat in PATTERN_REPLACE: + cap = pat.sub("", cap) + if debug and cap != caption: + print(caption) + print(cap) + removed_caps.append(cap) + return removed_caps def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 - org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation - curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように - - # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す - # ここより上で置き換えようとするとすごく大変 - def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): - input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) - if input_ids.size()[0] != curr_batch_size[0]: - input_ids = input_ids.repeat(curr_batch_size[0], 1) - return input_ids - GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch - - print(f"load images from {args.train_data_dir}") - image_paths = train_util.glob_images(args.train_data_dir) - print(f"found {len(image_paths)} images.") - - # できればcacheに依存せず明示的にダウンロードしたい - print(f"loading GIT: {args.model_id}") - git_processor = AutoProcessor.from_pretrained(args.model_id) - git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) - print("GIT loaded") - - # captioningする - def run_batch(path_imgs): - imgs = [im for _, im in path_imgs] - - curr_batch_size[0] = len(path_imgs) - inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 - generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) - captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) - - if args.remove_words: - captions = remove_words(captions, args.debug) - - for (image_path, _), caption in zip(path_imgs, captions): - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(caption + "\n") - if args.debug: - print(image_path, caption) - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - image, image_path = data - if image is None: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - - b_imgs.append((image_path, image)) - if len(b_imgs) >= args.batch_size: + # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 + org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation + curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように + + # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す + # ここより上で置き換えようとするとすごく大変 + def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): + input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) + if input_ids.size()[0] != curr_batch_size[0]: + input_ids = input_ids.repeat(curr_batch_size[0], 1) + return input_ids + + GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + + print(f"load images from {args.train_data_dir}") + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") + + # できればcacheに依存せず明示的にダウンロードしたい + print(f"loading GIT: {args.model_id}") + git_processor = AutoProcessor.from_pretrained(args.model_id) + git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) + print("GIT loaded") + + # captioningする + def run_batch(path_imgs): + imgs = [im for _, im in path_imgs] + + curr_batch_size[0] = len(path_imgs) + inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 + generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) + captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) + + if args.remove_words: + captions = remove_words(captions, args.debug) + + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is None: + try: + image = Image.open(image_path) + if image.mode != "RGB": + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, image)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) - - print("done!") + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps", - help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") - parser.add_argument("--remove_words", action="store_true", - help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument( + "--model_id", + type=str, + default="microsoft/git-large-textcaps", + help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID", + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") + parser.add_argument( + "--remove_words", + action="store_true", + help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 8d9a38ab3..fd289d1d3 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -2,6 +2,8 @@ import os import json +from pathlib import Path +from typing import List from tqdm import tqdm import numpy as np from PIL import Image @@ -12,7 +14,7 @@ import library.model_util as model_util import library.train_util as train_util -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGE_TRANSFORMS = transforms.Compose( [ @@ -23,245 +25,299 @@ def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def get_latents(vae, images, weight_dtype): - img_tensors = [IMAGE_TRANSFORMS(image) for image in images] - img_tensors = torch.stack(img_tensors) - img_tensors = img_tensors.to(DEVICE, weight_dtype) - with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() - return latents + img_tensors = [IMAGE_TRANSFORMS(image) for image in images] + img_tensors = torch.stack(img_tensors) + img_tensors = img_tensors.to(DEVICE, weight_dtype) + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() + return latents + + +def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): + if is_full_path: + base_name = os.path.splitext(os.path.basename(image_key))[0] + relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) + else: + base_name = image_key + relative_path = "" + if flip: + base_name += "_flip" -def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): - if is_full_path: - base_name = os.path.splitext(os.path.basename(image_key))[0] - else: - base_name = image_key - if flip: - base_name += '_flip' - return os.path.join(data_dir, base_name) + if recursive and relative_path: + return os.path.join(data_dir, relative_path, base_name) + else: + return os.path.join(data_dir, base_name) def main(args): - # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" - if args.bucket_reso_steps % 8 > 0: - print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") - - image_paths = train_util.glob_images(args.train_data_dir) - print(f"found {len(image_paths)} images.") - - if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") - return - - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - vae = model_util.load_vae(args.model_name_or_path, weight_dtype) - vae.eval() - vae.to(DEVICE, dtype=weight_dtype) - - # bucketのサイズを計算する - max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" - - bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso, - args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps) - if not args.bucket_no_upscale: - bucket_manager.make_buckets() - else: - print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") - - # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する - img_ar_errors = [] - - def process_batch(is_last): - for bucket in bucket_manager.buckets: - if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - latents = get_latents(vae, [img for _, img in bucket], weight_dtype) - assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \ - f"latent shape {latents.shape}, {bucket[0][1].shape}" - - for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) - np.savez(npz_file_name, latent) - - # flip - if args.flip_aug: - latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない - - for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) - np.savez(npz_file_name, latent) - else: - # remove existing flipped npz - for image_key, _ in bucket: - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" - if os.path.isfile(npz_file_name): - print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") - os.remove(npz_file_name) - - bucket.clear() - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] - - bucket_counts = {} - for data_entry in tqdm(data, smoothing=0.0): - if data_entry[0] is None: - continue - - img_tensor, image_path = data_entry[0] - if img_tensor is not None: - image = transforms.functional.to_pil_image(img_tensor) + # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" + if args.bucket_reso_steps % 8 > 0: + print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") + + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] + print(f"found {len(image_paths)} images.") + + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding="utf-8") as f: + metadata = json.load(f) else: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - if image_key not in metadata: - metadata[image_key] = {} - - # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 - - reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) - img_ar_errors.append(abs(ar_error)) - bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 + print(f"no metadata / メタデータファイルがありません: {args.in_json}") + return + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae = model_util.load_vae(args.model_name_or_path, weight_dtype) + vae.eval() + vae.to(DEVICE, dtype=weight_dtype) + + # bucketのサイズを計算する + max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) + assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + + bucket_manager = train_util.BucketManager( + args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps + ) + if not args.bucket_no_upscale: + bucket_manager.make_buckets() + else: + print( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" + ) + + # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する + img_ar_errors = [] + + def process_batch(is_last): + for bucket in bucket_manager.buckets: + if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: + latents = get_latents(vae, [img for _, img in bucket], weight_dtype) + assert ( + latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8 + ), f"latent shape {latents.shape}, {bucket[0][1].shape}" + + for (image_key, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + np.savez(npz_file_name, latent) + + # flip + if args.flip_aug: + latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない + + for (image_key, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext( + args.train_data_dir, image_key, args.full_path, True, args.recursive + ) + np.savez(npz_file_name, latent) + else: + # remove existing flipped npz + for image_key, _ in bucket: + npz_file_name = ( + get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" + ) + if os.path.isfile(npz_file_name): + print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") + os.remove(npz_file_name) + + bucket.clear() + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] - # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て - metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) + bucket_counts = {} + for data_entry in tqdm(data, smoothing=0.0): + if data_entry[0] is None: + continue - if not args.bucket_no_upscale: - # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する - assert resized_size[0] == reso[0] or resized_size[1] == reso[ - 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" - assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ - 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" - - assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ - 1], f"internal error resized size is small: {resized_size}, {reso}" - - # 既に存在するファイルがあればshapeを確認して同じならskipする - if args.skip_existing: - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] - if args.flip_aug: - npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") - - found = True - for npz_file in npz_files: - if not os.path.exists(npz_file): - found = False - break - - dat = np.load(npz_file)['arr_0'] - if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 - found = False - break - if found: - continue - - # 画像をリサイズしてトリミングする - # PILにinter_areaがないのでcv2で…… - image = np.array(image) - if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) - - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] - - if resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] - - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" - - # # debug - # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) - - # バッチへ追加 - bucket_manager.add_image(reso, (image_key, image)) - - # バッチを推論するか判定して推論する - process_batch(False) - - # 残りを処理する - process_batch(True) - - bucket_manager.sort() - for i, reso in enumerate(bucket_manager.resos): - count = bucket_counts.get(reso, 0) - if count > 0: - print(f"bucket {i} {reso}: {count}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(img_ar_errors)}") - - # metadataを書き出して終わり - print(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding='utf-8') as f: - json.dump(metadata, f, indent=2) - print("done!") + img_tensor, image_path = data_entry[0] + if img_tensor is not None: + image = transforms.functional.to_pil_image(img_tensor) + else: + try: + image = Image.open(image_path) + if image.mode != "RGB": + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + + image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] + if image_key not in metadata: + metadata[image_key] = {} + + # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 + + reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) + img_ar_errors.append(abs(ar_error)) + bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 + + # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て + metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) + + if not args.bucket_no_upscale: + # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する + assert ( + resized_size[0] == reso[0] or resized_size[1] == reso[1] + ), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" + assert ( + resized_size[0] >= reso[0] and resized_size[1] >= reso[1] + ), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + + assert ( + resized_size[0] >= reso[0] and resized_size[1] >= reso[1] + ), f"internal error resized size is small: {resized_size}, {reso}" + + # 既に存在するファイルがあればshapeを確認して同じならskipする + if args.skip_existing: + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] + if args.flip_aug: + npz_files.append( + get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" + ) + + found = True + for npz_file in npz_files: + if not os.path.exists(npz_file): + found = False + break + + dat = np.load(npz_file)["arr_0"] + if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 + found = False + break + if found: + continue + + # 画像をリサイズしてトリミングする + # PILにinter_areaがないのでcv2で…… + image = np.array(image) + if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size // 2 : trim_size // 2 + reso[0]] + + if resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size // 2 : trim_size // 2 + reso[1]] + + assert ( + image.shape[0] == reso[1] and image.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + + # # debug + # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) + + # バッチへ追加 + bucket_manager.add_image(reso, (image_key, image)) + + # バッチを推論するか判定して推論する + process_batch(False) + + # 残りを処理する + process_batch(True) + + bucket_manager.sort() + for i, reso in enumerate(bucket_manager.resos): + count = bucket_counts.get(reso, 0) + if count > 0: + print(f"bucket {i} {reso}: {count}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error: {np.mean(img_ar_errors)}") + + # metadataを書き出して終わり + print(f"writing metadata: {args.out_json}") + with open(args.out_json, "wt", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") - parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action='store_true', - help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)') - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--max_resolution", type=str, default="512,512", - help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") - parser.add_argument("--bucket_reso_steps", type=int, default=64, - help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") - parser.add_argument("--bucket_no_upscale", action="store_true", - help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_path", action="store_true", - help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--flip_aug", action="store_true", - help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") - parser.add_argument("--skip_existing", action="store_true", - help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - main(args) + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") + parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument( + "--max_resolution", + type=str, + default="512,512", + help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)", + ) + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", + ) + parser.add_argument( + "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + ) + parser.add_argument( + "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + ) + parser.add_argument( + "--full_path", + action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", + ) + parser.add_argument( + "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 2286115ec..40bf428c2 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -10,6 +10,7 @@ from tensorflow.keras.models import load_model from huggingface_hub import hf_hub_download import torch +from pathlib import Path import library.train_util as train_util @@ -17,7 +18,7 @@ IMAGE_SIZE = 448 # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 -DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' +DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] @@ -25,182 +26,271 @@ def preprocess_image(image): - image = np.array(image) - image = image[:, :, ::-1] # RGB->BGR + image = np.array(image) + image = image[:, :, ::-1] # RGB->BGR - # pad to square - size = max(image.shape[0:2]) - pad_x = size - image.shape[1] - pad_y = size - image.shape[0] - pad_l = pad_x // 2 - pad_t = pad_y // 2 - image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + # pad to square + size = max(image.shape[0:2]) + pad_x = size - image.shape[1] + pad_y = size - image.shape[0] + pad_l = pad_x // 2 + pad_t = pad_y // 2 + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) - image = image.astype(np.float32) - return image + image = image.astype(np.float32) + return image class ImageLoadingPrepDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] + def __getitem__(self, idx): + img_path = str(self.images[idx]) - try: - image = Image.open(img_path).convert("RGB") - image = preprocess_image(image) - tensor = torch.tensor(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None + try: + image = Image.open(img_path).convert("RGB") + image = preprocess_image(image) + tensor = torch.tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None - return (tensor, img_path) + return (tensor, img_path) def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする - # depreacatedの警告が出るけどなくなったらその時 - # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - for file in FILES: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( - args.model_dir, SUB_DIR), force_download=True, force_filename=file) - else: - print("using existing wd14 tagger model") - - # 画像を読み込む - image_paths = train_util.glob_images(args.train_data_dir) - print(f"found {len(image_paths)} images.") - - print("loading model and labels") - model = load_model(args.model_dir) - - # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: - reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] - assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" - - tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ - - # 推論する - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) - - probs = model(imgs, training=False) - probs = probs.numpy() - - for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] - - # それ以降はタグなのでconfidenceがthresholdより高いものを追加する - # Everything else is tags: pick any where prediction confidence > threshold - tag_text = "" - for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで - if p >= args.thresh and i < len(tags): - tag_text += ", " + tags[i] - - if len(tag_text) > 0: - tag_text = tag_text[2:] # 最初の ", " を消す - - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(tag_text + '\n') - if args.debug: - print(image_path, tag_text) - - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingPrepDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - image, image_path = data - if image is not None: - image = image.detach().numpy() - else: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - image = preprocess_image(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - b_imgs.append((image_path, image)) - - if len(b_imgs) >= args.batch_size: + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする + # depreacatedの警告が出るけどなくなったらその時 + # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(args.model_dir) or args.force_download: + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + for file in FILES: + hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + for file in SUB_DIR_FILES: + hf_hub_download( + args.repo_id, + file, + subfolder=SUB_DIR, + cache_dir=os.path.join(args.model_dir, SUB_DIR), + force_download=True, + force_filename=file, + ) + else: + print("using existing wd14 tagger model") + + # 画像を読み込む + model = load_model(args.model_dir) + + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") + # 依存ライブラリを増やしたくないので自力で読むよ + + with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + l = [row for row in reader] + header = l[0] # tag_id,name,category,count + rows = l[1:] + assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" + + general_tags = [row[1] for row in rows[1:] if row[2] == "0"] + character_tags = [row[1] for row in rows[1:] if row[2] == "4"] + + # 画像を読み込む + + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") + + tag_freq = {} + + undesired_tags = set(args.undesired_tags.split(",")) + + def run_batch(path_imgs): + imgs = np.array([im for _, im in path_imgs]) + + probs = model(imgs, training=False) + probs = probs.numpy() + + for (image_path, _), prob in zip(path_imgs, probs): + # 最初の4つはratingなので無視する + # # First 4 labels are actually ratings: pick one with argmax + # ratings_names = label_names[:4] + # rating_index = ratings_names["probs"].argmax() + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + + # それ以降はタグなのでconfidenceがthresholdより高いものを追加する + # Everything else is tags: pick any where prediction confidence > threshold + combined_tags = [] + general_tag_text = "" + character_tag_text = "" + for i, p in enumerate(prob[4:]): + if i < len(general_tags) and p >= args.general_threshold: + tag_name = general_tags[i].replace("_", " ") if args.remove_underscore else general_tags[i] + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += ", " + tag_name + combined_tags.append(tag_name) + elif i >= len(general_tags) and p >= args.character_threshold: + tag_name = ( + character_tags[i - len(general_tags)].replace("_", " ") + if args.remove_underscore + else character_tags[i - len(general_tags)] + ) + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += ", " + tag_name + combined_tags.append(tag_name) + + # 先頭のカンマを取る + if len(general_tag_text) > 0: + general_tag_text = general_tag_text[2:] + if len(character_tag_text) > 0: + character_tag_text = character_tag_text[2:] + + tag_text = ", ".join(combined_tags) + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + f.write(tag_text + "\n") + if args.debug: + print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingPrepDataset(image_paths) + data = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.max_data_loader_n_workers, + collate_fn=collate_fn_remove_corrupted, + drop_last=False, + ) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is not None: + image = image.detach().numpy() + else: + try: + image = Image.open(image_path) + if image.mode != "RGB": + image = image.convert("RGB") + image = preprocess_image(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + b_imgs.append((image_path, image)) + + if len(b_imgs) >= args.batch_size: + b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string run_batch(b_imgs) - b_imgs.clear() - - if len(b_imgs) > 0: - run_batch(b_imgs) - - print("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, - help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") - parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", - help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") - parser.add_argument("--force_download", action='store_true', - help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - main(args) + if args.frequency_tags: + sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) + print("\nTag frequencies:") + for tag, freq in sorted_tags: + print(f"{tag}: {freq}") + + print("done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "--repo_id", + type=str, + default=DEFAULT_WD14_TAGGER_REPO, + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID", + ) + parser.add_argument( + "--model_dir", + type=str, + default="wd14_tagger_model", + help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", + ) + parser.add_argument( + "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", + ) + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument( + "--general_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", + ) + parser.add_argument( + "--character_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", + ) + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + parser.add_argument( + "--remove_underscore", + action="store_true", + help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", + ) + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument( + "--undesired_tags", + type=str, + default="", + help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", + ) + parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + + args = parser.parse_args() + + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + + if args.general_threshold is None: + args.general_threshold = args.thresh + if args.character_threshold is None: + args.character_threshold = args.thresh + + main(args) diff --git a/finetune_gui.py b/finetune_gui.py index a3de5010b..1d7a2334a 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -29,6 +29,7 @@ ) from library.utilities import utilities_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -322,6 +323,10 @@ def train_model( ): if check_if_model_exist(output_name, output_dir, save_model_as): return + + if optimizer == 'Adafactor' and lr_warmup != '0': + msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") + lr_warmup = '0' # create caption json file if generate_caption_database: diff --git a/gui.bat b/gui.bat index d1f28641e..d03bac165 100644 --- a/gui.bat +++ b/gui.bat @@ -2,6 +2,7 @@ :: Activate the virtual environment call .\venv\Scripts\activate.bat +set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib :: Validate the requirements and store the exit code python.exe .\tools\validate_requirements.py diff --git a/gui.ps1 b/gui.ps1 index 349ce97e6..3ee16fb6d 100644 --- a/gui.ps1 +++ b/gui.ps1 @@ -1,5 +1,6 @@ # Activate the virtual environment & .\venv\Scripts\activate +$env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib" # Validate the requirements and store the exit code python.exe .\tools\validate_requirements.py diff --git a/kohya_gui.py b/kohya_gui.py index f8e0d8ca3..17941d1d1 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -20,7 +20,7 @@ def UI(**kwargs): print('Load CSS...') css += file.read() + '\n' - interface = gr.Blocks(css=css, title='Kohya_ss GUI') + interface = gr.Blocks(css=css, title='Kohya_ss GUI', theme=gr.themes.Default()) with interface: with gr.Tab('Dreambooth'): diff --git a/library/train_util.py b/library/train_util.py index 56eef81f8..b249e61d4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -722,7 +722,7 @@ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_s def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - def cache_latents(self, vae, vae_batch_size=1): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # ちょっと速くした print("caching latents.") @@ -740,11 +740,38 @@ def cache_latents(self, vae, vae_batch_size=1): if info.latents_npz is not None: info.latents = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + + # might be None, but that's ok because check is done in dataset + info.latents_flipped = self.load_latents_from_npz(info, True) if info.latents_flipped is not None: info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue + # check disk cache exists and size of latents + if cache_to_disk: + # TODO: refactor to unify with FineTuningDataset + info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" + if not is_main_process: + continue + + cache_available = False + expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 + if os.path.exists(info.latents_npz): + cached_latents = np.load(info.latents_npz)["arr_0"] + if cached_latents.shape[1:3] == expected_latents_size: + cache_available = True + + if subset.flip_aug: + cache_available = False + if os.path.exists(info.latents_npz_flipped): + cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"] + if cached_latents_flipped.shape[1:3] == expected_latents_size: + cache_available = True + + if cache_available: + continue + # if last member of batch has different resolution, flush the batch if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: batches.append(batch) @@ -760,6 +787,9 @@ def cache_latents(self, vae, vae_batch_size=1): if len(batch) > 0: batches.append(batch) + if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only + return + # iterate batches for batch in tqdm(batches, smoothing=1, total=len(batches)): images = [] @@ -773,14 +803,21 @@ def cache_latents(self, vae, vae_batch_size=1): img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): - info.latents = latent + if cache_to_disk: + np.savez(info.latents_npz, latent.float().numpy()) + else: + info.latents = latent if subset.flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") for info, latent in zip(batch, latents): - info.latents_flipped = latent + if cache_to_disk: + np.savez(info.latents_npz_flipped, latent.float().numpy()) + else: + info.latents_flipped = latent def get_image_size(self, image_path): image = Image.open(image_path) @@ -873,10 +910,10 @@ def __getitem__(self, index): loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) # image/latentsを処理する - if image_info.latents is not None: + if image_info.latents is not None: # cache_latents=Trueの場合 latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped image = None - elif image_info.latents_npz is not None: + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) latents = torch.FloatTensor(latents) image = None @@ -1163,19 +1200,27 @@ def __init__( tags_list = [] for image_key, img_md in metadata.items(): # path情報を作る + abs_path = None + + # まず画像を優先して探す if os.path.exists(image_key): abs_path = image_key - elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, image_key) + if len(paths) > 0: + abs_path = paths[0] + + # なければnpzを探す + if abs_path is None: + if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + + assert abs_path is not None, f"no image / 画像がありません: {image_key}" caption = img_md.get("caption") tags = img_md.get("tags") @@ -1340,10 +1385,10 @@ def enable_XTI(self, *args, **kwargs): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): print(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -2144,9 +2189,14 @@ def add_dataset_arguments( parser.add_argument( "--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)", + help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--cache_latents_to_disk", + action="store_true", + help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", + ) parser.add_argument( "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" ) @@ -3203,4 +3253,4 @@ def __call__(self, examples): # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0] diff --git a/lora_gui.py b/lora_gui.py index 9ec368893..fa5df0836 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -127,7 +127,7 @@ def save_configuration( vae_batch_size, min_snr_gamma, down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, - weighted_captions, + weighted_captions,unit, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -248,7 +248,7 @@ def open_configuration( vae_batch_size, min_snr_gamma, down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, - weighted_captions, + weighted_captions,unit, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -360,7 +360,7 @@ def train_model( vae_batch_size, min_snr_gamma, down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, - weighted_captions, + weighted_captions,unit, ): print_only_bool = True if print_only.get('label') == 'True' else False @@ -400,6 +400,10 @@ def train_model( if check_if_model_exist(output_name, output_dir, save_model_as): return + + if optimizer == 'Adafactor' and lr_warmup != '0': + msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") + lr_warmup = '0' # If string is empty set string to 0. if text_encoder_lr == '': @@ -544,7 +548,21 @@ def train_model( if network_args: run_cmd += f' --network_args{network_args}' + + if LoRA_type in ['Kohya DyLoRA']: + kohya_lora_var_list = ['conv_dim', 'conv_alpha', 'down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas', 'unit'] + run_cmd += f' --network_module=networks.dylora' + kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value} + + network_args = '' + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f' {key}="{value}"' + + if network_args: + run_cmd += f' --network_args{network_args}' if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0): if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0): @@ -783,6 +801,7 @@ def lora_tab( LoRA_type = gr.Dropdown( label='LoRA type', choices=[ + 'Kohya DyLoRA', 'Kohya LoCon', # 'LoCon', 'LyCORIS/LoCon', @@ -851,8 +870,8 @@ def lora_tab( value=1, step=0.1, interactive=True, + info='alpha for LoRA weight scaling', ) - with gr.Row(visible=False) as LoCon_row: # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) @@ -870,35 +889,44 @@ def lora_tab( step=0.1, label='Convolution Alpha', ) + with gr.Row(visible=False) as kohya_dylora: + unit = gr.Slider( + minimum=1, + maximum=64, + label='DyLoRA Unit', + value=1, + step=1, + interactive=True, + ) # Show of hide LoCon conv settings depending on LoRA type selection def update_LoRA_settings(LoRA_type): + # Print a message when LoRA type is changed print('LoRA type changed...') - - LoRA_type_change = False - LoCon_row = False - - if ( - LoRA_type == 'LoCon' - or LoRA_type == 'Kohya LoCon' - or LoRA_type == 'LyCORIS/LoHa' - or LoRA_type == 'LyCORIS/LoCon' - ): - LoCon_row = True - - if ( - LoRA_type == 'Standard' - or LoRA_type == 'Kohya LoCon' - ): - LoRA_type_change = True - - return gr.Group.update(visible=LoCon_row), gr.Group.update(visible=LoRA_type_change) + + # Determine if LoCon_row should be visible based on LoRA_type + LoCon_row = LoRA_type in {'LoCon', 'Kohya DyLoRA', 'Kohya LoCon', 'LyCORIS/LoHa', 'LyCORIS/LoCon'} + + # Determine if LoRA_type_change should be visible based on LoRA_type + LoRA_type_change = LoRA_type in {'Standard', 'Kohya DyLoRA', 'Kohya LoCon'} + + # Determine if kohya_dylora_visible should be visible based on LoRA_type + kohya_dylora_visible = LoRA_type == 'Kohya DyLoRA' + + # Return the updated visibility settings for the groups + return ( + gr.Group.update(visible=LoCon_row), + gr.Group.update(visible=LoRA_type_change), + gr.Group.update(visible=kohya_dylora_visible), + ) + with gr.Row(): max_resolution = gr.Textbox( label='Max resolution', value='512,512', placeholder='512,512', + info='The maximum resolution of dataset images. W,H', ) stop_text_encoder_training = gr.Slider( minimum=0, @@ -906,8 +934,10 @@ def update_LoRA_settings(LoRA_type): value=0, step=1, label='Stop text encoder training', + info='After what % of steps should the text encoder stop being trained. 0 = train for all steps.', ) - enable_bucket = gr.Checkbox(label='Enable buckets', value=True) + enable_bucket = gr.Checkbox(label='Enable buckets', value=True, + info='Allow non similar resolution dataset images to be trained on.',) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(visible=True) as kohya_advanced_lora: @@ -915,39 +945,47 @@ def update_LoRA_settings(LoRA_type): with gr.Row(visible=True): down_lr_weight = gr.Textbox( label='Down LR weights', - placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1' + placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1', + info='Specify the learning rate weight of the down blocks of U-Net.' ) mid_lr_weight = gr.Textbox( label='Mid LR weights', - placeholder='(Optional) eg: 0.5' + placeholder='(Optional) eg: 0.5', + info='Specify the learning rate weight of the mid block of U-Net.' ) up_lr_weight = gr.Textbox( label='Up LR weights', - placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1' + placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1', + info='Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.' ) block_lr_zero_threshold = gr.Textbox( label='Blocks LR zero threshold', - placeholder='(Optional) eg: 0.1' + placeholder='(Optional) eg: 0.1', + info='If the weight is not more than this value, the LoRA module is not created. The default is 0.' ) with gr.Tab(label='Blocks'): with gr.Row(visible=True): block_dims = gr.Textbox( label='Block dims', - placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2' + placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', + info='Specify the dim (rank) of each block. Specify 25 numbers.' ) block_alphas = gr.Textbox( label='Block alphas', - placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2' + placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', + info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.' ) with gr.Tab(label='Conv'): with gr.Row(visible=True): conv_dims = gr.Textbox( label='Conv dims', - placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2' + placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', + info='Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.' ) conv_alphas = gr.Textbox( label='Conv alphas', - placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2' + placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', + info='Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.' ) with gr.Row(): no_token_padding = gr.Checkbox( @@ -957,7 +995,7 @@ def update_LoRA_settings(LoRA_type): label='Gradient accumulate steps', value='1' ) weighted_captions = gr.Checkbox( - label='Weighted captions', value=False + label='Weighted captions', value=False, info='Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.', ) with gr.Row(): prior_loss_weight = gr.Number( @@ -1013,7 +1051,7 @@ def update_LoRA_settings(LoRA_type): ) = sample_gradio_config() LoRA_type.change( - update_LoRA_settings, inputs=[LoRA_type], outputs=[LoCon_row, kohya_advanced_lora] + update_LoRA_settings, inputs=[LoRA_type], outputs=[LoCon_row, kohya_advanced_lora, kohya_dylora] ) with gr.Tab('Tools'): @@ -1122,7 +1160,7 @@ def update_LoRA_settings(LoRA_type): vae_batch_size, min_snr_gamma, down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, - weighted_captions, + weighted_captions, unit, ] button_open_config.click( diff --git a/networks/dylora.py b/networks/dylora.py new file mode 100644 index 000000000..90b509dfc --- /dev/null +++ b/networks/dylora.py @@ -0,0 +1,450 @@ +# some codes are copied from: +# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/ + +# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. +# Changes made to the original code: +# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +import math +import os +import random +from typing import List, Tuple, Union +import torch +from torch import nn + + +class DyLoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + # NOTE: support dropout in future + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1): + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + self.unit = unit + assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit" + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3) + + if self.is_conv2d and self.is_conv2d_3x3: + kernel_size = org_module.kernel_size + self.stride = org_module.stride + self.padding = org_module.padding + self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)]) + self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)]) + else: + self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)]) + self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)]) + + # same as microsoft's + for lora in self.lora_A: + torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5)) + for lora in self.lora_B: + torch.nn.init.zeros_(lora) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + result = self.org_forward(x) + + # specify the dynamic rank + trainable_rank = random.randint(0, self.lora_dim - 1) + trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit + + # 一部のパラメータを固定して、残りのパラメータを学習する + for i in range(0, trainable_rank): + self.lora_A[i].requires_grad = False + self.lora_B[i].requires_grad = False + for i in range(trainable_rank, trainable_rank + self.unit): + self.lora_A[i].requires_grad = True + self.lora_B[i].requires_grad = True + for i in range(trainable_rank + self.unit, self.lora_dim): + self.lora_A[i].requires_grad = False + self.lora_B[i].requires_grad = False + + lora_A = torch.cat(tuple(self.lora_A), dim=0) + lora_B = torch.cat(tuple(self.lora_B), dim=1) + + # calculate with lora_A and lora_B + if self.is_conv2d_3x3: + ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding) + ab = torch.nn.functional.conv2d(ab, lora_B) + else: + ab = x + if self.is_conv2d: + ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) + + ab = torch.nn.functional.linear(ab, lora_A) + ab = torch.nn.functional.linear(ab, lora_B) + + if self.is_conv2d: + ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W) + + # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな) + result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) + + # NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも + return result + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # state dictを通常のLoRAと同じにする: + # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える + sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + + lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) + if self.is_conv2d and not self.is_conv2d_3x3: + lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) + + lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) + if self.is_conv2d and not self.is_conv2d_3x3: + lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) + + sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() + sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() + + i = 0 + while True: + key_a = f"{self.lora_name}.lora_A.{i}" + key_b = f"{self.lora_name}.lora_B.{i}" + if key_a in sd: + sd.pop(key_a) + sd.pop(key_b) + else: + break + i += 1 + return sd + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた + lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) + lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) + + if lora_A_weight is None or lora_B_weight is None: + if strict: + raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") + else: + return + + if self.is_conv2d and not self.is_conv2d_3x3: + lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) + lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) + + state_dict.update( + {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))} + ) + state_dict.update( + {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))} + ) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + unit = kwargs.get("unit", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + assert conv_dim == network_dim, "conv_dim must be same as network_dim" + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + if unit is not None: + unit = int(unit) + else: + unit = 1 + + network = DyLoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + apply_to_conv=conv_dim is not None, + unit=unit, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # print(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + module_class = DyLoRAModule + + network = DyLoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + return network, weights_sd + + +class DyLoRANetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + alpha=1, + apply_to_conv=False, + modules_dim=None, + modules_alpha=None, + unit=1, + module_class=DyLoRAModule, + varbose=False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.apply_to_conv = apply_to_conv + + if modules_dim is not None: + print(f"create LoRA network from weights") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") + if self.apply_to_conv: + print(f"apply LoRA to Conv2d with kernel size (3,3).") + + # create module instances + def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: + prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + if modules_dim is not None: + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + if is_linear or is_conv2d_1x1 or apply_to_conv: + dim = self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + continue + + # dropout and fan_in_fan_out is default + lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) + loras.append(lora) + return loras + + self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.apply_to_conv: + target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(True, unet, target_modules) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + """ + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + print(f"weights are merged") + """ + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + pass + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + pass diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py new file mode 100644 index 000000000..0abee9836 --- /dev/null +++ b/networks/extract_lora_from_dylora.py @@ -0,0 +1,125 @@ +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo + +import argparse +import math +import os +import torch +from safetensors.torch import load_file, save_file, safe_open +from tqdm import tqdm +from library import train_util, model_util +import numpy as np + + +def load_state_dict(file_name): + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None + + return sd, metadata + + +def save_to_file(file_name, model, metadata): + if model_util.is_safetensors(file_name): + save_file(model, file_name, metadata) + else: + torch.save(model, file_name) + + +def split_lora_model(lora_sd, unit): + max_rank = 0 + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if "lora_down" in key: + rank = value.size()[0] + if rank > max_rank: + max_rank = rank + print(f"Max rank: {max_rank}") + + rank = unit + split_models = [] + new_alpha = None + while rank < max_rank: + print(f"Splitting rank {rank}") + new_sd = {} + for key, value in lora_sd.items(): + if "lora_down" in key: + new_sd[key] = value[:rank].contiguous() + elif "lora_up" in key: + new_sd[key] = value[:, :rank].contiguous() + else: + # なぜかscaleするとおかしくなる…… + # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] + # scale = math.sqrt(this_rank / rank) # rank is > unit + # print(key, value.size(), this_rank, rank, value, scale) + # new_alpha = value * scale # always same + # new_sd[key] = new_alpha + new_sd[key] = value + + split_models.append((new_sd, rank, new_alpha)) + rank += unit + + return max_rank, split_models + + +def split(args): + print("loading Model...") + lora_sd, metadata = load_state_dict(args.model) + + print("Splitting Model...") + original_rank, split_models = split_lora_model(lora_sd, args.unit) + + comment = metadata.get("ss_training_comment", "") + for state_dict, new_rank, new_alpha in split_models: + # update metadata + if metadata is None: + new_metadata = {} + else: + new_metadata = metadata.copy() + + new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" + new_metadata["ss_network_dim"] = str(new_rank) + # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + filename, ext = os.path.splitext(args.save_to) + model_file_name = filename + f"-{new_rank:04d}{ext}" + + print(f"saving model to: {model_file_name}") + save_to_file(model_file_name, state_dict, new_metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + split(args) diff --git a/requirements.txt b/requirements.txt index 5c5cf63bd..89f6272af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,8 @@ diffusers[torch]==0.10.2 easygui==0.98.3 einops==0.6.0 ftfy==6.1.1 -gradio==3.19.1; sys_platform != 'darwin' +gradio==3.27.0; sys_platform != 'darwin' +# gradio==3.19.1; sys_platform != 'darwin' gradio==3.23.0; sys_platform == 'darwin' lion-pytorch==0.0.6 opencv-python==4.7.0.68 diff --git a/setup.bat b/setup.bat index cbc8b8599..3618393f2 100644 --- a/setup.bat +++ b/setup.bat @@ -15,9 +15,31 @@ IF NOT EXIST venv ( ) call .\venv\Scripts\activate.bat -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --use-pep517 --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl +echo Do you want to uninstall previous versions of torch and associated files before installing? +echo [1] - Yes +echo [2] - No +set /p uninstall_choice="Enter your choice (1 or 2): " + +if %uninstall_choice%==1 ( + pip uninstall -y xformers + pip uninstall -y torch torchvision +) + +echo Please choose the version of torch you want to install: +echo [1] - v1 (torch 1.12.1) +echo [2] - v2 (torch 2.0.0) +set /p choice="Enter your choice (1 or 2): " + +if %choice%==1 ( + pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 + pip install --use-pep517 --upgrade -r requirements.txt + pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl +) else ( + pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 + pip install --use-pep517 --upgrade -r requirements.txt + pip install --upgrade xformers==0.0.17 + rem pip install -U -I --no-deps https://files.pythonhosted.org/packages/d6/f7/02662286419a2652c899e2b3d1913c47723fc164b4ac06a85f769c291013/xformers-0.0.17rc482-cp310-cp310-win_amd64.whl +) copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index da5467d20..87bbe7a19 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -362,6 +362,10 @@ def train_model( if check_if_model_exist(output_name, output_dir, save_model_as): return + + if optimizer == 'Adafactor' and lr_warmup != '0': + msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") + lr_warmup = '0' # Get a list of all subfolders in train_data_dir subfolders = [ diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 7c7cc1c58..15a9ca4ab 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -9,86 +9,122 @@ def convert(args): - # 引数を確認する - load_dtype = torch.float16 if args.fp16 else None - - save_dtype = None - if args.fp16: - save_dtype = torch.float16 - elif args.bf16: - save_dtype = torch.bfloat16 - elif args.float: - save_dtype = torch.float - - is_load_ckpt = os.path.isfile(args.model_to_load) - is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 - - assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" - assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" - - # モデルを読み込む - msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) - print(f"loading {msg}: {args.model_to_load}") - - if is_load_ckpt: - v2_model = args.v2 - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) - else: - pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - - if args.v1 == args.v2: - # 自動判定する - v2_model = unet.config.cross_attention_dim == 1024 - print("checking model version: model is " + ('v2' if v2_model else 'v1')) + # 引数を確認する + load_dtype = torch.float16 if args.fp16 else None + + save_dtype = None + if args.fp16 or args.save_precision_as == "fp16": + save_dtype = torch.float16 + elif args.bf16 or args.save_precision_as == "bf16": + save_dtype = torch.bfloat16 + elif args.float or args.save_precision_as == "float": + save_dtype = torch.float + + is_load_ckpt = os.path.isfile(args.model_to_load) + is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 + + assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" + assert ( + is_save_ckpt or args.reference_model is not None + ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" + + # モデルを読み込む + msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) + print(f"loading {msg}: {args.model_to_load}") + + if is_load_ckpt: + v2_model = args.v2 + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) else: - v2_model = not args.v1 - - # 変換して保存する - msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" - print(f"converting and saving as {msg}: {args.model_to_save}") - - if is_save_ckpt: - original_model = args.model_to_load if is_load_ckpt else None - key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, - original_model, args.epoch, args.global_step, save_dtype, vae) - print(f"model saved. total converted state_dict keys: {key_count}") - else: - print(f"copy scheduler/tokenizer config from: {args.reference_model}") - model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors) - print(f"model saved.") + pipe = StableDiffusionPipeline.from_pretrained( + args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None + ) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + + if args.v1 == args.v2: + # 自動判定する + v2_model = unet.config.cross_attention_dim == 1024 + print("checking model version: model is " + ("v2" if v2_model else "v1")) + else: + v2_model = not args.v1 + + # 変換して保存する + msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" + print(f"converting and saving as {msg}: {args.model_to_save}") + + if is_save_ckpt: + original_model = args.model_to_load if is_load_ckpt else None + key_count = model_util.save_stable_diffusion_checkpoint( + v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae + ) + print(f"model saved. total converted state_dict keys: {key_count}") + else: + print(f"copy scheduler/tokenizer config from: {args.reference_model}") + model_util.save_diffusers_checkpoint( + v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors + ) + print(f"model saved.") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("--v1", action='store_true', - help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') - parser.add_argument("--v2", action='store_true', - help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') - parser.add_argument("--fp16", action='store_true', - help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') - parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') - parser.add_argument("--float", action='store_true', - help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') - parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') - parser.add_argument("--global_step", type=int, default=0, - help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') - parser.add_argument("--reference_model", type=str, default=None, - help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)") - - parser.add_argument("model_to_load", type=str, default=None, - help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") - parser.add_argument("model_to_save", type=str, default=None, - help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") - return parser - - -if __name__ == '__main__': - parser = setup_parser() - - args = parser.parse_args() - convert(args) + parser = argparse.ArgumentParser() + parser.add_argument( + "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む" + ) + parser.add_argument( + "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" + ) + parser.add_argument( + "--fp16", + action="store_true", + help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)", + ) + parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)") + parser.add_argument( + "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)" + ) + parser.add_argument( + "--save_precision_as", + type=str, + default="no", + choices=["fp16", "bf16", "float"], + help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください", + ) + parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値") + parser.add_argument( + "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" + ) + parser.add_argument( + "--reference_model", + type=str, + default=None, + help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)", + ) + + parser.add_argument( + "model_to_load", + type=str, + default=None, + help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ", + ) + parser.add_argument( + "model_to_save", + type=str, + default=None, + help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + convert(args) diff --git a/tools/gradio_theme_builder.py b/tools/gradio_theme_builder.py new file mode 100644 index 000000000..c10965535 --- /dev/null +++ b/tools/gradio_theme_builder.py @@ -0,0 +1,2 @@ +import gradio as gr +gr.themes.builder() diff --git a/train_README-ja.md b/train_README-ja.md index 032e006b1..fd66458a1 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -2,7 +2,7 @@ __ドキュメント更新中のため記述に誤りがあるかもしれませ # 学習について、共通編 -当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。 +当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversion([XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)を含む)の学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。 # 概要 @@ -535,7 +535,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - `--debug_dataset` - このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。 + このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。 ※Linux環境(Colabを含む)では画像は表示されません。 @@ -545,6 +545,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。 +- `--cache_latents` + + 使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。 + +- `--min_snr_gamma` + + Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。 ## オプティマイザ関係 @@ -570,7 +577,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 学習率のスケジューラ関連の指定です。 - lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。 + lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, 任意のスケジューラから選べます。デフォルトはconstantです。 lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。 @@ -578,6 +585,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 詳細については各自お調べください。 + 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。 + ### オプティマイザの指定について オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。 diff --git a/train_db.py b/train_db.py index e72dc889a..eddf8f686 100644 --- a/train_db.py +++ b/train_db.py @@ -117,12 +117,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 diff --git a/train_network.py b/train_network.py index ef630969c..658138b70 100644 --- a/train_network.py +++ b/train_network.py @@ -172,12 +172,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # prepare network import sys @@ -195,7 +197,7 @@ def train(args): network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return - + if hasattr(network, "prepare_network"): network.prepare_network(args) @@ -219,7 +221,9 @@ def train(args): try: trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) except TypeError: - print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)") + print( + "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) @@ -539,6 +543,12 @@ def train(args): loss_list = [] loss_total = 0.0 del train_dataset_group + + # if hasattr(network, "on_step_start"): + # on_step_start = network.on_step_start + # else: + # on_step_start = lambda *args, **kwargs: None + for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") @@ -551,6 +561,8 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): + # on_step_start(text_encoder, unet) + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -563,16 +575,17 @@ def train(args): with torch.set_grad_enabled(train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings(tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, ) else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: @@ -757,4 +770,4 @@ def setup_parser() -> argparse.ArgumentParser: args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) \ No newline at end of file + train(args) diff --git a/train_network_README-ja.md b/train_network_README-ja.md index 152ff9af5..cb7cd726b 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -12,11 +12,31 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 +# 学習できるLoRAの種類 + +以下の二種類をサポートします。以下は当リポジトリ内の独自の名称です。 + +1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) + + Linear およびカーネルサイズ 1x1 の Conv2d に適用されるLoRA + +2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) + + 1.に加え、カーネルサイズ 3x3 の Conv2d に適用されるLoRA + +LoRA-LierLaに比べ、LoRA-C3Liarは適用される層が増える分、高い精度が期待できるかもしれません。 + +また学習時は __DyLoRA__ を使用することもできます(後述します)。 + ## 学習したモデルに関する注意 -cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 +LoRA-LierLa は、AUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。 -WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 +LoRA-C3Liarを使いWeb UIで生成するには、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 + +いずれも学習したLoRAのモデルを、Stable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージすることもできます。 + +cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 # 学習の手順 @@ -31,9 +51,9 @@ WebUI等で画像生成する場合には、学習したLoRAのモデルを学 `train_network.py`を用います。 -`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。 +`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのは`network.lora`となりますので、それを指定してください。 -なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。 +なお学習率は通常のDreamBoothやfine tuningよりも高めの、`1e-4`~`1e-3`程度を指定するとよいようです。 以下はコマンドラインの例です。 @@ -56,6 +76,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py --network_module=networks.lora ``` +このコマンドラインでは LoRA-LierLa が学習されます。 + `--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。 その他、以下のオプションが指定できます。 @@ -83,22 +105,143 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 -## LoRA を Conv2d に拡大して適用する +# その他の学習方法 -通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。 +## LoRA-C3Lier を学習する `--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。 ``` ---network_args "conv_dim=1" "conv_alpha=1" +--network_args "conv_dim=4" "conv_alpha=1" ``` 以下のように alpha 省略時は1になります。 ``` ---network_args "conv_dim=1" +--network_args "conv_dim=4" +``` + +## DyLoRA + +DyLoRAはこちらの論文で提案されたものです。[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558) 公式実装は[こちら](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)です。 + +論文によると、LoRAのrankは必ずしも高いほうが良いわけではなく、対象のモデル、データセット、タスクなどにより適切なrankを探す必要があるようです。DyLoRAを使うと、指定したdim(rank)以下のさまざまなrankで同時にLoRAを学習します。これにより最適なrankをそれぞれ学習して探す手間を省くことができます。 + +当リポジトリの実装は公式実装をベースに独自の拡張を加えています(そのため不具合などあるかもしれません)。 + +### 当リポジトリのDyLoRAの特徴 + +学習後のDyLoRAのモデルファイルはLoRAと互換性があります。また、モデルファイルから指定したdim(rank)以下の複数のdimのLoRAを抽出できます。 + +DyLoRA-LierLa、DyLoRA-C3Lierのどちらも学習できます。 + +### DyLoRAで学習する + +`--network_module=networks.dylora` のように、DyLoRAに対応する`network.dylora`を指定してください。 + +また `--network_args` に、たとえば`--network_args "unit=4"`のように`unit`を指定します。`unit`はrankを分割する単位です。たとえば`--network_dim=16 --network_args "unit=4"` のように指定します。`unit`は`network_dim`を割り切れる値(`network_dim`は`unit`の倍数)としてください。 + +`unit`を指定しない場合は、`unit=1`として扱われます。 + +記述例は以下です。 + +``` +--network_module=networks.dylora --network_dim=16 --network_args "unit=4" + +--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4" ``` +DyLoRA-C3Lierの場合は、`--network_args` に`"conv_dim=4"`のように`conv_dim`を指定します。通常のLoRAと異なり、`conv_dim`は`network_dim`と同じ値である必要があります。記述例は以下です。 + +``` +--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4" + +--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8" +``` + +たとえばdim=16、unit=4(後述)で学習すると、4、8、12、16の4つのrankのLoRAを学習、抽出できます。抽出した各モデルで画像を生成し、比較することで、最適なrankのLoRAを選択できます。 + +その他のオプションは通常のLoRAと同じです。 + +※ `unit`は当リポジトリの独自拡張で、DyLoRAでは同dim(rank)の通常LoRAに比べると学習時間が長くなることが予想されるため、分割単位を大きくしたものです。 + +### DyLoRAのモデルからLoRAモデルを抽出する + +`networks`フォルダ内の `extract_lora_from_dylora.py`を使用します。指定した`unit`単位で、DyLoRAのモデルからLoRAのモデルを抽出します。 + +コマンドラインはたとえば以下のようになります。 + +```powershell +python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4 +``` + +`--model` にはDyLoRAのモデルファイルを指定します。`--save_to` には抽出したモデルを保存するファイル名を指定します(rankの数値がファイル名に付加されます)。`--unit` にはDyLoRAの学習時の`unit`を指定します。 + +## 階層別学習率 + +詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。 + +フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 + +`--network_args` で以下の引数を指定してください。 + +- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 + - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 + - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 +- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 +- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 +- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 +- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 + +### 階層別学習率コマンドライン指定例: + +```powershell +--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5" + +--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5" +``` + +### 階層別学習率tomlファイル指定例: + +```toml +network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",] + +network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ] +``` + +## 階層別dim (rank) + +フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 + +`--network_args` で以下の引数を指定してください。 + +- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 +- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。 +- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。 +- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。 + +### 階層別dim (rank)コマンドライン指定例: + +```powershell +--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" + +--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2" + +--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2" +``` + +### 階層別dim (rank)tomlファイル指定例: + +```toml +network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",] + +network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",] +``` + +# その他のスクリプト + +マージ等LoRAに関連するスクリプト群です。 + ## マージスクリプトについて merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。 @@ -323,14 +466,14 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256 - 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。 -## 追加情報 +# 追加情報 -### cloneofsimo氏のリポジトリとの違い +## cloneofsimo氏のリポジトリとの違い 2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。 またモジュール入れ替え機構は全く異なります。 -### 将来拡張について +## 将来拡張について LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。 diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 98639345d..611adff71 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -185,10 +185,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) + current_epoch = Value("i", 0) + current_step = Value("i", 0) ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -233,12 +233,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() @@ -262,7 +264,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -337,7 +341,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 text_encoder.train() @@ -357,7 +361,7 @@ def train(args): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - # weight_dtype) use float instead of fp16/bf16 because text encoder is float + # use float instead of fp16/bf16 because text encoder is float encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) # Sample noise that we'll add to the latents @@ -375,7 +379,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training @@ -385,9 +390,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - + if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index db46ad1b7..54c4b4e56 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -267,12 +267,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() @@ -416,7 +418,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training diff --git a/train_ti_README-ja.md b/train_ti_README-ja.md index 908736961..86f45a5dc 100644 --- a/train_ti_README-ja.md +++ b/train_ti_README-ja.md @@ -4,7 +4,7 @@ 実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。 -学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。 +学習したモデルはWeb UIでもそのまま使えます。 # 学習の手順