Skip to content

Commit

Permalink
Revert "Allow loading from npz files only without images"
Browse files Browse the repository at this point in the history
This reverts commit df6712b.
  • Loading branch information
deepdelirious committed Mar 27, 2024
1 parent 925d561 commit 742ebd1
Showing 1 changed file with 16 additions and 43 deletions.
59 changes: 16 additions & 43 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,10 +1033,6 @@ def cache_text_encoder_outputs(
)

def get_image_size(self, image_path):
npz_path = os.path.splitext(image_path)[0] + ".npz"
if self.is_latent_cacheable and os.path.exists(npz_path):
_, size, _, _ = load_latents_from_disk(npz_path)
return size
image = Image.open(image_path)
return image.size

Expand Down Expand Up @@ -1456,23 +1452,15 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.warning(f"not directory: {subset.image_dir}")
return [], []

img_paths = glob_images(subset.image_dir, "*", self.is_latent_cacheable)
img_paths = glob_images(subset.image_dir, "*")
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
cached_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
if self.is_text_encoder_output_cacheable:
cache_file = os.path.splitext(img_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
if os.path.exists(cache_file):
captions.append("")
cached_captions.append(cache_file)
continue

logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
Expand All @@ -1487,26 +1475,19 @@ def load_dreambooth_dir(subset: DreamBoothSubset):

self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録

def show_caption_warning(captions_with_warnings, warning_message):
if not captions_with_warnings:
return

number_of_warning_captions = len(captions_with_warnings)
number_of_warning_captions_to_show = 5
remaining_warning_captions = number_of_warning_captions - number_of_warning_captions_to_show

logger.warning(warning_message.format(number_of_warning_captions=number_of_warning_captions))
for i, warning_caption in enumerate(captions_with_warnings):
if i >= number_of_warning_captions_to_show:
logger.warning(warning_caption + f"... and {remaining_warning_captions} more")
break
logger.warning(warning_caption)
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show

show_caption_warning(missing_captions,
"No caption file found for {number_of_warning_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_warning_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。")

show_caption_warning(cached_captions,
"No caption file found for {number_of_warning_captions} images. Cached TE embeddings are available for this caption, which will be used instead.")
logger.warning(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
break
logger.warning(missing_caption)
return img_paths, captions

logger.info("prepare images.")
Expand Down Expand Up @@ -1884,7 +1865,7 @@ def __init__(

extra_imgs = []
for subset in subsets:
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*", self.is_latent_cacheable)
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
extra_imgs.extend(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
Expand Down Expand Up @@ -2163,23 +2144,15 @@ def debug_dataset(train_dataset, show_input_ids=False):

epoch += 1

def glob_images(directory, base="*", fallback_to_cache=False):

def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == "*":
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths)) # 重複を排除

if fallback_to_cache and len(img_paths) == 0:
print(f"No images found in {directory}. Will look for cached latents instead.")
if base == "*":
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ".npz")))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ".npz"))))

img_paths = [img_path for img_path in set(img_paths) if not img_path.endswith(TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX)]
img_paths.sort()
return img_paths

Expand Down

0 comments on commit 742ebd1

Please sign in to comment.