Skip to content

Commit

Permalink
add caption_separator option
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 19, 2023
1 parent f312522 commit d0923d6
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def main(args):

tag_freq = {}

undesired_tags = set(args.undesired_tags.split(","))
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))

def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
Expand Down Expand Up @@ -194,7 +196,7 @@ def run_batch(path_imgs):

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
Expand All @@ -203,18 +205,18 @@ def run_batch(path_imgs):

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)

# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
character_tag_text = character_tag_text[len(caption_separator) :]

caption_file = os.path.splitext(image_path)[0] + args.caption_extension

tag_text = ", ".join(combined_tags)
tag_text = caption_separator.join(combined_tags)

if args.append_tags:
# Check if file exists
Expand All @@ -224,13 +226,13 @@ def run_batch(path_imgs):
existing_content = f.read().strip("\n") # Remove newlines

# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]

# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]

# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
tag_text = caption_separator.join(existing_tags + new_tags)

with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
Expand Down Expand Up @@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)

return parser

Expand Down

0 comments on commit d0923d6

Please sign in to comment.