diff --git a/cache_dataset_meta.py b/cache_dataset_meta.py new file mode 100644 index 000000000..7e7d96d12 --- /dev/null +++ b/cache_dataset_meta.py @@ -0,0 +1,103 @@ +import argparse +import random + +from accelerate.utils import set_seed + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def make_dataset(args): + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator( + ConfigSanitizer(True, True, False, True) + ) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=None) + train_dataset_group = config_util.generate_dataset_group_by_blueprint( + blueprint.dataset_group + ) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None) + return train_dataset_group + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + add_logging_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args, unknown = parser.parse_known_args() + args = train_util.read_config_from_file(args, parser) + if args.max_token_length is None: + args.max_token_length = 75 + args.cache_meta = True + + dataset_group = make_dataset(args) diff --git a/library/config_util.py b/library/config_util.py index d543c3312..b631b4949 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -110,6 +110,8 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 + cache_meta: bool = False + use_cached_meta: bool = False @dataclass @@ -225,6 +227,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + "cache_meta": bool, + "use_cached_meta": bool, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 346b2f076..f8eb63463 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,7 @@ from huggingface_hub import hf_hub_download import numpy as np from PIL import Image +import imagesize import cv2 import safetensors.torch import traceback @@ -1033,8 +1034,7 @@ def cache_text_encoder_outputs( ) def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size + return imagesize.get(image_path) def load_image_with_face_info(self, subset: BaseSubset, image_path: str): img = load_image(image_path) @@ -1396,6 +1396,8 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + cache_meta: bool, + use_cached_meta: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset, trust_cache) @@ -1452,26 +1454,43 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") + sizes = None + if use_cached_meta: + logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt") + # [img_path, caption, resolution] + with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f: + metas = f.readlines() + metas = [x.strip().split("<|##|>") for x in metas] + sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas] + + if use_cached_meta: + img_paths = [x[0] for x in metas] + else: + img_paths = glob_images(subset.image_dir, "*") + sizes = [None]*len(img_paths) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: - logger.warning( - f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" - ) - captions.append("") - missing_captions.append(img_path) - else: - if cap_for_img is None: - captions.append(subset.class_tokens) + if use_cached_meta: + captions = [x[1] for x in metas] + missing_captions = [x[0] for x in metas if x[1] == ""] + else: + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + if cap_for_img is None and subset.class_tokens is None: + logger.warning( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) + captions.append("") missing_captions.append(img_path) else: - captions.append(cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 @@ -1488,7 +1507,21 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(missing_caption + f"... and {remaining_missing_captions} more") break logger.warning(missing_caption) - return img_paths, captions + + if cache_meta: + logger.info(f"cache metadata for {subset.image_dir}") + if sizes is None or sizes[0] is None: + sizes = [self.get_image_size(img_path) for img_path in img_paths] + # [img_path, caption, resolution] + data = [ + (img_path, caption, " ".join(str(x) for x in size)) + for img_path, caption, size in zip(img_paths, captions, sizes) + ] + with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f: + f.write("\n".join(["<|##|>".join(x) for x in data])) + logger.info(f"cache metadata done for {subset.image_dir}") + + return img_paths, captions, sizes logger.info("prepare images.") num_train_images = 0 @@ -1507,7 +1540,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): ) continue - img_paths, captions = load_dreambooth_dir(subset) + img_paths, captions, sizes = load_dreambooth_dir(subset) if len(img_paths) < 1: logger.warning( f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" @@ -1519,8 +1552,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): else: num_train_images += subset.num_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): + for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size if subset.is_reg: reg_infos.append(info) else: @@ -3294,6 +3329,12 @@ def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool ): # dataset common + parser.add_argument( + "--cache_meta", action="store_true" + ) + parser.add_argument( + "--use_cached_meta", action="store_true" + ) parser.add_argument( "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" ) diff --git a/requirements.txt b/requirements.txt index 279de350c..5862f3caa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +# for Image utils +imagesize==1.4.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/train_network.py b/train_network.py index e0fa69458..34d2c6b3b 100644 --- a/train_network.py +++ b/train_network.py @@ -6,6 +6,7 @@ import random import time import json +import pickle from multiprocessing import Value import toml @@ -23,7 +24,7 @@ import library.train_util as train_util from library.train_util import ( - DreamBoothDataset, + DreamBoothDataset, DatasetGroup ) import library.config_util as config_util from library.config_util import (