Skip to content

Commit

Permalink
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sd…
Browse files Browse the repository at this point in the history
…xl-dev
  • Loading branch information
bmaltais committed Jul 25, 2023
2 parents 12f7ca8 + b78c0e2 commit 101d263
Showing 1 changed file with 0 additions and 48 deletions.
48 changes: 0 additions & 48 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,54 +286,6 @@ def diffusers_saver(out_dir):
)


# TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
print("caching text encoder outputs")

tokenizer1, tokenizer2 = tokenizers
text_encoder1, text_encoder2 = text_encoders
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if weight_dtype is not None:
text_encoder1.to(dtype=weight_dtype)
text_encoder2.to(dtype=weight_dtype)

text_encoder1_cache = {}
text_encoder2_cache = {}
for batch in tqdm(dataset):
input_ids1_batch = batch["input_ids"].to(accelerator.device)
input_ids2_batch = batch["input_ids2"].to(accelerator.device)

# split batch to avoid OOM
# TODO specify batch size by args
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
# remove input_ids already in cache
input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2_cache_key = tuple(input_id2.flatten().tolist())
if input_id1_cache_key in text_encoder1_cache:
assert input_id2_cache_key in text_encoder2_cache
continue

with torch.no_grad():
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
args,
input_id1,
input_id2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
return text_encoder1_cache, text_encoder2_cache


def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
Expand Down

0 comments on commit 101d263

Please sign in to comment.