Skip to content

Commit

Permalink
Merge pull request #906 from shirayu/accept_scheduler_designation_in_…
Browse files Browse the repository at this point in the history
…training

Accept sampler designation in sampling of training
  • Loading branch information
kohya-ss authored Dec 3, 2023
2 parents ee46134 + 40d917b commit df59822
Showing 1 changed file with 134 additions and 107 deletions.
241 changes: 134 additions & 107 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4447,11 +4447,118 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"

def get_my_scheduler(
*,
sample_sampler: str,
v_parameterization: bool,
):
sched_init_args = {}
if sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif sample_sampler == "lms" or sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif sample_sampler == "euler" or sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sample_sampler
elif sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler

if v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"

scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)

# clip_sample=Trueにする
if (
hasattr(scheduler.config, "clip_sample")
and scheduler.config.clip_sample is False
):
# print("set clip_sample to True")
scheduler.config.clip_sample = True

return scheduler


def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)


def line_to_prompt_dict(line: str) -> dict:
# subset of gen_img_diffusers
prompt_args = line.split(" --")
prompt_dict = {}
prompt_dict['prompt'] = prompt_args[0]

for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict['width'] = int(m.group(1))
continue

m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict['height'] = int(m.group(1))
continue

m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict['seed'] = int(m.group(1))
continue

m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1))))
continue

m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
prompt_dict['scale'] = float(m.group(1))
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict['negative_prompt'] = m.group(1)
continue

m = re.match(r"ss (.+)", parg, re.IGNORECASE)
if m:
prompt_dict['sample_sampler'] = m.group(1)
continue

m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m:
prompt_dict['controlnet_image'] = m.group(1)
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)

return prompt_dict

def sample_images_common(
pipe_class,
accelerator,
Expand Down Expand Up @@ -4504,56 +4611,19 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)

# schedulerを用意する
sched_init_args = {}
if args.sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler

if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"

scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
schedulers: dict = {}
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)

# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
schedulers[args.sample_sampler] = default_scheduler

pipeline = pipe_class(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
scheduler=default_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
Expand All @@ -4569,78 +4639,34 @@ def sample_images_common(

with torch.no_grad():
# with accelerator.autocast():
for i, prompt in enumerate(prompts):
for i, prompt_dict in enumerate(prompts):
if not accelerator.is_main_process:
continue

if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue

# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue

m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue

m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue

m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue

m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue

m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)

assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler

if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
Expand All @@ -4658,6 +4684,7 @@ def sample_images_common(
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
Expand Down

0 comments on commit df59822

Please sign in to comment.