Skip to content

Commit

Permalink
Fix LoHA issue with new kohya code
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jun 5, 2023
1 parent c389707 commit fdf8e9e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ This will store a backup file with your current locally installed pip packages a

## Change History

* 2023/07/05 (v21 7.5)
- Fix reported issue with LoHA: https://github.com/bmaltais/kohya_ss/issues/922
* 2023/06/05 (v21.7.4)
- Add manual accelerate config option
- Remove the ability to switch between torch 1 and 2 as it was causing errors with the venv
Expand Down
19 changes: 11 additions & 8 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ def train_model(
return
run_cmd += f' --network_module=lycoris.kohya'
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
# This is a hack to fix a train_network LoHA logic issue
if not network_dropout > 0.0:
run_cmd += f' --network_dropout="{network_dropout}"'

if LoRA_type in ['Kohya LoCon', 'Standard']:
kohya_lora_var_list = [
Expand Down Expand Up @@ -1262,25 +1265,25 @@ def update_LoRA_settings(LoRA_type):
scale_weight_norms = gr.Slider(
label='Scale weight norms',
value=0,
minimum=0.0,
maximum=1.0,
minimum=0,
maximum=1,
step=0.01,
info='Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.',
interactive=True,
)
network_dropout = gr.Slider(
label='Network dropout',
value=0.0,
minimum=0.0,
maximum=1.0,
value=0,
minimum=0,
maximum=1,
step=0.01,
info='Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5',
)
rank_dropout = gr.Slider(
label='Rank dropout',
value=0.0,
minimum=0.0,
maximum=1.0,
value=0,
minimum=0,
maximum=1,
step=0.01,
info='can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3',
)
Expand Down
39 changes: 39 additions & 0 deletions tools/setup_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,43 @@ def check_torch():
return 0


# report current version of code
def check_repo_version(): # pylint: disable=unused-argument
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#

if not os.path.exists('.git'):
log.error('Not a git repository')
return
# status = git('status')
# if 'branch' not in status:
# log.error('Cannot get git repository status')
# sys.exit(1)
ver = git('log -1 --pretty=format:"%h %ad"')
log.info(f'Version: {ver}')

# execute git command
def git(arg: str, folder: str = None, ignore: bool = False):
#
# This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master
#

git_cmd = os.environ.get('GIT', "git")
result = subprocess.run(f'"{git_cmd}" {arg}', check=False, shell=True, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=folder or '.')
txt = result.stdout.decode(encoding="utf8", errors="ignore")
if len(result.stderr) > 0:
txt += ('\n' if len(txt) > 0 else '') + result.stderr.decode(encoding="utf8", errors="ignore")
txt = txt.strip()
if result.returncode != 0 and not ignore:
global errors # pylint: disable=global-statement
errors += 1
log.error(f'Error running git: {folder} / {arg}')
if 'or stash them' in txt:
log.error(f'Local changes detected: check log for details: {log_file}')
log.debug(f'Git output: {txt}')
return txt

def cudann_install():
cudnn_src = os.path.join(
os.path.dirname(os.path.realpath(__file__)), '..\cudnn_windows'
Expand Down Expand Up @@ -465,6 +502,7 @@ def sync_bits_and_bytes_files():


def install_kohya_ss_torch1():
check_repo_version()
check_python()

# Upgrade pip if needed
Expand All @@ -491,6 +529,7 @@ def install_kohya_ss_torch1():


def install_kohya_ss_torch2():
check_repo_version()
check_python()

# Upgrade pip if needed
Expand Down

0 comments on commit fdf8e9e

Please sign in to comment.