Skip to content

Commit

Permalink
Join paths, update default scheduler
Browse files Browse the repository at this point in the history
Fix "subfolder not recognized" issue.
  • Loading branch information
d8ahazard committed Nov 10, 2022
1 parent 4d9b93d commit 75e40bb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
6 changes: 2 additions & 4 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,15 +807,14 @@ def save_weights(step, save_model, save_img):
if args.train_text_encoder:
text_enc_model = accelerator.unwrap_model(text_encoder)
else:
text_enc_model = CLIPTextModel.from_pretrained(args.working_dir, subfolder="text_encoder")
text_enc_model = CLIPTextModel.from_pretrained(os.path.join(args.working_dir, "text_encoder"))
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipeline = StableDiffusionPipeline.from_pretrained(
args.working_dir,
unet=accelerator.unwrap_model(unet),
text_encoder=text_enc_model,
vae=AutoencoderKL.from_pretrained(
args.working_dir,
subfolder="vae"
os.path.join(args.working_dir, "vae")
),
safety_checker=None,
scheduler=scheduler,
Expand Down Expand Up @@ -862,7 +861,6 @@ def save_weights(step, save_model, save_img):
shared.state.job_no = global_step
shared.state.textinfo = f"Training step: {global_step}/{args.max_train_steps}"
loss_avg = AverageMeter()
training_complete = False
text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
for epoch in range(args.num_train_epochs):
try:
Expand Down
4 changes: 1 addition & 3 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def on_ui_tabs():
db_new_model_name = gr.Textbox(label="Name")
src_checkpoint = gr.Dropdown(label='Source Checkpoint', choices=sorted(
sd_models.checkpoints_list.keys()))
# I just randomly chose ddim here because we use it everywhere else. Not sure which of these
# are ideal, or if it matters at all.
diff_type = gr.Dropdown(label='Scheduler', choices=["pndm", "ddim", "lms"], value="pndm")
diff_type = gr.Dropdown(label='Scheduler', choices=["ddim", "pndm", "lms"], value="ddim")

with gr.Row():
with gr.Column(scale=3):
Expand Down

0 comments on commit 75e40bb

Please sign in to comment.