From e6dd5fb4251f753a9cbbd35722ba3d2037687a97 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Wed, 13 Mar 2024 18:14:56 +0800 Subject: [PATCH 01/12] support meta cached dataset --- library/train_util.py | 7 +++++-- train_network.py | 23 +++++++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b71e4edc6..db9e41e48 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,7 @@ from huggingface_hub import hf_hub_download import numpy as np from PIL import Image +import imagesize import cv2 import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline @@ -1073,8 +1074,7 @@ def cache_text_encoder_outputs( ) def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size + return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str): img = load_image(image_path) @@ -3303,6 +3303,9 @@ def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common + parser.add_argument( + "--dataset_from_pkl", action="store_true" + ) parser.add_argument( "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) diff --git a/train_network.py b/train_network.py index e5b26d8a2..c204b4656 100644 --- a/train_network.py +++ b/train_network.py @@ -6,6 +6,7 @@ import random import time import json +import pickle from multiprocessing import Value import toml @@ -23,7 +24,7 @@ import library.train_util as train_util from library.train_util import ( - DreamBoothDataset, + DreamBoothDataset, DatasetGroup ) import library.config_util as config_util from library.config_util import ( @@ -156,7 +157,25 @@ def train(self, args): tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] # データセットを準備する - if args.dataset_class is None: + if args.dataset_from_pkl: + logger.info(f"Loading dataset from cached meta") + with open(f"{args.train_data_dir}/dataset-meta.pkl", "rb") as f: + train_dataset_group = pickle.load(f) + assert isinstance(train_dataset_group, DatasetGroup) + logger.info(f"Dataset Loaded") + logger.info(f"Dataset have {train_dataset_group.num_train_images} images") + logger.info(f"Dataset have {train_dataset_group.num_reg_images} reg images") + + # To simulate the correct behavior of random operations + # To avoid any potential to cause "seed breaking changes" + dataset_seed = random.randint(0, 2**31) + for dataset in train_dataset_group.datasets: + dataset.tokenizers = tokenizers + dataset.tokenizer_max_length = dataset.tokenizers[0].model_max_length if args.max_token_length is None else args.max_token_length + 2 + dataset.set_seed(0) + dataset.shuffle_buckets() + dataset.set_seed(dataset_seed) + elif args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: logger.info(f"Loading dataset config from {args.dataset_config}") From f94152b10aad7319737b5cb1a19a75c76cbe2e85 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Wed, 13 Mar 2024 18:15:17 +0800 Subject: [PATCH 02/12] add cache meta scripts --- cache_dataset_meta.py | 105 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 cache_dataset_meta.py diff --git a/cache_dataset_meta.py b/cache_dataset_meta.py new file mode 100644 index 000000000..6101d9394 --- /dev/null +++ b/cache_dataset_meta.py @@ -0,0 +1,105 @@ +import argparse +import random +import pickle + +from accelerate.utils import set_seed + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def make_dataset(args): + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator( + ConfigSanitizer(True, True, False, True) + ) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=None) + train_dataset_group = config_util.generate_dataset_group_by_blueprint( + blueprint.dataset_group + ) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None) + return train_dataset_group + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + add_logging_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args, unknown = parser.parse_known_args() + args = train_util.read_config_from_file(args, parser) + if args.max_token_length is None: + args.max_token_length = 75 + + dataset_group = make_dataset(args) + with open(f"{args.train_data_dir}/dataset-meta.pkl", "wb") as f: + pickle.dump(dataset_group, f) From c16ce03ff73c868fce828bd11951bd01aff19767 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:11:45 +0800 Subject: [PATCH 03/12] random ip_noise_gamma strength --- library/train_util.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index db9e41e48..8f7f49f37 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3100,6 +3100,13 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + type=bool, + default=False, + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -4676,7 +4683,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) From 39d1d15b4c5cb51c533b8f14adf1196dde33e126 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:14:01 +0800 Subject: [PATCH 04/12] random noise_offset strength --- library/train_util.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 8f7f49f37..cb74e7721 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3087,6 +3087,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--noise_offset_random_strength", + type=bool, + default=False, + help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", + ) parser.add_argument( "--multires_noise_iterations", type=int, @@ -4666,7 +4672,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount From 5a730e863a78542efcc21184ef650e255282285a Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Tue, 12 Mar 2024 19:24:27 +0800 Subject: [PATCH 05/12] use correct settings for parser --- library/train_util.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index cb74e7721..6c09c0a18 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3089,8 +3089,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument( "--noise_offset_random_strength", - type=bool, - default=False, + action="store_true", help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", ) parser.add_argument( @@ -3108,8 +3107,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument( "--ip_noise_gamma_random_strength", - type=bool, - default=False, + action="store_true", help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", ) From 0b76044acb706a593e171091d556fc91249a9cb7 Mon Sep 17 00:00:00 2001 From: kblueleaf Date: Sat, 16 Mar 2024 20:35:08 +0800 Subject: [PATCH 06/12] cache path/caption/size only --- cache_dataset_meta.py | 4 +-- library/config_util.py | 4 +++ library/train_util.py | 71 +++++++++++++++++++++++++++++++----------- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/cache_dataset_meta.py b/cache_dataset_meta.py index 6101d9394..7e7d96d12 100644 --- a/cache_dataset_meta.py +++ b/cache_dataset_meta.py @@ -1,6 +1,5 @@ import argparse import random -import pickle from accelerate.utils import set_seed @@ -99,7 +98,6 @@ def setup_parser() -> argparse.ArgumentParser: args = train_util.read_config_from_file(args, parser) if args.max_token_length is None: args.max_token_length = 75 + args.cache_meta = True dataset_group = make_dataset(args) - with open(f"{args.train_data_dir}/dataset-meta.pkl", "wb") as f: - pickle.dump(dataset_group, f) diff --git a/library/config_util.py b/library/config_util.py index eb652ecf3..58ffa5f4d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 + cache_meta: bool = False + use_cached_meta: bool = False @dataclass @@ -228,6 +230,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "cache_meta": bool, + "use_cached_meta": bool, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 6c09c0a18..82da15cf7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1418,6 +1418,8 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + cache_meta: bool, + use_cached_meta: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1474,26 +1476,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") + sizes = None + if use_cached_meta: + logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt") + # [img_path, caption, resolution] + with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f: + metas = f.readlines() + metas = [x.strip().split("<|##|>") for x in metas] + sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas] + + if use_cached_meta: + img_paths = [x[0] for x in metas] + else: + img_paths = glob_images(subset.image_dir, "*") + sizes = [None]*len(img_paths) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: - logger.warning( - f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" - ) - captions.append("") - missing_captions.append(img_path) - else: - if cap_for_img is None: - captions.append(subset.class_tokens) + if use_cached_meta: + captions = [x[1] for x in metas] + missing_captions = [x[0] for x in metas if x[1] == ""] + else: + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + logger.warning( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) + captions.append("") missing_captions.append(img_path) else: - captions.append(cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 @@ -1510,7 +1529,19 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break logger.warning(missing_caption) - return img_paths, captions + + if cache_meta: + logger.info(f"cache metadata for {subset.image_dir}") + # [img_path, caption, resolution] + data = [ + (img_path, caption, " ".join(str(x) for x in self.get_image_size(img_path))) + for img_path, caption in zip(img_paths, captions) + ] + with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f: + f.write("\n".join(["<|##|>".join(x) for x in data])) + logger.info(f"cache metadata done for {subset.image_dir}") + + return img_paths, captions, sizes logger.info("prepare images.") num_train_images = 0 @@ -1529,7 +1560,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): ) continue - img_paths, captions = load_dreambooth_dir(subset) + img_paths, captions, sizes = load_dreambooth_dir(subset) if len(img_paths) < 1: logger.warning( f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" @@ -1541,8 +1572,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): else: num_train_images += subset.num_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): + for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size if subset.is_reg: reg_infos.append(info) else: From 4f0666713b3d02e0c3899e9df4a902bc631bc509 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:36:59 +0800 Subject: [PATCH 07/12] revert mess up commit --- library/train_util.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 82da15cf7..243000eb9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3120,11 +3120,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) - parser.add_argument( - "--noise_offset_random_strength", - action="store_true", - help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", - ) parser.add_argument( "--multires_noise_iterations", type=int, @@ -4703,11 +4698,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - if args.noise_offset_random_strength: - noise_offset = torch.rand(1, device=latents.device) * args.noise_offset - else: - noise_offset = args.noise_offset - noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) + noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) if args.multires_noise_iterations: noise = custom_train_functions.pyramid_noise_like( noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount From 9c1e377c8e6004115bc46e2c57ddebebfdf9599d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:37:48 +0800 Subject: [PATCH 08/12] revert mess up commit --- library/train_util.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 243000eb9..914df207a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3133,12 +3133,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", ) - parser.add_argument( - "--ip_noise_gamma_random_strength", - action="store_true", - help="Use random strength between 0~ip_noise_gamma for input perturbation noise." - + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", - ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -4715,11 +4709,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma - else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) From a681a5a34f66a911defcad6e4d5160f41747b550 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:39:33 +0800 Subject: [PATCH 09/12] Update requirements.txt --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 279de350c..5862f3caa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +# for Image utils +imagesize==1.4.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 From 8caca5947d06f6b82b0d849084dad9a66e4427b0 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:42:52 +0800 Subject: [PATCH 10/12] Add arguments for meta cache. --- library/train_util.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 914df207a..63aa6f63a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3339,6 +3339,12 @@ def add_dataset_arguments( parser.add_argument( "--dataset_from_pkl", action="store_true" ) + parser.add_argument( + "--cache_meta", action="store_true" + ) + parser.add_argument( + "--use_cached_meta", action="store_true" + ) parser.add_argument( "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) From ae2774f80546ce00e0d2973f1a0f20fded52ff60 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 20:44:19 +0800 Subject: [PATCH 11/12] remove pickle implementation --- library/train_util.py | 3 --- train_network.py | 20 +------------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 63aa6f63a..cf9fb93b1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3336,9 +3336,6 @@ def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common - parser.add_argument( - "--dataset_from_pkl", action="store_true" - ) parser.add_argument( "--cache_meta", action="store_true" ) diff --git a/train_network.py b/train_network.py index c204b4656..05917b23a 100644 --- a/train_network.py +++ b/train_network.py @@ -157,25 +157,7 @@ def train(self, args): tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] # データセットを準備する - if args.dataset_from_pkl: - logger.info(f"Loading dataset from cached meta") - with open(f"{args.train_data_dir}/dataset-meta.pkl", "rb") as f: - train_dataset_group = pickle.load(f) - assert isinstance(train_dataset_group, DatasetGroup) - logger.info(f"Dataset Loaded") - logger.info(f"Dataset have {train_dataset_group.num_train_images} images") - logger.info(f"Dataset have {train_dataset_group.num_reg_images} reg images") - - # To simulate the correct behavior of random operations - # To avoid any potential to cause "seed breaking changes" - dataset_seed = random.randint(0, 2**31) - for dataset in train_dataset_group.datasets: - dataset.tokenizers = tokenizers - dataset.tokenizer_max_length = dataset.tokenizers[0].model_max_length if args.max_token_length is None else args.max_token_length + 2 - dataset.set_seed(0) - dataset.shuffle_buckets() - dataset.set_seed(dataset_seed) - elif args.dataset_class is None: + if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: logger.info(f"Loading dataset config from {args.dataset_config}") From efed44665634967013288048ee883b063825733f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Mar 2024 21:42:00 +0800 Subject: [PATCH 12/12] Return sizes when enable cache --- library/train_util.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index cf9fb93b1..8687a0c25 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1529,18 +1529,20 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break logger.warning(missing_caption) - + if cache_meta: logger.info(f"cache metadata for {subset.image_dir}") + if sizes is None or sizes[0] is None: + sizes = [self.get_image_size(img_path) for img_path in img_paths] # [img_path, caption, resolution] data = [ - (img_path, caption, " ".join(str(x) for x in self.get_image_size(img_path))) - for img_path, caption in zip(img_paths, captions) + (img_path, caption, " ".join(str(x) for x in size)) + for img_path, caption, size in zip(img_paths, captions, sizes) ] with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f: f.write("\n".join(["<|##|>".join(x) for x in data])) logger.info(f"cache metadata done for {subset.image_dir}") - + return img_paths, captions, sizes logger.info("prepare images.")