Skip to content

Commit

Permalink
Adjusted English grammar in logs to be more clear (#554)
Browse files Browse the repository at this point in the history
* Update train_network.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Update train_network.py
  • Loading branch information
TingTingin authored Jun 1, 2023
1 parent 8a5e390 commit 5931948
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,25 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
print(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Use DreamBooth method.")
print("Using DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
print("Training with captions.")
user_config = {
"datasets": [
{
Expand Down Expand Up @@ -135,7 +135,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

# acceleratorを準備する
print("prepare accelerator")
print("preparing accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process

Expand All @@ -147,7 +147,7 @@ def train(args):

# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)

# 差分追加学習のためにモデルを読み込む
import sys

Expand All @@ -171,7 +171,6 @@ def train(args):
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")

print(f"all weights merged: {', '.join(args.base_weights)}")

# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
Expand Down Expand Up @@ -210,15 +209,15 @@ def train(args):

if args.network_weights is not None:
info = network.load_weights(args.network_weights)
print(f"load network weights from {args.network_weights}: {info}")
print(f"loaded network weights from {args.network_weights}: {info}")

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
network.enable_gradient_checkpointing() # may have no effect

# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
print("preparing optimizer, data loader etc.")

# 後方互換性を確保するよ
try:
Expand Down Expand Up @@ -263,7 +262,7 @@ def train(args):
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
print("enabling full fp16 training.")
network.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
Expand Down

0 comments on commit 5931948

Please sign in to comment.