Skip to content

Commit

Permalink
Merge pull request #504 from Linaqruf/main
Browse files Browse the repository at this point in the history
TOML support for sample prompt
  • Loading branch information
kohya-ss authored May 15, 2023
2 parents 99f4940 + 8ab5c8c commit b556fc4
Showing 1 changed file with 72 additions and 49 deletions.
121 changes: 72 additions & 49 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,9 +3291,22 @@ def sample_images(
vae.to(device)

# read prompts
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
prompts = f.readlines()

# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines()

if args.sample_prompts.endswith('.txt'):
with open(args.sample_prompts, 'r') as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith('.toml'):
with open(args.sample_prompts, 'r') as f:
data = toml.load(f)
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
elif args.sample_prompts.endswith('.json'):
with open(args.sample_prompts, 'r') as f:
prompts = json.load(f)

# schedulerを用意する
sched_init_args = {}
if args.sample_sampler == "ddim":
Expand Down Expand Up @@ -3362,53 +3375,63 @@ def sample_images(
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == "#":
continue

# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue

m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue

m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue

m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue

m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)

if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue

# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue

m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue

m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue

m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue

m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)

if seed is not None:
torch.manual_seed(seed)
Expand Down

0 comments on commit b556fc4

Please sign in to comment.