Skip to content

Commit

Permalink
Replace print with logger if they are logs (#905)
Browse files Browse the repository at this point in the history
* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <[email protected]>
  • Loading branch information
shirayu and kohya-ss authored Feb 4, 2024
1 parent 7f948db commit 5f6bf29
Show file tree
Hide file tree
Showing 62 changed files with 1,195 additions and 961 deletions.
14 changes: 9 additions & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
Expand Down Expand Up @@ -49,11 +53,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
Expand Down Expand Up @@ -86,7 +90,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
Expand All @@ -97,7 +101,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)

# mixed precisionに対応した型を用意しておき適宜castする
Expand Down Expand Up @@ -461,7 +465,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
6 changes: 5 additions & 1 deletion finetune/blip/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

class BLIP_Base(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key]

msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
logger.info('load checkpoint from %s'%url_or_filename)
return model,msg

42 changes: 23 additions & 19 deletions finetune/clean_captions_and_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import re

from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
Expand Down Expand Up @@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating")
if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
# logger.info("no rating:")
# logger.info(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
logger.info("multiple ratings:")
logger.info(f"{image_key} {tags}")
tags = tokens[0]

tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
Expand Down Expand Up @@ -124,43 +128,43 @@ def clean_caption(caption):

def main(args):
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
logger.error("no metadata / メタデータファイルがありません")
return

print("cleaning captions and tags.")
logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys())
for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
logger.info("FROM: " + org)
logger.info("TO: " + tags)

caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
logger.info("FROM: " + org)
logger.info("TO: " + caption)

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand All @@ -178,10 +182,10 @@ def setup_parser() -> argparse.ArgumentParser:

args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。")
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
logger.warning("All captions and tags in the metadata are processed.")
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json
args.out_json = unknown[0]
elif len(unknown) > 0:
Expand Down
22 changes: 13 additions & 9 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -47,7 +51,7 @@ def __getitem__(self, idx):
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None

return (tensor, img_path)
Expand All @@ -74,21 +78,21 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path

cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)

print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

print(f"loading BLIP caption: {args.caption_weights}")
logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
logger.info("BLIP loaded")

# captioningする
def run_batch(path_imgs):
Expand All @@ -108,7 +112,7 @@ def run_batch(path_imgs):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f'{image_path} {caption}')

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
Expand Down Expand Up @@ -138,7 +142,7 @@ def run_batch(path_imgs):
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, img_tensor))
Expand All @@ -148,7 +152,7 @@ def run_batch(path_imgs):
if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
23 changes: 13 additions & 10 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util

from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -35,8 +38,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps

Expand Down Expand Up @@ -70,16 +73,16 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""

print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")

# captioningする
def run_batch(path_imgs):
Expand All @@ -97,7 +100,7 @@ def run_batch(path_imgs):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")

# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
Expand Down Expand Up @@ -126,7 +129,7 @@ def run_batch(path_imgs):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue

b_imgs.append((image_path, image))
Expand All @@ -137,7 +140,7 @@ def run_batch(path_imgs):
if len(b_imgs) > 0:
run_batch(b_imgs)

print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
20 changes: 12 additions & 8 deletions finetune/merge_captions_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,30 @@
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"

train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")

if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json

if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}

print("merge caption texts to metadata json.")
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
Expand All @@ -38,12 +42,12 @@ def main(args):

metadata[image_key]['caption'] = caption
if args.debug:
print(image_key, caption)
logger.info(f"{image_key} {caption}")

# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
logger.info("done!")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
Loading

0 comments on commit 5f6bf29

Please sign in to comment.