Skip to content

Commit

Permalink
Add v8 of train_db_fixed.py
Browse files Browse the repository at this point in the history
Add diffusers_fine_tuning
  • Loading branch information
bmaltais committed Nov 10, 2022
1 parent 23a5b7f commit 36b06d4
Show file tree
Hide file tree
Showing 19 changed files with 3,723 additions and 4,690 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
venv
mytraining.ps
__pycache__
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
## Change history
* 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short).
* 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short).
* 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]`
3 changes: 3 additions & 0 deletions diffusers_fine_tuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Diffusers Fine Tuning

This subfolder provide all the required toold to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
122 changes: 122 additions & 0 deletions diffusers_fine_tuning/clean_captions_and_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss

import argparse
import glob
import os
import json

from tqdm import tqdm


def clean_tags(image_key, tags):
# replace '_' to ' '
tags = tags.replace('_', ' ')

# remove rating
tokens = tags.split(", rating")
if len(tokens) == 1:
print("no rating:")
print(f"{image_key} {tags}")
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
tags = tokens[0]

return tags


# 上から順に検索、置換される
# ('置換元文字列', '置換後文字列')
CAPTION_REPLACEMENTS = [
('anime anime', 'anime'),
('young ', ''),
('anime girl', 'girl'),
('cartoon female', 'girl'),
('cartoon lady', 'girl'),
('cartoon character', 'girl'), # a or ~s
('cartoon woman', 'girl'),
('cartoon women', 'girls'),
('cartoon girl', 'girl'),
('anime female', 'girl'),
('anime lady', 'girl'),
('anime character', 'girl'), # a or ~s
('anime woman', 'girl'),
('anime women', 'girls'),
('lady', 'girl'),
('female', 'girl'),
('woman', 'girl'),
('women', 'girls'),
('people', 'girls'),
('person', 'girl'),
('a cartoon figure', 'a figure'),
('a cartoon image', 'an image'),
('a cartoon picture', 'a picture'),
('an anime cartoon image', 'an image'),
('a cartoon anime drawing', 'a drawing'),
('a cartoon drawing', 'a drawing'),
('girl girl', 'girl'),
]


def clean_caption(caption):
for rf, rt in CAPTION_REPLACEMENTS:
replaced = True
while replaced:
bef = caption
caption = caption.replace(rf, rt)
replaced = bef != caption
return caption

def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
print(f"found {len(image_paths)} images.")

if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
return

print("cleaning captions and tags.")
for image_path in tqdm(image_paths):
tags_path = os.path.splitext(image_path)[0] + '.txt'
with open(tags_path, "rt", encoding='utf-8') as f:
tags = f.readlines()[0].strip()

image_key = os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:
print(f"image not in metadata / メタデータに画像がありません: {image_path}")
return

tags = metadata[image_key].get('tags')
caption = metadata[image_key].get('caption')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_path}")
return
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
return

metadata[image_key]['tags'] = clean_tags(image_key, tags)
metadata[image_key]['caption'] = clean_caption(caption)

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
# parser.add_argument("--debug", action="store_true", help="debug mode")

args = parser.parse_args()
main(args)
Loading

0 comments on commit 36b06d4

Please sign in to comment.