Skip to content

Commit

Permalink
fix for colaboratory env
Browse files Browse the repository at this point in the history
  • Loading branch information
aria1th committed Jan 21, 2023
1 parent 3dbb5ab commit 710e02a
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions patches/external_pr/hypernetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from modules import shared, sd_models, devices, processing, sd_samplers
from modules.hypernetworks.hypernetwork import optimizer_dict, stack_conds, save_hypernetwork, report_statistics
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from ..tbutils import tensorboard_setup, tensorboard_add, tensorboard_add_image, tensorboard_log_hyperparameter
from .textual_inversion import validate_train_inputs, write_loss
Expand Down Expand Up @@ -219,9 +220,13 @@ def gradient_clipping(arg1):
return
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
if not os.path.isfile(template_file):
template_file = textual_inversion.textual_inversion_templates.get(template_file, None)
if template_file is not None:
template_file = template_file.path
else:
raise AssertionError(f"Cannot find {template_file}!")
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
Expand Down Expand Up @@ -678,9 +683,13 @@ def gradient_clipping(arg1):
set_scheduler(-1, False, False)
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root,
template_file, steps, save_hypernetwork_every, create_image_every,
log_directory, name="hypernetwork")
if not os.path.isfile(template_file):
template_file = textual_inversion.textual_inversion_templates.get(template_file, None)
if template_file is not None:
template_file = template_file.path
else:
raise AssertionError(f"Cannot find {template_file}!")
validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
hypernetwork.to(devices.device)
assert hypernetwork is not None, f"Cannot load {hypernetwork_name}!"
if not isinstance(hypernetwork, Hypernetwork):
Expand Down

0 comments on commit 710e02a

Please sign in to comment.