diff --git a/README.md b/README.md index bc8a67ffc..7ffce07c7 100644 --- a/README.md +++ b/README.md @@ -113,9 +113,10 @@ accelerate launch --num_cpu_threads_per_process 6 train_db_fixed-ber.py ` --cache_latents ` --save_every_n_epochs=1 ` --fine_tuning ` + --enable_bucket ` --dataset_repeats=200 ` --seed=23 ` - --save_half + ---save_precision="fp16" ``` Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e @@ -125,7 +126,12 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n * 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short). * 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]` * 11/14 (diffusers_fine_tuning v2): -- script name is now fine_tune.py. -- Added option to learn Text Encoder --train_text_encoder. -- The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16. -- Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option. \ No newline at end of file + - script name is now fine_tune.py. + - Added option to learn Text Encoder --train_text_encoder. + - The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16. + - Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option. +* 11/18 (v9): + - Added support for Aspect Ratio Bucketing (enable_bucket option). (--enable_bucket) + - Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision) + - Added support for saving learning state (--save_state, --resume) + - Added support for logging (--logging_dir) diff --git a/examples/caption.ps1 b/examples/caption.ps1 index 49a4dfd28..4673f878a 100644 --- a/examples/caption.ps1 +++ b/examples/caption.ps1 @@ -2,9 +2,12 @@ # # Usefull to create base caption that will be augmented on a per image basis -$folder = "D:\dreambooth\train_sylvia_ritter\raw_data\all-images\" +$folder = "D:\some\folder\location\" $file_pattern="*.*" -$text_fir_file="a digital painting of xxx, by silvery trait" +$caption_text="some caption text" -$files = Get-ChildItem $folder$file_pattern -foreach ($file in $files) {New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file} \ No newline at end of file +$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File +foreach ($file in $files) +{ + New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text +} \ No newline at end of file diff --git a/examples/caption_subfolders.ps1 b/examples/caption_subfolders.ps1 new file mode 100644 index 000000000..0bfba6f01 --- /dev/null +++ b/examples/caption_subfolders.ps1 @@ -0,0 +1,20 @@ +# This powershell script will create a text file for each files in the folder +# +# Usefull to create base caption that will be augmented on a per image basis + +$folder = "D:\test\t2\" +$file_pattern="*.*" +$text_fir_file="bigeyes style" + +foreach ($file in Get-ChildItem $folder\$file_pattern -File) +{ + New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file +} + +foreach($directory in Get-ChildItem -path $folder -Directory) +{ + foreach ($file in Get-ChildItem $folder\$directory\$file_pattern) + { + New-Item -ItemType file -Path $folder\$directory -Name "$($file.BaseName).txt" -Value $text_fir_file + } +} diff --git a/examples/kohya-1-folders.ps1 b/examples/kohya-1-folders.ps1 new file mode 100644 index 000000000..72660a7a8 --- /dev/null +++ b/examples/kohya-1-folders.ps1 @@ -0,0 +1,87 @@ +# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape, +# portrait and square images. +# +# Adjust the script to your own needs + +# Sylvia Ritter +# variable values +$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt" +$data_dir = "D:\test\squat" +$train_dir = "D:\test\" +$resolution = "512,512" + +$image_num = Get-ChildItem $data_dir -Recurse -File -Include *.png | Measure-Object | %{$_.Count} + +Write-Output "image_num: $image_num" + +$learning_rate = 1e-6 +$dataset_repeats = 40 +$train_batch_size = 8 +$epoch = 1 +$save_every_n_epochs=1 +$mixed_precision="fp16" +$num_cpu_threads_per_process=6 + +# You should not have to change values past this point + +$output_dir = $train_dir + "\model" +$repeats = $image_num * $dataset_repeats +$mts = [Math]::Ceiling($repeats / $train_batch_size * $epoch) + +Write-Output "Repeats: $repeats" + +.\venv\Scripts\activate + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py ` + --pretrained_model_name_or_path=$pretrained_model_name_or_path ` + --train_data_dir=$data_dir ` + --output_dir=$output_dir ` + --resolution=$resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$mts ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$dataset_repeats ` + --save_precision="fp16" + +# 2nd pass at half the dataset repeat value + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$output_dir"\last.ckpt" ` + --train_data_dir=$data_dir ` + --output_dir=$output_dir"2" ` + --resolution=$resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$([Math]::Ceiling($mts/2)) ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) ` + --save_precision="fp16" + + accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py ` + --pretrained_model_name_or_path=$output_dir"\last.ckpt" ` + --train_data_dir=$data_dir ` + --output_dir=$output_dir"2" ` + --resolution=$resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$mts ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$dataset_repeats ` + --save_precision="fp16" + \ No newline at end of file diff --git a/examples/kohya-3-folders.ps1 b/examples/kohya-3-folders.ps1 new file mode 100644 index 000000000..ed754a3c2 --- /dev/null +++ b/examples/kohya-3-folders.ps1 @@ -0,0 +1,154 @@ +# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape, +# portrait and square images. +# +# Adjust the script to your own needs + +# Sylvia Ritter +# variable values +$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt" +$train_dir = "D:\dreambooth\train_sylvia_ritter\raw_data" + +$landscape_image_num = 4 +$portrait_image_num = 25 +$square_image_num = 2 + +$learning_rate = 1e-6 +$dataset_repeats = 120 +$train_batch_size = 4 +$epoch = 1 +$save_every_n_epochs=1 +$mixed_precision="fp16" +$num_cpu_threads_per_process=6 + +$landscape_folder_name = "landscape-pp" +$landscape_resolution = "832,512" +$portrait_folder_name = "portrait-pp" +$portrait_resolution = "448,896" +$square_folder_name = "square-pp" +$square_resolution = "512,512" + +# You should not have to change values past this point + +$landscape_data_dir = $train_dir + "\" + $landscape_folder_name +$portrait_data_dir = $train_dir + "\" + $portrait_folder_name +$square_data_dir = $train_dir + "\" + $square_folder_name +$landscape_output_dir = $train_dir + "\model-l" +$portrait_output_dir = $train_dir + "\model-lp" +$square_output_dir = $train_dir + "\model-lps" + +$landscape_repeats = $landscape_image_num * $dataset_repeats +$portrait_repeats = $portrait_image_num * $dataset_repeats +$square_repeats = $square_image_num * $dataset_repeats + +$landscape_mts = [Math]::Ceiling($landscape_repeats / $train_batch_size * $epoch) +$portrait_mts = [Math]::Ceiling($portrait_repeats / $train_batch_size * $epoch) +$square_mts = [Math]::Ceiling($square_repeats / $train_batch_size * $epoch) + +# Write-Output $landscape_repeats + +.\venv\Scripts\activate + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$pretrained_model_name_or_path ` + --train_data_dir=$landscape_data_dir ` + --output_dir=$landscape_output_dir ` + --resolution=$landscape_resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$landscape_mts ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$dataset_repeats ` + --save_precision="fp16" + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$landscape_output_dir"\last.ckpt" ` + --train_data_dir=$portrait_data_dir ` + --output_dir=$portrait_output_dir ` + --resolution=$portrait_resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$portrait_mts ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$dataset_repeats ` + --save_precision="fp16" + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$portrait_output_dir"\last.ckpt" ` + --train_data_dir=$square_data_dir ` + --output_dir=$square_output_dir ` + --resolution=$square_resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$square_mts ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$dataset_repeats ` + --save_precision="fp16" + +# 2nd pass at half the dataset repeat value + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$square_output_dir"\last.ckpt" ` + --train_data_dir=$landscape_data_dir ` + --output_dir=$landscape_output_dir"2" ` + --resolution=$landscape_resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$([Math]::Ceiling($landscape_mts/2)) ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) ` + --save_precision="fp16" + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$landscape_output_dir"2\last.ckpt" ` + --train_data_dir=$portrait_data_dir ` + --output_dir=$portrait_output_dir"2" ` + --resolution=$portrait_resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$([Math]::Ceiling($portrait_mts/2)) ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) ` + --save_precision="fp16" + +accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py ` + --pretrained_model_name_or_path=$portrait_output_dir"2\last.ckpt" ` + --train_data_dir=$square_data_dir ` + --output_dir=$square_output_dir"2" ` + --resolution=$square_resolution ` + --train_batch_size=$train_batch_size ` + --learning_rate=$learning_rate ` + --max_train_steps=$([Math]::Ceiling($square_mts/2)) ` + --use_8bit_adam ` + --xformers ` + --mixed_precision=$mixed_precision ` + --cache_latents ` + --save_every_n_epochs=$save_every_n_epochs ` + --fine_tuning ` + --dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) ` + --save_precision="fp16" + \ No newline at end of file diff --git a/examples/kohya_diffuser.ps1 b/examples/kohya_diffuser.ps1 index b8fcffe21..a12b20f7f 100644 --- a/examples/kohya_diffuser.ps1 +++ b/examples/kohya_diffuser.ps1 @@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\ --use_8bit_adam --xformers ` --mixed_precision=$mixed_precision ` --save_every_n_epochs=$save_every_n_epochs ` - --save_half + --save_precision="fp16" accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py ` --pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" ` @@ -69,4 +69,4 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\ --use_8bit_adam --xformers ` --mixed_precision=$mixed_precision ` --save_every_n_epochs=$save_every_n_epochs ` - --save_half + --save_precision="fp16" diff --git a/train_db_fixed.py b/train_db_fixed.py index 142e51063..3cf3f5db8 100644 --- a/train_db_fixed.py +++ b/train_db_fixed.py @@ -4,7 +4,9 @@ # v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images, # enable reg images in fine-tuning, add dataset_repeats option # v8: supports Diffusers 0.7.2 +# v9: add bucketing option +import time from torch.autograd.function import Function import argparse import glob @@ -56,13 +58,40 @@ # checkpointファイル名 LAST_CHECKPOINT_NAME = "last.ckpt" +LAST_STATE_NAME = "last-state" EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt" +EPOCH_STATE_NAME = "epoch-{:06d}-state" + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) + + resos = set() + + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) + + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) + size += divisible + + resos = list(resos) + resos.sort() + + aspect_ratios = [w / h for w, h in resos] + return resos, aspect_ratios class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): - def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None: + def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None: super().__init__() + self.batch_size = batch_size self.fine_tuning = fine_tuning self.train_img_path_captions = train_img_path_captions self.reg_img_path_captions = reg_img_path_captions @@ -76,6 +105,7 @@ def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, self.shuffle_caption = shuffle_caption self.disable_padding = disable_padding self.latents_cache = None + self.enable_bucket = False # augmentation flip_p = 0.5 if flip_aug else 0.0 @@ -102,13 +132,8 @@ def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, self.enable_reg_images = self.num_reg_images > 0 - if not self.enable_reg_images: - self._length = self.num_train_images - else: - # 学習データの倍として、奇数ならtrain - self._length = self.num_train_images * 2 - if self._length // 2 < self.num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + if self.enable_reg_images and self.num_train_images < self.num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") self.image_transforms = transforms.Compose( [ @@ -117,6 +142,132 @@ def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, ] ) + # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + def make_buckets_with_caching(self, enable_bucket, vae): + self.enable_bucket = enable_bucket + + cache_latents = vae is not None + if cache_latents: + if enable_bucket: + print("cache latents with bucketing") + else: + print("cache latents") + else: + if enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketingを用意する + if enable_bucket: + bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height)) + else: + # bucketはひとつだけ、すべての画像は同じ解像度 + bucket_resos = [(self.width, self.height)] + bucket_aspect_ratios = [self.width / self.height] + bucket_aspect_ratios = np.array(bucket_aspect_ratios) + + # 画像の解像度、latentをあらかじめ取得する + img_ar_errors = [] + self.size_lat_cache = {} + for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions): + if image_path in self.size_lat_cache: + continue + + image = self.load_image(image_path)[0] + image_height, image_width = image.shape[0:2] + + if not enable_bucket: + # assert image_width == self.width and image_height == self.height, \ + # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}" + reso = (self.width, self.height) + else: + # bucketを決める + aspect_ratio = image_width / image_height + ar_errors = bucket_aspect_ratios - aspect_ratio + bucket_id = np.abs(ar_errors).argmin() + reso = bucket_resos[bucket_id] + ar_error = ar_errors[bucket_id] + img_ar_errors.append(ar_error) + + if cache_latents: + image = self.resize_and_trim(image, reso) + + # latentを取得する + if cache_latents: + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + else: + latents = None + + self.size_lat_cache[image_path] = (reso, latents) + + # 画像をbucketに分割する + self.buckets = [[] for _ in range(len(bucket_resos))] + reso_to_index = {} + for i, reso in enumerate(bucket_resos): + reso_to_index[reso] = i + + def split_to_buckets(is_reg, img_path_captions): + for image_path, caption in img_path_captions: + reso, _ = self.size_lat_cache[image_path] + bucket_index = reso_to_index[reso] + self.buckets[bucket_index].append((is_reg, image_path, caption)) + + split_to_buckets(False, self.train_img_path_captions) + + if self.enable_reg_images: + l = [] + while len(l) < len(self.train_img_path_captions): + l += self.reg_img_path_captions + l = l[:len(self.train_img_path_captions)] + split_to_buckets(True, l) + + if enable_bucket: + print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数") + for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)): + print(f"bucket {i}: resolution {reso}, count: {len(imgs)}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}") + + # 参照用indexを作る + self.buckets_indices = [] + for bucket_index, bucket in enumerate(self.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append((bucket_index, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + # どのサイズにリサイズするか→トリミングする方向で + def resize_and_trim(self, image, reso): + image_height, image_width = image.shape[0:2] + ar_img = image_width / image_height + ar_reso = reso[0] / reso[1] + if ar_img > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) + + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + elif resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ + f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + for bucket in self.buckets: + random.shuffle(bucket) + def load_image(self, image_path): image = Image.open(image_path) if not image.mode == "RGB": @@ -184,83 +335,85 @@ def crop_target(self, image, face_cx, face_cy, face_w, face_h): def __len__(self): return self._length - def set_cached_latents(self, image_path, latents): - if self.latents_cache is None: - self.latents_cache = {} - self.latents_cache[image_path] = latents + def __getitem__(self, index): + if index == 0: + self.shuffle_buckets() - def __getitem__(self, index_arg): - example = {} + bucket = self.buckets[self.buckets_indices[index][0]] + image_index = self.buckets_indices[index][1] * self.batch_size - if not self.enable_reg_images: - index = index_arg - img_path_captions = self.train_img_path_captions - reg = False - else: - # 偶数ならtrain、奇数ならregを返す - if index_arg % 2 == 0: - img_path_captions = self.train_img_path_captions - reg = False + latents_list = [] + images = [] + captions = [] + loss_weights = [] + + for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]: + loss_weights.append(1.0 if is_reg else self.prior_loss_weight) + + # image/latentsを処理する + reso, latents = self.size_lat_cache[image_path] + + if latents is None: + # 画像を読み込み必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image(image_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.resize_and_trim(img, reso) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p:p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p:p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}" + + # augmentation + if self.aug is not None: + img = self.aug(image=img)['image'] + + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる else: - img_path_captions = self.reg_img_path_captions - reg = True - index = index_arg // 2 - example['loss_weight'] = 1.0 if (not reg or self.fine_tuning) else self.prior_loss_weight - - index = index % len(img_path_captions) - image_path, caption = img_path_captions[index] - example['image_path'] = image_path - - # image/latentsを処理する - if self.latents_cache is not None and image_path in self.latents_cache: - # latentsはキャッシュ済み - example['latents'] = self.latents_cache[image_path] - else: - # 画像を読み込み必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image(image_path) - im_h, im_w = img.shape[0:2] - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image too small / 画像サイズが小さいようです: {image_path}" - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - example['image'] = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - - # captionを処理する - if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする - tokens = caption.strip().split(",") - random.shuffle(tokens) - caption = ",".join(tokens).strip() - - input_ids = self.tokenizer(caption, padding="do_not_pad", truncation=True, - max_length=self.tokenizer.model_max_length).input_ids - - # padしてTensor変換 + image = None + + images.append(image) + latents_list.append(latents) + + # captionを処理する + if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする + tokens = caption.strip().split(",") + random.shuffle(tokens) + caption = ",".join(tokens).strip() + captions.append(caption) + + # input_idsをpadしてTensor変換 if self.disable_padding: # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?) - input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids else: # paddingする - input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding='max_length', max_length=self.tokenizer.model_max_length, - return_tensors='pt').input_ids - + input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids + + example = {} + example['loss_weights'] = torch.FloatTensor(loss_weights) example['input_ids'] = input_ids - + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example['images'] = images + example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None if self.debug_dataset: - example['caption'] = caption + example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]] + example['captions'] = captions return example @@ -916,7 +1069,7 @@ def load_models_from_stable_diffusion_checkpoint(ckpt_path): return text_model, vae, unet -def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps): +def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None): # VAEがメモリ上にないので、もう一度VAEを含めて読み込む checkpoint = load_checkpoint_with_conversion(ckpt_path) state_dict = checkpoint["state_dict"] @@ -926,6 +1079,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, for k, v in unet_state_dict.items(): key = "model.diffusion_model." + k assert key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v # Convert the text encoder model @@ -933,6 +1088,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, for k, v in text_enc_dict.items(): key = "cond_stage_model.transformer." + k assert key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) state_dict[key] = v # Put together new checkpoint @@ -951,24 +1108,7 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, def collate_fn(examples): - input_ids = [e['input_ids'] for e in examples] - input_ids = torch.stack(input_ids) - - if 'latents' in examples[0]: - pixel_values = None - latents = [e['latents'] for e in examples] - latents = torch.stack(latents) - else: - pixel_values = [e['image'] for e in examples] - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - latents = None - - loss_weights = [e['loss_weight'] for e in examples] - loss_weights = torch.FloatTensor(loss_weights) - - batch = {"input_ids": input_ids, "pixel_values": pixel_values, "latents": latents, "loss_weights": loss_weights} - return batch + return examples[0] def train(args): @@ -998,19 +1138,22 @@ def load_dreambooth_dir(dir): try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}") - raise e + # print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}") + # raise e + return 0, [] caption = '_'.join(tokens[1:]) - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + print(f"found directory {n_repeats}_{caption}") + + img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp")) return n_repeats, [(ip, caption) for ip in img_paths] print("prepare train images.") train_img_path_captions = [] if fine_tuning: - img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) for img_path in tqdm(img_paths): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] @@ -1042,7 +1185,7 @@ def load_dreambooth_dir(dir): n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir)) for _ in range(n_repeats): train_img_path_captions.extend(img_caps) - print(f"{len(train_img_path_captions)} train images.") + print(f"{len(train_img_path_captions)} train images with repeating.") reg_img_path_captions = [] if args.reg_data_dir: @@ -1054,11 +1197,6 @@ def load_dreambooth_dir(dir): reg_img_path_captions.extend(img_caps) print(f"{len(reg_img_path_captions)} reg images.") - if args.debug_dataset: - # デバッグ時はshuffleして実際のデータセット使用時に近づける(学習時はdata loaderでshuffleする) - random.shuffle(train_img_path_captions) - random.shuffle(reg_img_path_captions) - # データセットを準備する resolution = tuple([int(r) for r in args.resolution.split(',')]) if len(resolution) == 1: @@ -1078,21 +1216,25 @@ def load_dreambooth_dir(dir): tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) print("prepare dataset") - train_dataset = DreamBoothOrFineTuningDataset(fine_tuning, train_img_path_captions, - reg_img_path_captions, tokenizer, resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.shuffle_caption, args.no_token_padding, args.debug_dataset) + train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, + args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, + args.shuffle_caption, args.no_token_padding, args.debug_dataset) if args.debug_dataset: - print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") + train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") for example in train_dataset: - im = example['image'] - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - print(f'caption: "{example["caption"]}", loss weight: {example["loss_weight"]}') - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() + for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']): + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}') + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break if k == 27: break return @@ -1100,7 +1242,14 @@ def load_dreambooth_dir(dir): # acceleratorを準備する # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする print("prepare accelerator") - accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision) + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime()) + accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision, + log_with=log_with, logging_dir=logging_dir) # モデルを読み込む if use_stable_diffusion_format: @@ -1122,28 +1271,24 @@ def load_dreambooth_dir(dir): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + # 学習を準備する if cache_latents: - # latentをcacheする→新しいDatasetを作るとcaptionのshuffleが効かないので元のDatasetにcacheを持つ(cascadeする手もあるが) - print("caching latents.") vae.to(accelerator.device, dtype=weight_dtype) - - for i in tqdm(range(len(train_dataset))): - example = train_dataset[i] - if 'latents' not in example: - image_path = example['image_path'] - with torch.no_grad(): - pixel_values = example["image"].unsqueeze(0).to(device=accelerator.device, dtype=weight_dtype) - latents = vae.encode(pixel_values).latent_dist.sample().squeeze(0).to("cpu") - train_dataset.set_cached_latents(image_path, latents) - # assertion - for i in range(len(train_dataset)): - assert 'latents' in train_dataset[i], "internal error: latents not cached" - + with torch.no_grad(): + train_dataset.make_buckets_with_caching(args.enable_bucket, vae) del vae if torch.cuda.is_available(): torch.cuda.empty_cache() else: + train_dataset.make_buckets_with_caching(args.enable_bucket, None) vae.requires_grad_(False) if args.gradient_checkpointing: @@ -1173,7 +1318,7 @@ def load_dreambooth_dir(dir): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps) @@ -1185,6 +1330,11 @@ def load_dreambooth_dir(dir): if not cache_latents: vae.to(accelerator.device, dtype=weight_dtype) + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + # epoch数を計算する num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader)) @@ -1193,7 +1343,7 @@ def load_dreambooth_dir(dir): print("running training / 学習開始") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") - print(f" num examples / サンプル数: {len(train_dataset)}") + print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -1222,7 +1372,7 @@ def load_dreambooth_dir(dir): if cache_latents: latents = batch["latents"].to(accelerator.device) else: - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -1271,15 +1421,22 @@ def load_dreambooth_dir(dir): global_step += 1 current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + accelerator.log(logs, step=global_step) + loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - # accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break + if args.logging_dir is not None: + logs = {"epoch_loss": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch+1) + accelerator.wait_for_everyone() if use_stable_diffusion_format and args.save_every_n_epochs is not None: @@ -1288,7 +1445,11 @@ def load_dreambooth_dir(dir): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1)) save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet), - args.pretrained_model_name_or_path, epoch + 1, global_step) + args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype) + + if args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) is_main_process = accelerator.is_main_process if is_main_process: @@ -1296,6 +1457,11 @@ def load_dreambooth_dir(dir): text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() + + if args.save_state: + print("saving last state.") + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + del accelerator # この後メモリを使うのでこれは消す if is_main_process: @@ -1303,7 +1469,8 @@ def load_dreambooth_dir(dir): if use_stable_diffusion_format: ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME) print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step) + save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, + args.pretrained_model_name_or_path, epoch, global_step, save_dtype) else: # Create the pipeline using using the trained modules and save it. print(f"save trained model as Diffusers to {args.output_dir}") @@ -1589,6 +1756,10 @@ def forward_xformers(self, x, context=None, mask=None): help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, + help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") @@ -1612,6 +1783,8 @@ def forward_xformers(self, x, context=None, mask=None): help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument("--cache_latents", action="store_true", help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") + parser.add_argument("--enable_bucket", action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") @@ -1619,8 +1792,12 @@ def forward_xformers(self, x, context=None, mask=None): help="enable gradient checkpointing / grandient checkpointingを有効にする") parser.add_argument("--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") parser.add_argument("--clip_skip", type=int, default=None, help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") args = parser.parse_args() train(args)