Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Improve training and tensorboard buttons and code" #2269

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v23.1.6
v23.1.5
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ The GUI allows you to set the training parameters and generate and run the requi
- [SDXL training](#sdxl-training)
- [Masked loss](#masked-loss)
- [Change History](#change-history)
- [2024/04/12 (v23.1.6)](#20240412-v2316)
- [2024/04/10 (v23.1.5)](#20240410-v2315)
- [Security Improvements](#security-improvements)
- [2024/04/08 (v23.1.4)](#20240408-v2314)
Expand Down Expand Up @@ -402,10 +401,6 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG

## Change History

### 2024/04/12 (v23.1.6)

- Improved the training and tensorboard buttons and code

### 2024/04/10 (v23.1.5)

- Fix issue with Textual Inversion configuration file selection.
Expand Down
3 changes: 0 additions & 3 deletions kohya_gui/class_command_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import subprocess
import psutil
import gradio as gr
from .custom_logging import setup_logging

# Set up logging
Expand Down Expand Up @@ -53,5 +52,3 @@ def kill_command(self):
log.info(f"Error when terminating process: {e}")
else:
log.info("There is no running process to kill.")

return gr.Button(visible=True), gr.Button(visible=False)
106 changes: 0 additions & 106 deletions kohya_gui/class_tensorboard.py

This file was deleted.

47 changes: 27 additions & 20 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
from .class_advanced_training import AdvancedTraining
from .class_folders import Folders
from .class_command_executor import CommandExecutor

from .tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from .dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from .dataset_balancing_gui import gradio_dataset_balancing_tab
from .class_sample_images import SampleImages, run_cmd_sample
from .class_tensorboard import TensorboardManager

from .custom_logging import setup_logging

Expand All @@ -42,7 +45,6 @@

PYTHON = sys.executable

TRAIN_BUTTON_VISIBLE = [gr.Button(visible=True), gr.Button(visible=False)]

def save_configuration(
save_as_bool,
Expand Down Expand Up @@ -449,22 +451,18 @@ def train_model(
vae=vae,
dataset_config=dataset_config,
):
return TRAIN_BUTTON_VISIBLE
return

if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless
):
return TRAIN_BUTTON_VISIBLE
return

if dataset_config:
log.info(
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
)
else:
if train_data_dir == "":
log.error("Train data dir is empty")
return TRAIN_BUTTON_VISIBLE

# Get a list of all subfolders in train_data_dir, excluding hidden folders
subfolders = [
f
Expand All @@ -475,7 +473,7 @@ def train_model(
# Check if subfolders are present. If not let the user know and return
if not subfolders:
log.info(f"No {subfolders} were found in train_data_dir can't train...")
return TRAIN_BUTTON_VISIBLE
return

total_steps = 0

Expand Down Expand Up @@ -516,7 +514,7 @@ def train_model(
log.info(
f"No images were found in folder {train_data_dir}... please rectify!"
)
return TRAIN_BUTTON_VISIBLE
return

# Print the result
# log.info(f"{total_steps} total steps")
Expand Down Expand Up @@ -725,8 +723,6 @@ def train_model(
# Run the command

executor.execute_command(run_cmd=run_cmd, env=env)

return gr.Button(visible=False), gr.Button(visible=True)


def dreambooth_tab(
Expand Down Expand Up @@ -799,15 +795,27 @@ def dreambooth_tab(
with gr.Row():
button_run = gr.Button("Start training", variant="primary")

button_stop_training = gr.Button("Stop training", visible=False, variant="stop")
button_stop_training = gr.Button("Stop training")

with gr.Column(), gr.Group():
with gr.Row():
button_print = gr.Button("Print training command")
button_print = gr.Button("Print training command")

# Setup gradio tensorboard buttons
with gr.Column(), gr.Group():
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
(
button_start_tensorboard,
button_stop_tensorboard,
) = gradio_tensorboard()

button_start_tensorboard.click(
start_tensorboard,
inputs=[dummy_headless, folders.logging_dir],
show_progress=False,
)

button_stop_tensorboard.click(
stop_tensorboard,
show_progress=False,
)

settings_list = [
source_model.pretrained_model_name_or_path,
Expand Down Expand Up @@ -947,11 +955,10 @@ def dreambooth_tab(
button_run.click(
train_model,
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
outputs=[button_run, button_stop_training],
show_progress=False,
)

button_stop_training.click(executor.kill_command, outputs=[button_run, button_stop_training])
button_stop_training.click(executor.kill_command)

button_print.click(
train_model,
Expand Down
43 changes: 25 additions & 18 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from .class_folders import Folders
from .class_sdxl_parameters import SDXLParameters
from .class_command_executor import CommandExecutor
from .class_tensorboard import TensorboardManager
from .tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
stop_tensorboard,
)
from .class_sample_images import SampleImages, run_cmd_sample

from .custom_logging import setup_logging
Expand All @@ -46,7 +50,7 @@
PYTHON = sys.executable

presets_dir = rf"{scriptdir}/presets"
TRAIN_BUTTON_VISIBLE = [gr.Button(visible=True), gr.Button(visible=False)]


def save_configuration(
save_as_bool,
Expand Down Expand Up @@ -483,12 +487,12 @@ def train_model(
resume=resume,
dataset_config=dataset_config,
):
return TRAIN_BUTTON_VISIBLE
return

if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless
):
return TRAIN_BUTTON_VISIBLE
return

if dataset_config:
log.info(
Expand Down Expand Up @@ -553,10 +557,6 @@ def train_model(
# Run the command
subprocess.run(run_cmd, env=env)

if image_folder == "":
log.error("Image folder dir is empty")
return TRAIN_BUTTON_VISIBLE

image_num = len(
[
f
Expand Down Expand Up @@ -757,9 +757,6 @@ def train_model(

# Run the command
executor.execute_command(run_cmd=run_cmd, env=env)

return gr.Button(visible=False), gr.Button(visible=True)



def finetune_tab(headless=False, config: dict = {}):
Expand Down Expand Up @@ -901,15 +898,26 @@ def list_presets(path):
with gr.Row():
button_run = gr.Button("Start training", variant="primary")

button_stop_training = gr.Button("Stop training", visible=False, variant="stop")
button_stop_training = gr.Button("Stop training")

with gr.Column(), gr.Group():
with gr.Row():
button_print = gr.Button("Print training command")
button_print = gr.Button("Print training command")

# Setup gradio tensorboard buttons
with gr.Column(), gr.Group():
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
(
button_start_tensorboard,
button_stop_tensorboard,
) = gradio_tensorboard()

button_start_tensorboard.click(
start_tensorboard,
inputs=[dummy_headless, logging_dir],
)

button_stop_tensorboard.click(
stop_tensorboard,
show_progress=False,
)

settings_list = [
source_model.pretrained_model_name_or_path,
Expand Down Expand Up @@ -1070,11 +1078,10 @@ def list_presets(path):
button_run.click(
train_model,
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
outputs=[button_run, button_stop_training],
show_progress=False,
)

button_stop_training.click(executor.kill_command, outputs=[button_run, button_stop_training])
button_stop_training.click(executor.kill_command)

button_print.click(
train_model,
Expand Down
Loading
Loading