Skip to content

Commit

Permalink
Update to latest kohya_ss sd-script code
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Feb 3, 2023
1 parent c8f4c9d commit 20e62af
Show file tree
Hide file tree
Showing 16 changed files with 790 additions and 384 deletions.
2 changes: 1 addition & 1 deletion README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。(bf1
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r <requirement file name>
pip install --upgrade -r requirements.txt
```

コマンドが成功すれば新しいバージョンが使用できます。
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,23 @@ Then redo the installation instruction within the kohya_ss venv.

## Change history

* 2023/02/03
- Increase max LoRA rank (dim) size to 1024.
- Update finetune preprocessing scripts.
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
- The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work.
- To change the weight, remove ``wd14_tagger_model`` folder, and run the script again.
- ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster.
- Please specify 2 or 4, depends on the number of CPU cores.
- ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``.
- ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning.
- ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade).
- Usage is almost the same as ``make_captions.py``, but batch size should be smaller.
- ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``).
- ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option.
- ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed.
- Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko!
- __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
* 2023/01/30 (v20.5.2):
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
Expand Down
65 changes: 63 additions & 2 deletions finetune/clean_captions_and_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,32 @@
import glob
import os
import json
import re

from tqdm import tqdm

PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')

# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
PATTERNS_REMOVE_IN_MULTI = [
PATTERN_HAIR_LENGTH,
PATTERN_HAIR_CUT,
re.compile(r', [\w\-]+ eyes, '),
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
# 複数の髪型定義がある場合は削除する
re.compile(
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
]


def clean_tags(image_key, tags):
# replace '_' to ' '
tags = tags.replace('^_^', '^@@@^')
tags = tags.replace('_', ' ')
tags = tags.replace('^@@@^', '^_^')

# remove rating: deepdanbooruのみ
tokens = tags.split(", rating")
Expand All @@ -26,6 +45,37 @@ def clean_tags(image_key, tags):
print(f"{image_key} {tags}")
tags = tokens[0]

tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策

# 複数の人物がいる場合は髪色等のタグを削除する
if 'girls' in tags or 'boys' in tags:
for pat in PATTERNS_REMOVE_IN_MULTI:
found = pat.findall(tags)
if len(found) > 1: # 二つ以上、タグがある
tags = pat.sub("", tags)

# 髪の特殊対応
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
if srch_hair_len:
org = srch_hair_len.group()
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)

found = PATTERN_HAIR.findall(tags)
if len(found) > 1:
tags = PATTERN_HAIR.sub("", tags)

if srch_hair_len:
tags = tags.replace(", @@@, ", org) # 戻す

# white shirtとshirtみたいな重複タグの削除
found = PATTERN_WORD.findall(tags)
for word in found:
if re.search(f", ((\w+) )+{word}, ", tags):
tags = tags.replace(f", {word}, ", "")

tags = tags.replace(", , ", ", ")
assert tags.startswith(", ") and tags.endswith(", ")
tags = tags[2:-2]
return tags


Expand Down Expand Up @@ -88,13 +138,23 @@ def main(args):
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
metadata[image_key]['tags'] = clean_tags(image_key, tags)
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)

caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
metadata[image_key]['caption'] = clean_caption(caption)
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
Expand All @@ -108,6 +168,7 @@ def main(args):
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode")

args, unknown = parser.parse_known_args()
if len(unknown) == 1:
Expand Down
101 changes: 76 additions & 25 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,59 @@
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder
import library.train_util as train_util

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


IMAGE_SIZE = 384

# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
img_path = self.images[idx]

try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None

return (tensor, img_path)


def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch


def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path

Expand All @@ -31,24 +72,15 @@ def main(args):
os.chdir('finetune')

print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")

print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")

# 正方形でいいのか? という気がするがソースがそうなので
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
Expand All @@ -66,18 +98,35 @@ def run_batch(path_imgs):
if args.debug:
print(image_path, caption)

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]

b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
raw_image = Image.open(image_path)
if raw_image.mode != "RGB":
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
raw_image = raw_image.convert("RGB")

image = transform(raw_image)
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue

img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)

Expand All @@ -95,6 +144,8 @@ def run_batch(path_imgs):
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
Expand Down
Loading

0 comments on commit 20e62af

Please sign in to comment.