diff --git a/.github/workflows/typos.yaml b/.github/workflows/typos.yaml index e8b06483f..c81ff3210 100644 --- a/.github/workflows/typos.yaml +++ b/.github/workflows/typos.yaml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.21.0 diff --git a/.release b/.release index ddcd4c6f5..959a88019 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v24.0.9 \ No newline at end of file +v24.1.0 \ No newline at end of file diff --git a/README.md b/README.md index 73f256244..f8216c188 100644 --- a/README.md +++ b/README.md @@ -47,29 +47,7 @@ The GUI allows you to set the training parameters and generate and run the requi - [SDXL training](#sdxl-training) - [Masked loss](#masked-loss) - [Change History](#change-history) - - [2024/04/28 (v24.0.9)](#20240428-v2409) - - [2024/04/26 (v24.0.8)](#20240426-v2408) - - [2024/04/25 (v24.0.7)](#20240425-v2407) - - [2024/04/22 (v24.0.6)](#20240422-v2406) - - [2024/04/19 (v24.0.5)](#20240419-v2405) - - [New Contributors](#new-contributors) - - [2024/04/18 (v24.0.4)](#20240418-v2404) - - [What's Changed](#whats-changed) - - [New Contributors](#new-contributors-1) - - [2024/04/24 (v24.0.3)](#20240424-v2403) - - [2024/04/24 (v24.0.2)](#20240424-v2402) - - [2024/04/17 (v24.0.1)](#20240417-v2401) - - [Enhancements](#enhancements) - - [Security and Stability](#security-and-stability) - - [Shell Execution](#shell-execution) - - [Miscellaneous](#miscellaneous) - - [2024/04/10 (v23.1.5)](#20240410-v2315) - - [Security Improvements](#security-improvements) - - [2024/04/08 (v23.1.4)](#20240408-v2314) - - [2024/04/08 (v23.1.3)](#20240408-v2313) - - [2024/04/08 (v23.1.2)](#20240408-v2312) - - [2024/04/07 (v23.1.1)](#20240407-v2311) - - [2024/04/07 (v23.1.0)](#20240407-v2310) + - [v24.1.0](#v2410) ## 🦒 Colab @@ -116,12 +94,20 @@ To set up the project, follow these steps: cd kohya_ss ``` -4. Run the setup script by executing the following command: +4. Run one of the following setup script by executing the following command: + + For systems with only python 3.10.11 installed: ```shell .\setup.bat ``` + For systems with only more than one python release installed: + + ```shell + .\setup-3.10.bat + ``` + During the accelerate config step, use the default values as proposed during the configuration unless you know your hardware demands otherwise. The amount of VRAM on your GPU does not impact the values used. #### Optional: CUDNN 8.9.6.50 @@ -455,202 +441,10 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG ## Change History -### 2024/04/28 (v24.0.9) - -- Updated the temporary configuration file to include date and time information in the file name. This will allow for easier batching of multiple training commands, particularly useful for users who want to automate their training sessions. -- Fixed an issue with wd14 captioning where the captioning process was not functioning correctly when the recursive option was set to true. Prefixes and postfixes are now applied to all caption files in the folder. - -### 2024/04/26 (v24.0.8) - -- Set `max_train_steps` to 0 if not specified in older `.json` config files. - -### 2024/04/25 (v24.0.7) - -- Prevent crash if tkinter is not installed -- Fix [24.0.6] Train toml config seed type error #2370 -- A new docker container is now built with every new release, eliminating the need for manual building. A big thank you to @jim60105 for his hard work in this area. You can find more information about it in the Docker section of the README. - -### 2024/04/22 (v24.0.6) - -- Make start and stop buttons visible in headless -- Add validation for lr and optimizer arguments - -### 2024/04/19 (v24.0.5) - -- Hide tensorboard button if tensorflow module is not installed by @bmaltais in -- wd14 captioning issue with undesired tags nor tag replacement by @bmaltais in -- Changed logger checkbox to dropdown, renamed use_wandb -> log_with by @ccharest93 in - -#### New Contributors - -- @ccharest93 made their first contribution in - -### 2024/04/18 (v24.0.4) - -#### What's Changed - -- Fix options.md heading by @bmaltais in -- Use correct file extensions when browsing for model file by @b-fission in -- Add argument for Gradio's `root_path` to enable reverse proxy support by @hlky in -- 2325 quotes wrapping python path cause subprocess cant find target in v2403 by @bmaltais in -- 2330 another seemingly new data validation leads to unusable configs 2403 by @bmaltais in -- Fix bad Lora parameters by @bmaltais in - -#### New Contributors - -- @b-fission made their first contribution in -- @hlky made their first contribution in - -### 2024/04/24 (v24.0.3) - -- Fix issue with sample prompt creation - -### 2024/04/24 (v24.0.2) - -- Fixed issue with clip_skip not being passed as an int to sd-scripts when using old config.json files. - -### 2024/04/17 (v24.0.1) - -#### Enhancements - -- **User Interface:** Transitioned the GUI to use a TOML file for argument passing to sd-scripts, significantly enhancing security by eliminating the need for command-line interface (CLI) use for sensitive data. -- **Training Tools:** Improved the training and TensorBoard buttons to provide a more intuitive user experience. -- **HuggingFace Integration:** Integrated a HuggingFace section in all trainer tabs, enabling authentication and use of HuggingFace's advanced AI models. -- **Gradio Upgrade:** Upgraded Gradio to version 4.20.0 to fix a previously identified bug impacting the runpod platform. -- **Metadata Support:** Added functionality for metadata capture within the GUI. - -#### Security and Stability - -- **Code Refactoring:** Extensively rewrote the code to address various security vulnerabilities, including removing the `shell=True` parameter from process calls. -- **Scheduler Update:** Disabled LR Warmup when using the Constant LR Scheduler to prevent traceback errors associated with sd-scripts. - -#### Shell Execution - -- **Conditional Shell Usage:** Added support for optional shell usage when executing external sd-scripts commands, tailored to meet specific platform needs and recent security updates. - -The `gui.bat` and `gui.sh` scripts now include the `--do_not_use_shell` argument to prevent shell execution (`shell=True`) during external process handling.The GUI will automatically set `use_shell` to True internally, as required for proper execution of external commands. To enforce disabling shell execution, use the `--do_not_use_shell` argument. - -- **How to Enable Shell Execution via Config File:** - 1. In the `config.toml` file, set `use_shell` to `true` to enable shell usage as per GUI startup settings. - **Note:** The `--do_not_use_shell` option will override the `config.toml` settings, setting `use_shell` to False even if it is set to True in the config file. - -#### Miscellaneous - -- Made various other minor improvements and bug fixes to enhance overall functionality and user experience. -- Fixed an issue with existing LoRA network weights were not properly loaded prior to training - -### 2024/04/10 (v23.1.5) - -- Fix issue with Textual Inversion configuration file selection. -- Upgrade to gradio 4.19.2 to fix several high security risks associated to earlier versions. This is a major upgrade, moving from 3.x to 4.x. Hoping this will not introduce undorseen issues. -- Upgrade transformers to 4.38.0 to fix a low severity security issue. - -#### Security Improvements - -- Add explicit --do_not_share parameter to kohya_gui.py to avoid sharing the GUI on platforms like Kaggle. -- Remove shell=True from subprocess calls to avoid security issues when using the GUI. -- Limit caption extensions to a fixed set of extensions to limit the risk of finding and replacing text content in unexpected files. - -### 2024/04/08 (v23.1.4) - -- Relocate config accordion to the top of the GUI. - -### 2024/04/08 (v23.1.3) - -- Fix dataset preparation bug. - -### 2024/04/08 (v23.1.2) - -- Added config.toml support for wd14_caption. - -### 2024/04/07 (v23.1.1) - -- Added support for Huber loss under the Parameters / Advanced tab. - -### 2024/04/07 (v23.1.0) - -- Update sd-scripts to 0.8.7 - - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. - - - Highlights - - The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. - - Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately. - - `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt. - - `bitsandbytes` no longer requires complex procedures as it now officially supports Windows. - - Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)). - - When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue. - - A warning is displayed at the start of training if such information is included in the command line. - - Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed. - - See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details. - - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. - - Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details. - - - Training scripts - - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). - - Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced. - - DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details. - - The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#masked-loss) for details. - - Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](./docs/train_lllite_README.md#scheduled-huber-loss) for details. - - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf! - - The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee! - - The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue. - - - Dataset settings - - The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! - - The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! - - Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380! - - Some features are added to the dataset subset settings. - - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping. - - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. - - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled. - - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. - - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - - See [Dataset config](./docs/config_README-en.md) for details. - - The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details. - - - Image tagging (not implemented yet in the GUI) - - The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds! - - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. - - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. - - Some options are added to `tag_image_by_wd14_tagger.py`. - - Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0! - - Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag` - - Output character tags first `--character_tags_first` - - Expand character tags and series `--character_tag_expand` - - Specify tags to output first `--always_first_tags` - - Replace tags `--tag_replacement` - - See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details. - - Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`. - - - About Masked loss - The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option. - - The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue. - - ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). - - - About Scheduled Huber Loss - Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data. - - With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images. - - To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction. - - Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal. - - The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset. - - See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details. - - - `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before. - - `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`. - - `huber_c`: Specify the Huber's parameter. The default is `0.1`. - - Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.` +### v24.1.0 -- Added GUI support for the new parameters listed above. -- Moved accelerate launch parameters to a new `Accelerate launch` accordion above the `Model` accordion. -- Added support for `Debiased Estimation loss` to Dreambooth settings. -- Added support for "Dataset Preparation" defaults via the config.toml file. -- Added a field to allow for the input of extra accelerate launch arguments. -- Added new caption tool from +- To ensure cross-platform compatibility and security, the GUI now defaults to using "shell=False" when running subprocesses. This is based on documentation and should not cause issues on most platforms. However, some users have reported issues on specific platforms such as runpod and colab. PLease open an issue if you encounter any issues. +- Add support for custom LyCORIS toml config files. Simply type the path to the config file in the LyCORIS preset dropdown. +- Improve files and folders validation +- Added a new setup-3.10.bat file to set the venv to specifically use python 3.10.x instead of the default python the system might use. +- Relocate toml training config file to same folder as the model output directory. diff --git a/kohya_gui.py b/kohya_gui.py index 7009d7392..f586f0a29 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -78,13 +78,11 @@ def UI(**kwargs): reg_data_dir_input=reg_data_dir_input, output_dir_input=output_dir_input, logging_dir_input=logging_dir_input, - enable_copy_info_button=True, headless=headless, config=config, - use_shell_flag=use_shell_flag, ) with gr.Tab("LoRA"): - _ = LoRATools(headless=headless, use_shell_flag=use_shell_flag) + _ = LoRATools(headless=headless) with gr.Tab("About"): gr.Markdown(f"kohya_ss GUI release {release}") with gr.Tab("README"): diff --git a/kohya_gui/basic_caption_gui.py b/kohya_gui/basic_caption_gui.py index 2f473fb7b..d38fea750 100644 --- a/kohya_gui/basic_caption_gui.py +++ b/kohya_gui/basic_caption_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess from .common_gui import ( get_folder_path, @@ -28,7 +27,6 @@ def caption_images( postfix: str, find_text: str, replace_text: str, - use_shell: bool = False, ): """ Captions images in a given directory with a given caption text. @@ -46,16 +44,18 @@ def caption_images( Returns: None """ - # Check if images_dir is provided + # Check if images_dir and caption_ext are provided + missing_parameters = [] if not images_dir: - msgbox( - "Image folder is missing. Please provide the directory containing the images to caption." - ) - return - - # Check if caption_ext is provided + missing_parameters.append("image directory") if not caption_ext: - msgbox("Please provide an extension for the caption files.") + missing_parameters.append("caption file extension") + + if missing_parameters: + log.info( + "The following parameter(s) are missing: {}. " + "Please provide these to proceed with captioning the images.".format(", ".join(missing_parameters)) + ) return # Log the captioning process @@ -63,55 +63,59 @@ def caption_images( log.info(f"Captioning files in {images_dir} with {caption_text}...") # Build the command to run caption.py - run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/caption.py"' - run_cmd += f' --caption_text="{caption_text}"' + run_cmd = [ + rf"{PYTHON}", + rf"{scriptdir}/tools/caption.py", + "--caption_text", + caption_text, + ] # Add optional flags to the command if overwrite: - run_cmd += f" --overwrite" + run_cmd.append("--overwrite") if caption_ext: - run_cmd += f' --caption_file_ext="{caption_ext}"' + run_cmd.append("--caption_file_ext") + run_cmd.append(caption_ext) - run_cmd += f' "{images_dir}"' + run_cmd.append(rf"{images_dir}") - # Log the command - log.info(run_cmd) + # Reconstruct the safe command string for display + command_to_run = " ".join(run_cmd) + log.info(f"Executing command: {command_to_run}") # Set the environment variable for the Python path env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" - log.info(f"Executing command: {run_cmd} with shell={use_shell}") - # Run the command in the sd-scripts folder context - subprocess.run(run_cmd, env=env, shell=use_shell) - + subprocess.run(run_cmd, env=env, shell=False) # Check if overwrite option is enabled if overwrite: - # Add prefix and postfix to caption files - if prefix or postfix: + # Add prefix and postfix to caption files or find and replace text in caption files + if prefix or postfix or find_text: + # Add prefix and/or postfix to caption files add_pre_postfix( folder=images_dir, caption_file_ext=caption_ext, prefix=prefix, postfix=postfix, ) - # Find and replace text in caption files - if find_text: - find_replace( - folder_path=images_dir, - caption_file_ext=caption_ext, - search_text=find_text, - replace_text=replace_text, - ) + # Replace specified text in caption files if find and replace text is provided + if find_text and replace_text: + find_replace( + folder_path=images_dir, + caption_file_ext=caption_ext, + search_text=find_text, + replace_text=replace_text, + ) else: # Show a message if modification is not possible without overwrite option enabled if prefix or postfix: - msgbox( + log.info( 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected.' ) @@ -120,7 +124,7 @@ def caption_images( # Gradio UI -def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None, use_shell: bool = False): +def gradio_basic_caption_gui_tab(headless=False, default_images_dir=None): """ Creates a Gradio tab for basic image captioning. @@ -200,6 +204,7 @@ def list_images_dirs(path): choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, + allow_custom_value=True, ) # Checkbox to overwrite existing captions overwrite = gr.Checkbox( @@ -258,7 +263,6 @@ def list_images_dirs(path): postfix, find_text, replace_text, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/blip_caption_gui.py b/kohya_gui/blip_caption_gui.py index ede340dc6..215f0027b 100644 --- a/kohya_gui/blip_caption_gui.py +++ b/kohya_gui/blip_caption_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -23,7 +22,6 @@ def caption_images( beam_search: bool, prefix: str = "", postfix: str = "", - use_shell: bool = False, ) -> None: """ Automatically generates captions for images in the specified directory using the BLIP model. @@ -47,59 +45,60 @@ def caption_images( """ # Check if the image folder is provided if not train_data_dir: - msgbox("Image folder is missing...") + log.info("Image folder is missing...") return # Check if the caption file extension is provided if not caption_file_ext: - msgbox("Please provide an extension for the caption files.") + log.info("Please provide an extension for the caption files.") return log.info(f"Captioning files in {train_data_dir}...") # Construct the command to run make_captions.py - run_cmd = [fr'"{PYTHON}"', fr'"{scriptdir}/sd-scripts/finetune/make_captions.py"'] + run_cmd = [rf"{PYTHON}", rf"{scriptdir}/sd-scripts/finetune/make_captions.py"] # Add required arguments - run_cmd.append('--batch_size') + run_cmd.append("--batch_size") run_cmd.append(str(batch_size)) - run_cmd.append('--num_beams') + run_cmd.append("--num_beams") run_cmd.append(str(num_beams)) - run_cmd.append('--top_p') + run_cmd.append("--top_p") run_cmd.append(str(top_p)) - run_cmd.append('--max_length') + run_cmd.append("--max_length") run_cmd.append(str(max_length)) - run_cmd.append('--min_length') + run_cmd.append("--min_length") run_cmd.append(str(min_length)) # Add optional flags to the command if beam_search: run_cmd.append("--beam_search") if caption_file_ext: - run_cmd.append('--caption_extension') + run_cmd.append("--caption_extension") run_cmd.append(caption_file_ext) # Add the directory containing the training data - run_cmd.append(fr'"{train_data_dir}"') + run_cmd.append(rf"{train_data_dir}") # Add URL for caption model weights - run_cmd.append('--caption_weights') - run_cmd.append("https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth") + run_cmd.append("--caption_weights") + run_cmd.append( + rf"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth" + ) # Set up the environment env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") - - # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell, cwd=f"{scriptdir}/sd-scripts") + log.info(f"Executing command: {command_to_run}") + # Run the command in the sd-scripts folder context + subprocess.run(run_cmd, env=env, shell=False, cwd=rf"{scriptdir}/sd-scripts") # Add prefix and postfix add_pre_postfix( @@ -117,7 +116,7 @@ def caption_images( ### -def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None, use_shell: bool = False): +def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None): from .common_gui import create_refresh_button default_train_dir = ( @@ -167,6 +166,7 @@ def list_train_dirs(path): choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, + allow_custom_value=True, ) prefix = gr.Textbox( @@ -207,7 +207,6 @@ def list_train_dirs(path): beam_search, prefix, postfix, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/class_command_executor.py b/kohya_gui/class_command_executor.py index 440a46c40..f18e97a32 100644 --- a/kohya_gui/class_command_executor.py +++ b/kohya_gui/class_command_executor.py @@ -28,7 +28,7 @@ def __init__(self, headless: bool = False): "Stop training", visible=self.process is not None or headless, variant="stop" ) - def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs): + def execute_command(self, run_cmd: str, **kwargs): """ Execute a command if no other command is currently running. @@ -44,10 +44,10 @@ def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs): # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") + log.info(f"Executing command: {command_to_run}") # Execute the command securely - self.process = subprocess.Popen(command_to_run, **kwargs, shell=use_shell) + self.process = subprocess.Popen(run_cmd, **kwargs) log.info("Command executed.") def kill_command(self): diff --git a/kohya_gui/class_lora_tab.py b/kohya_gui/class_lora_tab.py index d1e238419..487ff5cbc 100644 --- a/kohya_gui/class_lora_tab.py +++ b/kohya_gui/class_lora_tab.py @@ -14,15 +14,14 @@ class LoRATools: def __init__( self, headless: bool = False, - use_shell_flag: bool = False, ): gr.Markdown("This section provide various LoRA tools...") - gradio_extract_dylora_tab(headless=headless, use_shell=use_shell_flag) - gradio_convert_lcm_tab(headless=headless, use_shell=use_shell_flag) - gradio_extract_lora_tab(headless=headless, use_shell=use_shell_flag) - gradio_extract_lycoris_locon_tab(headless=headless, use_shell=use_shell_flag) - gradio_merge_lora_tab = GradioMergeLoRaTab(use_shell=use_shell_flag) - gradio_merge_lycoris_tab(headless=headless, use_shell=use_shell_flag) - gradio_svd_merge_lora_tab(headless=headless, use_shell=use_shell_flag) + gradio_extract_dylora_tab(headless=headless) + gradio_convert_lcm_tab(headless=headless) + gradio_extract_lora_tab(headless=headless) + gradio_extract_lycoris_locon_tab(headless=headless) + gradio_merge_lora_tab = GradioMergeLoRaTab() + gradio_merge_lycoris_tab(headless=headless) + gradio_svd_merge_lora_tab(headless=headless) gradio_resize_lora_tab(headless=headless) gradio_verify_lora_tab(headless=headless) diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 52d4faa94..43f247667 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -14,6 +14,7 @@ import json import math import shutil +import toml # Set up logging log = setup_logging() @@ -354,7 +355,7 @@ def update_my_data(my_data): except ValueError: # Handle the case where the string is not a valid float my_data[key] = int(1) - + for key in [ "max_train_steps", ]: @@ -419,7 +420,7 @@ def update_my_data(my_data): my_data["xformers"] = "xformers" else: my_data["xformers"] = "none" - + # Convert use_wandb to log_with="wandb" if it is set to True for key in ["use_wandb"]: value = my_data.get(key) @@ -430,17 +431,16 @@ def update_my_data(my_data): except ValueError: # Handle the case where the string is not a valid float pass - + my_data.pop(key, None) - - + # Replace the lora_network_weights key with network_weights keeping the original value for key in ["lora_network_weights"]: - value = my_data.get(key) # Get original value - if value is not None: # Check if the key exists in the dictionary + value = my_data.get(key) # Get original value + if value is not None: # Check if the key exists in the dictionary my_data["network_weights"] = value my_data.pop(key, None) - + return my_data @@ -740,7 +740,6 @@ def add_pre_postfix( postfix: str = "", caption_file_ext: str = ".caption", recursive: bool = False, - ) -> None: """ Add prefix and/or postfix to the content of caption files within a folder. @@ -751,12 +750,8 @@ def add_pre_postfix( prefix (str, optional): Prefix to add to the content of the caption files. postfix (str, optional): Postfix to add to the content of the caption files. caption_file_ext (str, optional): Extension of the caption files. + recursive (bool, optional): Whether to search for caption files recursively. """ - # Enforce that the provided extension is one of .caption, .cap, .txt - if caption_file_ext not in (".caption", ".cap", ".txt"): - log.error("Invalid caption file extension. Must be on of .caption, .cap, .txt") - return - # If neither prefix nor postfix is provided, return early if prefix == "" and postfix == "": return @@ -780,33 +775,39 @@ def add_pre_postfix( # Iterate over the list of image files for image_file in image_files: # Construct the caption file name by appending the caption file extension to the image file name - caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext + caption_file_name = f"{os.path.splitext(image_file)[0]}{caption_file_ext}" # Construct the full path to the caption file caption_file_path = os.path.join(folder, caption_file_name) # Check if the caption file does not exist if not os.path.exists(caption_file_path): # Create a new caption file with the specified prefix and/or postfix - with open(caption_file_path, "w", encoding="utf-8") as f: - # Determine the separator based on whether both prefix and postfix are provided - separator = " " if prefix and postfix else "" - f.write(f"{prefix}{separator}{postfix}") + try: + with open(caption_file_path, "w", encoding="utf-8") as f: + # Determine the separator based on whether both prefix and postfix are provided + separator = " " if prefix and postfix else "" + f.write(f"{prefix}{separator}{postfix}") + except Exception as e: + log.error(f"Error writing to file {caption_file_path}: {e}") else: # Open the existing caption file for reading and writing - with open(caption_file_path, "r+", encoding="utf-8") as f: - # Read the content of the caption file, stripping any trailing whitespace - content = f.read().rstrip() - # Move the file pointer to the beginning of the file - f.seek(0, 0) - - # Determine the separator based on whether only prefix is provided - prefix_separator = " " if prefix else "" - # Determine the separator based on whether only postfix is provided - postfix_separator = " " if postfix else "" - # Write the updated content to the caption file, adding prefix and/or postfix - f.write( - f"{prefix}{prefix_separator}{content}{postfix_separator}{postfix}" - ) + try: + with open(caption_file_path, "r+", encoding="utf-8") as f: + # Read the content of the caption file, stripping any trailing whitespace + content = f.read().rstrip() + # Move the file pointer to the beginning of the file + f.seek(0, 0) + + # Determine the separator based on whether only prefix is provided + prefix_separator = " " if prefix else "" + # Determine the separator based on whether only postfix is provided + postfix_separator = " " if postfix else "" + # Write the updated content to the caption file, adding prefix and/or postfix + f.write( + f"{prefix}{prefix_separator}{content}{postfix_separator}{postfix}" + ) + except Exception as e: + log.error(f"Error writing to file {caption_file_path}: {e}") def has_ext_files(folder_path: str, file_extension: str) -> bool: @@ -1356,118 +1357,74 @@ def check_duplicate_filenames( log.info("...valid") -def validate_paths(headless: bool = False, **kwargs: Optional[str]) -> bool: - """ - Validates the existence of specified paths and patterns for model training configurations. - - This function checks for the existence of various directory paths and files provided as keyword arguments, - including model paths, data directories, output directories, and more. It leverages predefined default - models for validation and ensures directory creation if necessary. - - Args: - headless (bool): A flag indicating if the function should run without requiring user input. - **kwargs (Optional[str]): Keyword arguments that represent various path configurations, - including but not limited to `pretrained_model_name_or_path`, `train_data_dir`, - and more. +def validate_file_path(file_path: str) -> bool: + if file_path == "": + return True + msg = f"Validating {file_path} existence..." + if not os.path.isfile(file_path): + log.error(f"{msg} FAILED: does not exist") + return False + log.info(f"{msg} SUCCESS") + return True - Returns: - bool: True if all specified paths are valid or have been successfully created; False otherwise. - """ - def validate_path( - path: Optional[str], path_type: str, create_if_missing: bool = False - ) -> bool: - """ - Validates the existence of a path. If the path does not exist and `create_if_missing` is True, - attempts to create the directory. - - Args: - path (Optional[str]): The path to validate. - path_type (str): Description of the path type for logging purposes. - create_if_missing (bool): Whether to create the directory if it does not exist. - - Returns: - bool: True if the path is valid or has been created; False otherwise. - """ - if path: - log.info(f"Validating {path_type} path {path} existence...") - if os.path.exists(path): - log.info("...valid") - else: - if create_if_missing: - try: - os.makedirs(path, exist_ok=True) - log.info(f"...created folder at {path}") - return True - except Exception as e: - log.error(f"...failed to create {path_type} folder: {e}") - return False - else: - log.error( - f"...{path_type} path '{path}' is missing or does not exist" - ) - return False - else: - log.info(f"{path_type} not specified, skipping validation") +def validate_folder_path(folder_path: str, can_be_written_to: bool = False) -> bool: + if folder_path == "": return True - - # Validates the model name or path against default models or existence as a local path - if not validate_model_path(kwargs.get("pretrained_model_name_or_path")): + msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..." + if not os.path.isdir(folder_path): + log.error(f"{msg} FAILED: does not exist") return False + if can_be_written_to: + if not os.access(folder_path, os.W_OK): + log.error(f"{msg} FAILED: is not writable.") + return False + log.info(f"{msg} SUCCESS") + return True - # Validates the existence of specified directories or files, and creates them if necessary - for key, value in kwargs.items(): - if key in ["output_dir", "logging_dir"]: - if not validate_path(value, key, create_if_missing=True): - return False - elif key in ["vae"]: - # Check if it matches the Hugging Face model pattern - if re.match(r"^[\w-]+\/[\w-]+$", value): - log.info("Checking vae... huggingface.co model, skipping validation") - else: - if not validate_path(value, key): - return False - else: - if key not in ["pretrained_model_name_or_path"]: - if not validate_path(value, key): - return False - +def validate_toml_file(file_path: str) -> bool: + if file_path == "": + return True + msg = f"Validating toml {file_path} existence and validity..." + if not os.path.isfile(file_path): + log.error(f"{msg} FAILED: does not exist") + return False + + try: + toml.load(file_path) + except: + log.error(f"{msg} FAILED: is not a valid toml file.") + return False + log.info(f"{msg} SUCCESS") return True -def validate_model_path(pretrained_model_name_or_path: Optional[str]) -> bool: +def validate_model_path(pretrained_model_name_or_path: str) -> bool: """ Validates the pretrained model name or path against Hugging Face models or local paths. Args: - pretrained_model_name_or_path (Optional[str]): The pretrained model name or path to validate. + pretrained_model_name_or_path (str): The pretrained model name or path to validate. Returns: bool: True if the path is a valid Hugging Face model or exists locally; False otherwise. """ from .class_source_model import default_models - if pretrained_model_name_or_path: - log.info( - f"Validating model file or folder path {pretrained_model_name_or_path} existence..." - ) + msg = f"Validating {pretrained_model_name_or_path} existence..." - # Check if it matches the Hugging Face model pattern - if re.match(r"^[\w-]+\/[\w-]+$", pretrained_model_name_or_path): - log.info("...huggingface.co model, skipping validation") - elif pretrained_model_name_or_path not in default_models: - # If not one of the default models, check if it's a valid local path - if not os.path.exists(pretrained_model_name_or_path): - log.error( - f"...source model path '{pretrained_model_name_or_path}' is missing or does not exist" - ) - return False - else: - log.info("...valid") - else: - log.info("...valid") + # Check if it matches the Hugging Face model pattern + if re.match(r"^[\w-]+\/[\w-]+$", pretrained_model_name_or_path): + log.info(f"{msg} SKIPPING: huggingface.co model") + elif pretrained_model_name_or_path in default_models: + log.info(f"{msg} SUCCESS") else: - log.info("Model name or path not specified, skipping validation") + # If not one of the default models, check if it's a valid local path + if not os.path.exists(pretrained_model_name_or_path): + log.error(f"{msg} FAILED: is missing or does not exist") + return False + log.info(f"{msg} SUCCESS") + return True @@ -1495,6 +1452,7 @@ def is_file_writable(file_path: str) -> bool: # If an IOError occurs, the file cannot be written to return False + def print_command_and_toml(run_cmd, tmpfilename): log.warning( "Here is the trainer command as a reference. It will not be executed:\n" @@ -1512,16 +1470,19 @@ def print_command_and_toml(run_cmd, tmpfilename): log.info(f"end of toml config file: {tmpfilename}") save_to_file(command_to_run) - + + def validate_args_setting(input_string): # Regex pattern to handle multiple conditions: # - Empty string is valid # - Single or multiple key/value pairs with exactly one space between pairs # - No spaces around '=' and no spaces within keys or values - pattern = r'^(\S+=\S+)( \S+=\S+)*$|^$' + pattern = r"^(\S+=\S+)( \S+=\S+)*$|^$" if re.match(pattern, input_string): return True else: log.info(f"'{input_string}' is not a valid settings string.") - log.info("A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space.") - return False \ No newline at end of file + log.info( + "A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space." + ) + return False diff --git a/kohya_gui/convert_lcm_gui.py b/kohya_gui/convert_lcm_gui.py index 605109b31..92c3d3432 100644 --- a/kohya_gui/convert_lcm_gui.py +++ b/kohya_gui/convert_lcm_gui.py @@ -27,10 +27,7 @@ def convert_lcm( model_path, lora_scale, model_type, - use_shell: bool = False, ): - run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"' - # Check if source model exist if not os.path.isfile(model_path): log.error("The provided DyLoRA model is not a file") @@ -48,44 +45,77 @@ def convert_lcm( save_to = f"{path}_lcm{ext}" # Construct the command to run the script - run_cmd += f" --lora-scale {lora_scale}" - run_cmd += f' --model "{model_path}"' - run_cmd += f' --name "{name}"' + run_cmd = [ + rf"{PYTHON}", + rf"{scriptdir}/tools/lcm_convert.py", + "--lora-scale", + str(lora_scale), + "--model", + rf"{model_path}", + "--name", + str(name), + ] if model_type == "SDXL": - run_cmd += f" --sdxl" + run_cmd.append("--sdxl") if model_type == "SSD-1B": - run_cmd += f" --ssd-1b" + run_cmd.append("--ssd-1b") # Set up the environment env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display - log.info(f"Executing command: {run_cmd} with shell={use_shell}") + command_to_run = " ".join(run_cmd) + log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context - subprocess.run( - run_cmd, env=env, shell=use_shell - ) + subprocess.run(run_cmd, env=env, shell=False) # Return a success message log.info("Done extracting...") -def gradio_convert_lcm_tab(headless=False, use_shell: bool = False): +def gradio_convert_lcm_tab(headless=False): + """ + Creates a Gradio tab for converting a model to an LCM model. + + Args: + headless (bool): If True, the tab will be created without any visible elements. + + Returns: + None + """ current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") def list_models(path): + """ + Lists all model files in the given directory. + + Args: + path (str): The directory path to search for model files. + + Returns: + list: A list of model file paths. + """ nonlocal current_model_dir current_model_dir = path return list(list_files(path, exts=[".safetensors"], all=True)) def list_save_to(path): + """ + Lists all save-to options for the given directory. + + Args: + path (str): The directory path to search for save-to options. + + Returns: + list: A list of save-to options. + """ nonlocal current_save_dir current_save_dir = path return list(list_files(path, exts=[".safetensors"], all=True)) @@ -186,7 +216,6 @@ def list_save_to(path): model_path, lora_scale, model_type, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/convert_model_gui.py b/kohya_gui/convert_model_gui.py index 722fd5bcf..7d4b30e85 100644 --- a/kohya_gui/convert_model_gui.py +++ b/kohya_gui/convert_model_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -26,11 +25,10 @@ def convert_model( target_model_type, target_save_precision_type, unet_use_linear_projection, - use_shell: bool = False, ): # Check for caption_text_input if source_model_type == "": - msgbox("Invalid source model type") + log.info("Invalid source model type") return # Check if source model exist @@ -39,19 +37,19 @@ def convert_model( elif os.path.isdir(source_model_input): log.info("The provided model is a folder") else: - msgbox("The provided source model is neither a file nor a folder") + log.info("The provided source model is neither a file nor a folder") return # Check if source model exist if os.path.isdir(target_model_folder_input): log.info("The provided model folder exist") else: - msgbox("The provided target folder does not exist") + log.info("The provided target folder does not exist") return run_cmd = [ - fr'"{PYTHON}"', - fr'"{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py"', + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/tools/convert_diffusers20_original_sd.py", ] v1_models = [ @@ -82,7 +80,7 @@ def convert_model( run_cmd.append("--unet_use_linear_projection") # Add the source model input path - run_cmd.append(fr'"{source_model_input}"') + run_cmd.append(rf"{source_model_input}") # Determine the target model path if target_model_type == "diffuser" or target_model_type == "diffuser_safetensors": @@ -96,23 +94,20 @@ def convert_model( ) # Add the target model path - run_cmd.append(fr'"{target_model_path}"') + run_cmd.append(rf"{target_model_path}") + + # Log the command + log.info(" ".join(run_cmd)) env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Adding an example of an environment variable that might be relevant env["TF_ENABLE_ONEDNN_OPTS"] = "0" - # Reconstruct the safe command string for display - command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") - - # Run the command in the sd-scripts folder context - subprocess.run( - command_to_run, env=env, shell=use_shell - ) + # Run the command + subprocess.run(run_cmd, env=env, shell=False) ### @@ -120,7 +115,7 @@ def convert_model( ### -def gradio_convert_model_tab(headless=False, use_shell: bool = False): +def gradio_convert_model_tab(headless=False): from .common_gui import create_refresh_button default_source_model = os.path.join(scriptdir, "outputs") @@ -280,7 +275,6 @@ def list_target_folder(path): target_model_type, target_save_precision_type, unet_use_linear_projection, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 8b8d8ded7..113a7228e 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -17,7 +17,7 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, + validate_file_path, validate_folder_path, validate_model_path, validate_args_setting, ) from .class_accelerate_launch import AccelerateLaunch @@ -511,23 +511,57 @@ def train_model( log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): - return + return TRAIN_BUTTON_VISIBLE + + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(reg_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(train_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(vae): + return TRAIN_BUTTON_VISIBLE + # + # End of path validation + # # This function validates files or folder paths. Simply add new variables containing file of folder path # to validate below - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - train_data_dir=train_data_dir, - reg_data_dir=reg_data_dir, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - vae=vae, - dataset_config=dataset_config, - ): - return TRAIN_BUTTON_VISIBLE + # if not validate_paths( + # dataset_config=dataset_config, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # reg_data_dir=reg_data_dir, + # resume=resume, + # train_data_dir=train_data_dir, + # vae=vae, + # ): + # return TRAIN_BUTTON_VISIBLE if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless=headless @@ -642,7 +676,7 @@ def train_model( log.info(max_train_steps_info) log.info(f"lr_warmup_steps = {lr_warmup_steps}") - run_cmd = [rf'"{get_executable_path("accelerate")}"', "launch"] + run_cmd = [rf'{get_executable_path("accelerate")}', "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -661,9 +695,9 @@ def train_model( ) if sdxl: - run_cmd.append(rf'"{scriptdir}/sd-scripts/sdxl_train.py"') + run_cmd.append(rf'{scriptdir}/sd-scripts/sdxl_train.py') else: - run_cmd.append(rf'"{scriptdir}/sd-scripts/train_db.py"') + run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py") if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 @@ -825,7 +859,7 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = f"./outputs/config_dreambooth-{formatted_datetime}.toml" + tmpfilename = fr"{output_dir}/config_dreambooth-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: @@ -835,7 +869,7 @@ def train_model( log.error(f"Failed to write TOML file: {toml_file.name}") run_cmd.append(f"--config_file") - run_cmd.append(rf'"{tmpfilename}"') + run_cmd.append(rf'{tmpfilename}') # Initialize a dictionary with always-included keyword arguments kwargs_for_training = { @@ -866,13 +900,13 @@ def train_model( env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) + executor.execute_command(run_cmd=run_cmd, env=env) train_state_value = time.time() diff --git a/kohya_gui/extract_lora_from_dylora_gui.py b/kohya_gui/extract_lora_from_dylora_gui.py index eec3cf002..00e0c3eb3 100644 --- a/kohya_gui/extract_lora_from_dylora_gui.py +++ b/kohya_gui/extract_lora_from_dylora_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -27,16 +26,15 @@ def extract_dylora( model, save_to, unit, - use_shell: bool = False, ): # Check for caption_text_input if model == "": - msgbox("Invalid DyLoRA model file") + log.info("Invalid DyLoRA model file") return # Check if source model exist if not os.path.isfile(model): - msgbox("The provided DyLoRA model is not a file") + log.info("The provided DyLoRA model is not a file") return if os.path.dirname(save_to) == "": @@ -51,29 +49,29 @@ def extract_dylora( save_to = f"{path}_tmp{ext}" run_cmd = [ - fr'"{PYTHON}"', - rf'"{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py"', + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/networks/extract_lora_from_dylora.py", "--save_to", - rf'"{save_to}"', + rf"{save_to}", "--model", - rf'"{model}"', + rf"{model}", "--unit", str(unit), ] env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Example environment variable adjustment for the Python environment env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") + log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + subprocess.run(run_cmd, env=env, shell=False) log.info("Done extracting DyLoRA...") @@ -83,7 +81,7 @@ def extract_dylora( ### -def gradio_extract_dylora_tab(headless=False, use_shell: bool = False): +def gradio_extract_dylora_tab(headless=False): current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") @@ -172,7 +170,6 @@ def list_save_to(path): model, save_to, unit, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py index e0693034d..1a9ad7849 100644 --- a/kohya_gui/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -39,7 +39,6 @@ def extract_lora( load_original_model_to, load_tuned_model_to, load_precision, - use_shell: bool = False, ): # Check for caption_text_input if model_tuned == "": @@ -74,18 +73,18 @@ def extract_lora( return run_cmd = [ - fr'"{PYTHON}"', - fr'"{scriptdir}/sd-scripts/networks/extract_lora_from_models.py"', + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/networks/extract_lora_from_models.py", "--load_precision", load_precision, "--save_precision", save_precision, "--save_to", - fr'"{save_to}"', + rf"{save_to}", "--model_org", - fr'"{model_org}"', + rf"{model_org}", "--model_tuned", - fr'"{model_tuned}"', + rf"{model_tuned}", "--dim", str(dim), "--device", @@ -112,18 +111,17 @@ def extract_lora( env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Adding an example of another potentially relevant environment variable env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") - - # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + log.info(f"Executing command: {command_to_run}") + # Run the command in the sd-scripts folder context + subprocess.run(run_cmd, env=env) ### @@ -131,7 +129,9 @@ def extract_lora( ### -def gradio_extract_lora_tab(headless=False, use_shell: bool = False): +def gradio_extract_lora_tab( + headless=False, +): current_model_dir = os.path.join(scriptdir, "outputs") current_model_org_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") @@ -361,7 +361,6 @@ def change_sdxl(sdxl): load_original_model_to, load_tuned_model_to, load_precision, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/extract_lycoris_locon_gui.py b/kohya_gui/extract_lycoris_locon_gui.py index fa9fd689e..6676edb3f 100644 --- a/kohya_gui/extract_lycoris_locon_gui.py +++ b/kohya_gui/extract_lycoris_locon_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -43,24 +42,23 @@ def extract_lycoris_locon( use_sparse_bias, sparsity, disable_cp, - use_shell: bool = False, ): # Check for caption_text_input if db_model == "": - msgbox("Invalid finetuned model file") + log.info("Invalid finetuned model file") return if base_model == "": - msgbox("Invalid base model file") + log.info("Invalid base model file") return # Check if source model exist if not os.path.isfile(db_model): - msgbox("The provided finetuned model is not a file") + log.info("The provided finetuned model is not a file") return if not os.path.isfile(base_model): - msgbox("The provided base model is not a file") + log.info("The provided base model is not a file") return if os.path.dirname(output_name) == "": @@ -74,7 +72,7 @@ def extract_lycoris_locon( path, ext = os.path.splitext(output_name) output_name = f"{path}_tmp{ext}" - run_cmd = [fr'"{PYTHON}"', fr'"{scriptdir}/tools/lycoris_locon_extract.py"'] + run_cmd = [fr'{PYTHON}', fr'{scriptdir}/tools/lycoris_locon_extract.py'] if is_sdxl: run_cmd.append("--is_sdxl") @@ -121,23 +119,23 @@ def extract_lycoris_locon( run_cmd.append("--disable_cp") # Add paths - run_cmd.append(fr'"{base_model}"') - run_cmd.append(fr'"{db_model}"') - run_cmd.append(fr'"{output_name}"') + run_cmd.append(fr"{base_model}") + run_cmd.append(fr"{db_model}") + run_cmd.append(fr"{output_name}") env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Adding an example of an environment variable that might be relevant env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") + log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + subprocess.run(run_cmd, env=env) log.info("Done extracting...") @@ -174,7 +172,7 @@ def update_mode(mode): return tuple(updates) -def gradio_extract_lycoris_locon_tab(headless=False, use_shell: bool = False): +def gradio_extract_lycoris_locon_tab(headless=False): current_model_dir = os.path.join(scriptdir, "outputs") current_base_model_dir = os.path.join(scriptdir, "outputs") @@ -452,7 +450,6 @@ def list_save_to(path): use_sparse_bias, sparsity, disable_cp, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index ef384f23e..204b74f9f 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -18,8 +18,8 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, - validate_args_setting + validate_file_path, validate_folder_path, validate_model_path, + validate_args_setting, ) from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile @@ -530,13 +530,13 @@ def train_model( # Get list of function parameters and values parameters = list(locals().items()) global train_state_value - + TRAIN_BUTTON_VISIBLE = [ gr.Button(visible=True), gr.Button(visible=False or headless), gr.Textbox(value=train_state_value), ] - + if executor.is_running(): log.error("Training is already running. Can't start another training session.") return TRAIN_BUTTON_VISIBLE @@ -548,7 +548,7 @@ def train_model( log.info(f"Validating lr scheduler arguments...") if not validate_args_setting(lr_scheduler_args): return - + log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): return @@ -556,17 +556,46 @@ def train_model( if train_dir != "" and not os.path.exists(train_dir): os.mkdir(train_dir) - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - finetune_image_folder=image_folder, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - dataset_config=dataset_config, - ): + # + # Validate paths + # + + if not validate_file_path(dataset_config): return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(image_folder): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + # + # End of path validation + # + + # if not validate_paths( + # dataset_config=dataset_config, + # finetune_image_folder=image_folder, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # resume=resume, + # ): + # return TRAIN_BUTTON_VISIBLE if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless @@ -712,7 +741,7 @@ def train_model( lr_warmup_steps = 0 log.info(f"lr_warmup_steps = {lr_warmup_steps}") - run_cmd = [get_executable_path("accelerate"), "launch"] + run_cmd = [rf'{get_executable_path("accelerate")}', "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -812,7 +841,9 @@ def train_model( "max_bucket_reso": int(max_bucket_reso), "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), - "max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None, + "max_train_epochs": ( + int(max_train_epochs) if int(max_train_epochs) != 0 else None + ), "max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None, "mem_eff_attn": mem_eff_attn, "metadata_author": metadata_author, @@ -888,15 +919,15 @@ def train_model( for key, value in config_toml_data.items() if value not in ["", False, None] } - + config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers) - + # Sort the dictionary by keys config_toml_data = dict(sorted(config_toml_data.items())) current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = f"./outputs/config_finetune-{formatted_datetime}.toml" + tmpfilename = fr"{output_dir}/config_finetune-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: toml.dump(config_toml_data, toml_file) @@ -904,7 +935,7 @@ def train_model( if not os.path.exists(toml_file.name): log.error(f"Failed to write TOML file: {toml_file.name}") - run_cmd.append(f"--config_file") + run_cmd.append("--config_file") run_cmd.append(rf"{tmpfilename}") # Initialize a dictionary with always-included keyword arguments @@ -941,7 +972,7 @@ def train_model( env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) + executor.execute_command(run_cmd=run_cmd, env=env) train_state_value = time.time() @@ -1285,7 +1316,7 @@ def list_presets(path): ) run_state = gr.Textbox(value=train_state_value, visible=False) - + run_state.change( fn=executor.wait_for_training_to_end, outputs=[executor.button_run, executor.button_stop_training], @@ -1299,7 +1330,8 @@ def list_presets(path): ) executor.button_stop_training.click( - executor.kill_command, outputs=[executor.button_run, executor.button_stop_training] + executor.kill_command, + outputs=[executor.button_run, executor.button_stop_training], ) button_print.click( diff --git a/kohya_gui/git_caption_gui.py b/kohya_gui/git_caption_gui.py index df287ce4f..13b38179b 100644 --- a/kohya_gui/git_caption_gui.py +++ b/kohya_gui/git_caption_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -22,25 +21,24 @@ def caption_images( model_id, prefix, postfix, - use_shell: bool = False, ): # Check for images_dir_input if train_data_dir == "": - msgbox("Image folder is missing...") + log.info("Image folder is missing...") return if caption_ext == "": - msgbox("Please provide an extension for the caption files.") + log.info("Please provide an extension for the caption files.") return log.info(f"GIT captioning files in {train_data_dir}...") - run_cmd = [fr'"{PYTHON}"', fr'"{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"'] + run_cmd = [fr"{PYTHON}", fr"{scriptdir}/sd-scripts/finetune/make_captions_by_git.py"] # Add --model_id if provided if model_id != "": run_cmd.append("--model_id") - run_cmd.append(model_id) + run_cmd.append(fr'{model_id}') # Add other arguments with their values run_cmd.append("--batch_size") @@ -58,21 +56,21 @@ def caption_images( run_cmd.append(caption_ext) # Add the directory containing the training data - run_cmd.append(fr'"{train_data_dir}"') + run_cmd.append(fr"{train_data_dir}") env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Adding an example of an environment variable that might be relevant env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") + log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + subprocess.run(run_cmd, env=env) # Add prefix and postfix @@ -92,7 +90,7 @@ def caption_images( def gradio_git_caption_gui_tab( - headless=False, default_train_dir=None, use_shell: bool = False + headless=False, default_train_dir=None, ): from .common_gui import create_refresh_button @@ -143,6 +141,7 @@ def list_train_dirs(path): choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, + allow_custom_value=True, ) prefix = gr.Textbox( @@ -183,7 +182,6 @@ def list_train_dirs(path): model_id, prefix, postfix, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/group_images_gui.py b/kohya_gui/group_images_gui.py index d7224481d..08a7fc58b 100644 --- a/kohya_gui/group_images_gui.py +++ b/kohya_gui/group_images_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess from .common_gui import get_folder_path, scriptdir, list_dirs import os @@ -21,23 +20,22 @@ def group_images( do_not_copy_other_files, generate_captions, caption_ext, - use_shell: bool = False, ): if input_folder == "": - msgbox("Input folder is missing...") + log.info("Input folder is missing...") return if output_folder == "": - msgbox("Please provide an output folder.") + log.info("Please provide an output folder.") return log.info(f"Grouping images in {input_folder}...") run_cmd = [ - fr'"{PYTHON}"', - f'"{scriptdir}/tools/group_images.py"', - fr'"{input_folder}"', - fr'"{output_folder}"', + fr"{PYTHON}", + f"{scriptdir}/tools/group_images.py", + fr"{input_folder}", + fr"{output_folder}", str(group_size), ] @@ -62,16 +60,16 @@ def group_images( # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") + log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + subprocess.run(run_cmd, env=env) log.info("...grouping done") -def gradio_group_images_gui_tab(headless=False, use_shell: bool = False): +def gradio_group_images_gui_tab(headless=False): from .common_gui import create_refresh_button current_input_folder = os.path.join(scriptdir, "data") @@ -189,6 +187,7 @@ def list_output_dirs(path): choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, + allow_custom_value=True, ) group_images_button = gr.Button("Group images") @@ -203,7 +202,6 @@ def list_output_dirs(path): do_not_copy_other_files, generate_captions, caption_ext, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 02e98cd9d..58411ddb2 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -19,8 +19,8 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, - validate_args_setting + validate_file_path, validate_folder_path, validate_model_path, validate_toml_file, + validate_args_setting, ) from .class_accelerate_launch import AccelerateLaunch from .class_configuration_file import ConfigurationFile @@ -60,6 +60,15 @@ presets_dir = rf"{scriptdir}/presets" +LYCORIS_PRESETS_CHOICES = [ + "attn-mlp", + "attn-only", + "full", + "full-lin", + "unet-transformer-only", + "unet-convblock-only", +] + def save_configuration( save_as_bool, @@ -667,13 +676,13 @@ def train_model( # Get list of function parameters and values parameters = list(locals().items()) global train_state_value - + TRAIN_BUTTON_VISIBLE = [ gr.Button(visible=True), gr.Button(visible=False or headless), gr.Textbox(value=train_state_value), ] - + if executor.is_running(): log.error("Training is already running. Can't start another training session.") return TRAIN_BUTTON_VISIBLE @@ -682,27 +691,69 @@ def train_model( log.info(f"Validating lr scheduler arguments...") if not validate_args_setting(lr_scheduler_args): - return - + return TRAIN_BUTTON_VISIBLE + log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): - return - - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - train_data_dir=train_data_dir, - reg_data_dir=reg_data_dir, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - vae=vae, - network_weights=network_weights, - dataset_config=dataset_config, - ): return TRAIN_BUTTON_VISIBLE + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES: + if not validate_toml_file(LyCORIS_preset): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(network_weights): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(reg_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(train_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(vae): + return TRAIN_BUTTON_VISIBLE + + # + # End of path validation + # + + # if not validate_paths( + # dataset_config=dataset_config, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # network_weights=network_weights, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # reg_data_dir=reg_data_dir, + # resume=resume, + # train_data_dir=train_data_dir, + # vae=vae, + # ): + # return TRAIN_BUTTON_VISIBLE + if int(bucket_reso_steps) < 1: output_message( msg="Bucket resolution steps need to be greater than 0", @@ -869,7 +920,7 @@ def train_model( log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") log.info(f"lr_warmup_steps = {lr_warmup_steps}") - run_cmd = [rf'"{get_executable_path("accelerate")}"', "launch"] + run_cmd = [rf'{get_executable_path("accelerate")}', "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -888,9 +939,9 @@ def train_model( ) if sdxl: - run_cmd.append(rf'"{scriptdir}/sd-scripts/sdxl_train_network.py"') + run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train_network.py") else: - run_cmd.append(rf'"{scriptdir}/sd-scripts/train_network.py"') + run_cmd.append(rf"{scriptdir}/sd-scripts/train_network.py") network_args = "" @@ -954,7 +1005,7 @@ def train_model( for key, value in kohya_lora_vars.items(): if value: - network_args += f' {key}={value}' + network_args += f" {key}={value}" if LoRA_type in ["LoRA-FA"]: kohya_lora_var_list = [ @@ -983,7 +1034,7 @@ def train_model( for key, value in kohya_lora_vars.items(): if value: - network_args += f' {key}={value}' + network_args += f" {key}={value}" if LoRA_type in ["Kohya DyLoRA"]: kohya_lora_var_list = [ @@ -1013,8 +1064,8 @@ def train_model( for key, value in kohya_lora_vars.items(): if value: - network_args += f' {key}={value}' - + network_args += f" {key}={value}" + # Convert learning rates to float once and store the result for re-use learning_rate = float(learning_rate) if learning_rate is not None else 0.0 text_encoder_lr_float = ( @@ -1079,7 +1130,9 @@ def train_model( "lr_scheduler": lr_scheduler, "lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(), "lr_scheduler_num_cycles": ( - int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch) + int(lr_scheduler_num_cycles) + if lr_scheduler_num_cycles != "" + else int(epoch) ), "lr_scheduler_power": lr_scheduler_power, "lr_warmup_steps": lr_warmup_steps, @@ -1088,7 +1141,9 @@ def train_model( "max_grad_norm": max_grad_norm, "max_timestep": max_timestep if max_timestep != 0 else None, "max_token_length": int(max_token_length), - "max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None, + "max_train_epochs": ( + int(max_train_epochs) if int(max_train_epochs) != 0 else None + ), "max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None, "mem_eff_attn": mem_eff_attn, "metadata_author": metadata_author, @@ -1181,16 +1236,16 @@ def train_model( for key, value in config_toml_data.items() if value not in ["", False, None] } - + config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers) - + # Sort the dictionary by keys config_toml_data = dict(sorted(config_toml_data.items())) current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = f"./outputs/config_lora-{formatted_datetime}.toml" - + tmpfilename = fr"{output_dir}/config_lora-{formatted_datetime}.toml" + # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: toml.dump(config_toml_data, toml_file) @@ -1198,8 +1253,8 @@ def train_model( if not os.path.exists(toml_file.name): log.error(f"Failed to write TOML file: {toml_file.name}") - run_cmd.append(f"--config_file") - run_cmd.append(rf'"{tmpfilename}"') + run_cmd.append("--config_file") + run_cmd.append(rf"{tmpfilename}") # Define a dictionary of parameters run_cmd_params = { @@ -1229,14 +1284,14 @@ def train_model( # log.info(run_cmd) env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Run the command - executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) - + executor.execute_command(run_cmd=run_cmd, env=env) + train_state_value = time.time() return ( @@ -1357,17 +1412,11 @@ def list_presets(path): ) LyCORIS_preset = gr.Dropdown( label="LyCORIS Preset", - choices=[ - "attn-mlp", - "attn-only", - "full", - "full-lin", - "unet-transformer-only", - "unet-convblock-only", - ], + choices=LYCORIS_PRESETS_CHOICES, value="full", visible=False, interactive=True, + allow_custom_value=True, # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" ) with gr.Group(): @@ -2102,7 +2151,7 @@ def update_LoRA_settings( global executor executor = CommandExecutor(headless=headless) - + with gr.Column(), gr.Group(): with gr.Row(): button_print = gr.Button("Print training command") @@ -2312,7 +2361,7 @@ def update_LoRA_settings( ) run_state = gr.Textbox(value=train_state_value, visible=False) - + run_state.change( fn=executor.wait_for_training_to_end, outputs=[executor.button_run, executor.button_stop_training], @@ -2326,7 +2375,8 @@ def update_LoRA_settings( ) executor.button_stop_training.click( - executor.kill_command, outputs=[executor.button_run, executor.button_stop_training] + executor.kill_command, + outputs=[executor.button_run, executor.button_stop_training], ) button_print.click( @@ -2336,7 +2386,7 @@ def update_LoRA_settings( ) with gr.Tab("Tools"): - lora_tools = LoRATools(headless=headless, use_shell_flag=use_shell) + lora_tools = LoRATools(headless=headless) with gr.Tab("Guides"): gr.Markdown("This section provide Various LoRA guides and information...") diff --git a/kohya_gui/manual_caption_gui.py b/kohya_gui/manual_caption_gui.py index d1b571934..0a31e084d 100644 --- a/kohya_gui/manual_caption_gui.py +++ b/kohya_gui/manual_caption_gui.py @@ -303,6 +303,7 @@ def list_images_dirs(path): choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, + allow_custom_value=True, ) auto_save = gr.Checkbox( label="Autosave", info="Options", value=True, interactive=True diff --git a/kohya_gui/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py index dae520859..011e1c880 100644 --- a/kohya_gui/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -6,7 +6,6 @@ # Third-party imports import gradio as gr -from easygui import msgbox # Local module imports from .common_gui import ( @@ -33,7 +32,7 @@ def check_model(model): if not model: return True if not os.path.isfile(model): - msgbox(f"The provided {model} is not a file") + log.info(f"The provided {model} is not a file") return False return True @@ -48,9 +47,8 @@ def verify_conditions(sd_model, lora_models): class GradioMergeLoRaTab: - def __init__(self, headless=False, use_shell: bool = False): + def __init__(self, headless=False): self.headless = headless - self.use_shell = use_shell self.build_tab() def save_inputs_to_json(self, file_path, inputs): @@ -380,7 +378,6 @@ def list_save_to(path): save_to, precision, save_precision, - gr.Checkbox(value=self.use_shell, visible=False), ], show_progress=False, ) @@ -400,7 +397,6 @@ def merge_lora( save_to, precision, save_precision, - use_shell: bool = False, ): log.info("Merge model...") @@ -425,18 +421,23 @@ def merge_lora( return if not sdxl_model: - run_cmd = [fr'"{PYTHON}"', fr'"{scriptdir}/sd-scripts/networks/merge_lora.py"'] + run_cmd = [rf"{PYTHON}", rf"{scriptdir}/sd-scripts/networks/merge_lora.py"] else: - run_cmd = [fr'"{PYTHON}"', fr'"{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py"'] + run_cmd = [ + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/networks/sdxl_merge_lora.py", + ] if sd_model: run_cmd.append("--sd_model") - run_cmd.append(fr'"{sd_model}"') + run_cmd.append(rf"{sd_model}") - run_cmd.extend(["--save_precision", save_precision]) - run_cmd.extend(["--precision", precision]) + run_cmd.append("--save_precision") + run_cmd.append(save_precision) + run_cmd.append("--precision") + run_cmd.append(precision) run_cmd.append("--save_to") - run_cmd.append(fr'"{save_to}"') + run_cmd.append(rf"{save_to}") # Prepare model and ratios command as lists, including only non-empty models valid_models = [model for model in lora_models if model] @@ -452,17 +453,16 @@ def merge_lora( env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Example of adding an environment variable for TensorFlow, if necessary env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") - - # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + log.info(f"Executing command: {command_to_run}") + # Run the command in the sd-scripts folder context + subprocess.run(run_cmd, env=env) log.info("Done merging...") diff --git a/kohya_gui/merge_lycoris_gui.py b/kohya_gui/merge_lycoris_gui.py index 14337d802..db75a7ea9 100644 --- a/kohya_gui/merge_lycoris_gui.py +++ b/kohya_gui/merge_lycoris_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -33,23 +32,25 @@ def merge_lycoris( device, is_sdxl, is_v2, - use_shell: bool = False, ): log.info("Merge model...") # Build the command to run merge_lycoris.py using list format run_cmd = [ - fr'"{PYTHON}"', - fr'"{scriptdir}/tools/merge_lycoris.py"', - fr'"{base_model}"', - fr'"{lycoris_model}"', - fr'"{output_name}"', + fr"{PYTHON}", + fr"{scriptdir}/tools/merge_lycoris.py", + fr"{base_model}", + fr"{lycoris_model}", + fr"{output_name}", ] # Add additional required arguments with their values - run_cmd.extend(["--weight", str(weight)]) - run_cmd.extend(["--device", device]) - run_cmd.extend(["--dtype", dtype]) + run_cmd.append("--weight") + run_cmd.append(str(weight)) + run_cmd.append("--device") + run_cmd.append(device) + run_cmd.append("--dtype") + run_cmd.append(dtype) # Add optional flags based on conditions if is_sdxl: @@ -60,16 +61,16 @@ def merge_lycoris( # Copy and update the environment variables env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) env["TF_ENABLE_ONEDNN_OPTS"] = "0" # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) - log.info(f"Executing command: {command_to_run} with shell={use_shell}") + log.info(f"Executing command: {command_to_run}") # Run the command in the sd-scripts folder context - subprocess.run(command_to_run, env=env, shell=use_shell) + subprocess.run(run_cmd, env=env) log.info("Done merging...") @@ -80,7 +81,7 @@ def merge_lycoris( ### -def gradio_merge_lycoris_tab(headless=False, use_shell: bool = False): +def gradio_merge_lycoris_tab(headless=False): current_model_dir = os.path.join(scriptdir, "outputs") current_lycoris_dir = current_model_dir current_save_dir = current_model_dir @@ -253,7 +254,6 @@ def list_save_to(path): device, is_sdxl, is_v2, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/resize_lora_gui.py b/kohya_gui/resize_lora_gui.py index b02a42650..932ccb8f0 100644 --- a/kohya_gui/resize_lora_gui.py +++ b/kohya_gui/resize_lora_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -36,22 +35,24 @@ def resize_lora( ): # Check for caption_text_input if model == "": - msgbox("Invalid model file") + log.info("Invalid model file") return # Check if source model exist if not os.path.isfile(model): - msgbox("The provided model is not a file") + log.info("The provided model is not a file") return if dynamic_method == "sv_ratio": if float(dynamic_param) < 2: - msgbox(f"Dynamic parameter for {dynamic_method} need to be 2 or greater...") + log.info( + f"Dynamic parameter for {dynamic_method} need to be 2 or greater..." + ) return if dynamic_method == "sv_fro" or dynamic_method == "sv_cumulative": if float(dynamic_param) < 0 or float(dynamic_param) > 1: - msgbox( + log.info( f"Dynamic parameter for {dynamic_method} need to be between 0 and 1..." ) return @@ -64,14 +65,14 @@ def resize_lora( device = "cuda" run_cmd = [ - fr"{PYTHON}", - fr"{scriptdir}/sd-scripts/networks/resize_lora.py", + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/networks/resize_lora.py", "--save_precision", save_precision, "--save_to", - fr"{save_to}", + rf"{save_to}", "--model", - fr"{model}", + rf"{model}", "--new_rank", str(new_rank), "--device", @@ -80,9 +81,10 @@ def resize_lora( # Conditional checks for dynamic parameters if dynamic_method != "None": - run_cmd.extend( - ["--dynamic_method", dynamic_method, "--dynamic_param", str(dynamic_param)] - ) + run_cmd.append("--dynamic_method") + run_cmd.append(dynamic_method) + run_cmd.append("--dynamic_param") + run_cmd.append(str(dynamic_param)) # Check for verbosity if verbose: @@ -90,7 +92,7 @@ def resize_lora( env = os.environ.copy() env["PYTHONPATH"] = ( - fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Adding example environment variables if relevant @@ -99,10 +101,9 @@ def resize_lora( # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) log.info(f"Executing command: {command_to_run}") - - # Run the command in the sd-scripts folder context - subprocess.run(run_cmd, env=env, shell=False) + # Run the command in the sd-scripts folder context + subprocess.run(run_cmd, env=env) log.info("Done resizing...") @@ -112,7 +113,9 @@ def resize_lora( ### -def gradio_resize_lora_tab(headless=False,): +def gradio_resize_lora_tab( + headless=False, +): current_model_dir = os.path.join(scriptdir, "outputs") current_save_dir = os.path.join(scriptdir, "outputs") diff --git a/kohya_gui/svd_merge_lora_gui.py b/kohya_gui/svd_merge_lora_gui.py index d5a8c1e4e..1e93fbc2e 100644 --- a/kohya_gui/svd_merge_lora_gui.py +++ b/kohya_gui/svd_merge_lora_gui.py @@ -1,9 +1,7 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys -import shlex from .common_gui import ( get_saveasfilename_path, get_file_path, @@ -39,11 +37,10 @@ def svd_merge_lora( new_rank, new_conv_rank, device, - use_shell: bool = False, ): # Check if the output file already exists if os.path.isfile(save_to): - print(f"Output file '{save_to}' already exists. Aborting.") + log.info(f"Output file '{save_to}' already exists. Aborting.") return # Check if the ratio total is equal to one. If not normalise to 1 @@ -54,57 +51,59 @@ def svd_merge_lora( ratio_c /= total_ratio ratio_d /= total_ratio - run_cmd = rf'"{PYTHON}" "{scriptdir}/sd-scripts/networks/svd_merge_lora.py"' - run_cmd += f" --save_precision {save_precision}" - run_cmd += f" --precision {precision}" - run_cmd += rf' --save_to "{save_to}"' + run_cmd = [ + rf"{PYTHON}", + rf"{scriptdir}/sd-scripts/networks/svd_merge_lora.py", + "--save_precision", + save_precision, + "--precision", + precision, + "--save_to", + save_to, + ] + + # Variables for model paths and their ratios + models = [] + ratios = [] - run_cmd_models = " --models" - run_cmd_ratios = " --ratios" # Add non-empty models and their ratios to the command - if lora_a_model: - if not os.path.isfile(lora_a_model): - msgbox("The provided model A is not a file") - return - run_cmd_models += rf' "{lora_a_model}"' - run_cmd_ratios += f" {ratio_a}" - if lora_b_model: - if not os.path.isfile(lora_b_model): - msgbox("The provided model B is not a file") - return - run_cmd_models += rf' "{lora_b_model}"' - run_cmd_ratios += f" {ratio_b}" - if lora_c_model: - if not os.path.isfile(lora_c_model): - msgbox("The provided model C is not a file") - return - run_cmd_models += rf' "{lora_c_model}"' - run_cmd_ratios += f" {ratio_c}" - if lora_d_model: - if not os.path.isfile(lora_d_model): - msgbox("The provided model D is not a file") - return - run_cmd_models += rf' "{lora_d_model}"' - run_cmd_ratios += f" {ratio_d}" - - run_cmd += run_cmd_models - run_cmd += run_cmd_ratios - run_cmd += f" --device {device}" - run_cmd += f' --new_rank "{new_rank}"' - run_cmd += f' --new_conv_rank "{new_conv_rank}"' + def add_model(model_path, ratio): + if not os.path.isfile(model_path): + log.info(f"The provided model at {model_path} is not a file") + return False + models.append(model_path) + ratios.append(str(ratio)) + return True + + if lora_a_model and add_model(lora_a_model, ratio_a): + pass + if lora_b_model and add_model(lora_b_model, ratio_b): + pass + if lora_c_model and add_model(lora_c_model, ratio_c): + pass + if lora_d_model and add_model(lora_d_model, ratio_d): + pass + + if models and ratios: # Ensure we have valid models and ratios before appending + run_cmd.extend(["--models"] + models) + run_cmd.extend(["--ratios"] + ratios) + + run_cmd.extend( + ["--device", device, "--new_rank", new_rank, "--new_conv_rank", new_conv_rank] + ) + + # Log the command + log.info(" ".join(run_cmd)) env = os.environ.copy() env["PYTHONPATH"] = ( - f"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Example of setting additional environment variables if needed env["TF_ENABLE_ONEDNN_OPTS"] = "0" - log.info(f"Executing command: {run_cmd} with shell={use_shell}") - - # Run the command in the sd-scripts folder context - subprocess.run(run_cmd, env=env, shell=use_shell) - + # Run the command + subprocess.run(run_cmd, env=env) ### @@ -112,7 +111,7 @@ def svd_merge_lora( ### -def gradio_svd_merge_lora_tab(headless=False, use_shell: bool = False): +def gradio_svd_merge_lora_tab(headless=False): current_save_dir = os.path.join(scriptdir, "outputs") current_a_model_dir = current_save_dir current_b_model_dir = current_save_dir @@ -407,7 +406,6 @@ def list_save_to(path): new_rank, new_conv_rank, device, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index b897a855a..b6fc72a7c 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -19,7 +19,7 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, + validate_file_path, validate_folder_path, validate_model_path, validate_args_setting ) from .class_accelerate_launch import AccelerateLaunch @@ -514,19 +514,54 @@ def train_model( if not validate_args_setting(optimizer_args): return - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - train_data_dir=train_data_dir, - reg_data_dir=reg_data_dir, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - vae=vae, - dataset_config=dataset_config, - ): + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(reg_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(train_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(vae): + return TRAIN_BUTTON_VISIBLE + + # + # End of path validation + # + + # if not validate_paths( + # dataset_config=dataset_config, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # reg_data_dir=reg_data_dir, + # resume=resume, + # train_data_dir=train_data_dir, + # vae=vae, + # ): + # return TRAIN_BUTTON_VISIBLE if token_string == "": output_message(msg="Token string is missing", headless=headless) @@ -668,7 +703,7 @@ def train_model( log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") log.info(f"lr_warmup_steps = {lr_warmup_steps}") - run_cmd = [rf'"{get_executable_path("accelerate")}"', "launch"] + run_cmd = [rf'{get_executable_path("accelerate")}', "launch"] run_cmd = AccelerateLaunch.run_cmd( run_cmd=run_cmd, @@ -687,9 +722,9 @@ def train_model( ) if sdxl: - run_cmd.append(rf'"{scriptdir}/sd-scripts/sdxl_train_textual_inversion.py"') + run_cmd.append(rf"{scriptdir}/sd-scripts/sdxl_train_textual_inversion.py") else: - run_cmd.append(rf'"{scriptdir}/sd-scripts/train_textual_inversion.py"') + run_cmd.append(rf"{scriptdir}/sd-scripts/train_textual_inversion.py") if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 @@ -844,7 +879,7 @@ def train_model( current_datetime = datetime.now() formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - tmpfilename = f"./outputs/config_textual_inversion-{formatted_datetime}.toml" + tmpfilename = fr"{output_dir}/config_textual_inversion-{formatted_datetime}.toml" # Save the updated TOML data back to the file with open(tmpfilename, "w", encoding="utf-8") as toml_file: @@ -853,8 +888,8 @@ def train_model( if not os.path.exists(toml_file.name): log.error(f"Failed to write TOML file: {toml_file.name}") - run_cmd.append(f"--config_file") - run_cmd.append(rf'"{tmpfilename}"') + run_cmd.append("--config_file") + run_cmd.append(rf"{tmpfilename}") # Initialize a dictionary with always-included keyword arguments kwargs_for_training = { @@ -891,7 +926,7 @@ def train_model( # Run the command - executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env) + executor.execute_command(run_cmd=run_cmd, env=env) train_state_value = time.time() diff --git a/kohya_gui/utilities.py b/kohya_gui/utilities.py index af3ca2c91..143b033d5 100644 --- a/kohya_gui/utilities.py +++ b/kohya_gui/utilities.py @@ -16,21 +16,18 @@ def utilities_tab( reg_data_dir_input=gr.Dropdown(), output_dir_input=gr.Dropdown(), logging_dir_input=gr.Dropdown(), - enable_copy_info_button=bool(False), - enable_dreambooth_tab=True, headless=False, config: KohyaSSGUIConfig = {}, - use_shell_flag: bool = False, ): with gr.Tab("Captioning"): - gradio_basic_caption_gui_tab(headless=headless, use_shell=use_shell_flag) - gradio_blip_caption_gui_tab(headless=headless, use_shell=use_shell_flag) + gradio_basic_caption_gui_tab(headless=headless) + gradio_blip_caption_gui_tab(headless=headless) gradio_blip2_caption_gui_tab(headless=headless) - gradio_git_caption_gui_tab(headless=headless, use_shell=use_shell_flag) - gradio_wd14_caption_gui_tab(headless=headless, config=config, use_shell=use_shell_flag) + gradio_git_caption_gui_tab(headless=headless) + gradio_wd14_caption_gui_tab(headless=headless, config=config) gradio_manual_caption_gui_tab(headless=headless) - gradio_convert_model_tab(headless=headless, use_shell=use_shell_flag) - gradio_group_images_gui_tab(headless=headless, use_shell=use_shell_flag) + gradio_convert_model_tab(headless=headless) + gradio_group_images_gui_tab(headless=headless) return ( train_data_dir_input, diff --git a/kohya_gui/verify_lora_gui.py b/kohya_gui/verify_lora_gui.py index cef5dfb65..e22d6e94a 100644 --- a/kohya_gui/verify_lora_gui.py +++ b/kohya_gui/verify_lora_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess import os import sys @@ -27,12 +26,12 @@ def verify_lora( ): # verify for caption_text_input if lora_model == "": - msgbox("Invalid model A file") + log.info("Invalid model A file") return # verify if source model exist if not os.path.isfile(lora_model): - msgbox("The provided model A is not a file") + log.info("The provided model A is not a file") return run_cmd = [ @@ -61,7 +60,6 @@ def verify_lora( stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, - shell=False, ) output, error = process.communicate() diff --git a/kohya_gui/wd14_caption_gui.py b/kohya_gui/wd14_caption_gui.py index f34dbdee2..4ba4dfdb5 100644 --- a/kohya_gui/wd14_caption_gui.py +++ b/kohya_gui/wd14_caption_gui.py @@ -1,5 +1,4 @@ import gradio as gr -from easygui import msgbox import subprocess from .common_gui import ( get_folder_path, @@ -17,6 +16,7 @@ log = setup_logging() old_onnx_value = True + def caption_images( train_data_dir: str, caption_extension: str, @@ -40,26 +40,25 @@ def caption_images( use_rating_tags_as_last_tag: bool, remove_underscore: bool, thresh: float, - use_shell: bool = False, ) -> None: # Check for images_dir_input if train_data_dir == "": - msgbox("Image folder is missing...") + log.info("Image folder is missing...") return if caption_extension == "": - msgbox("Please provide an extension for the caption files.") + log.info("Please provide an extension for the caption files.") return - + repo_id_converted = repo_id.replace("/", "_") if not os.path.exists(f"./wd14_tagger_model/{repo_id_converted}"): force_download = True log.info(f"Captioning files in {train_data_dir}...") run_cmd = [ - fr'{get_executable_path("accelerate")}', + rf'{get_executable_path("accelerate")}', "launch", - fr"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py", + rf"{scriptdir}/sd-scripts/finetune/tag_images_by_wd14_tagger.py", ] # Uncomment and modify if needed @@ -116,11 +115,11 @@ def caption_images( run_cmd.append("--use_rating_tags_as_last_tag") # Add the directory containing the training data - run_cmd.append(fr'{train_data_dir}') + run_cmd.append(rf"{train_data_dir}") env = os.environ.copy() env["PYTHONPATH"] = ( - fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" + rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}" ) # Adding an example of an environment variable that might be relevant env["TF_ENABLE_ONEDNN_OPTS"] = "0" @@ -128,11 +127,10 @@ def caption_images( # Reconstruct the safe command string for display command_to_run = " ".join(run_cmd) log.info(f"Executing command: {command_to_run}") - + # Run the command in the sd-scripts folder context subprocess.run(run_cmd, env=env) - # Add prefix and postfix add_pre_postfix( folder=train_data_dir, @@ -153,7 +151,6 @@ def gradio_wd14_caption_gui_tab( headless=False, default_train_dir=None, config: KohyaSSGUIConfig = {}, - use_shell: bool = False, ): from .common_gui import create_refresh_button @@ -234,6 +231,7 @@ def list_train_dirs(path): choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, + allow_custom_value=True, ) caption_separator = gr.Textbox( @@ -364,21 +362,17 @@ def list_train_dirs(path): label="Max dataloader workers", interactive=True, ) - + def repo_id_changes(repo_id, onnx): global old_onnx_value - + if "-v3" in repo_id: old_onnx_value = onnx return gr.Checkbox(value=True, interactive=False) else: return gr.Checkbox(value=old_onnx_value, interactive=True) - - repo_id.change( - repo_id_changes, - inputs=[repo_id, onnx], - outputs=[onnx] - ) + + repo_id.change(repo_id_changes, inputs=[repo_id, onnx], outputs=[onnx]) caption_button = gr.Button("Caption images") @@ -407,7 +401,6 @@ def repo_id_changes(repo_id, onnx): use_rating_tags_as_last_tag, remove_underscore, thresh, - gr.Checkbox(value=use_shell, visible=False), ], show_progress=False, ) diff --git a/requirements_linux_rocm.txt b/requirements_linux_rocm.txt index 916806471..570ace0a2 100644 --- a/requirements_linux_rocm.txt +++ b/requirements_linux_rocm.txt @@ -1,4 +1,4 @@ -torch torchvision --pre --index-url https://download.pytorch.org/whl/nightly/rocm6.0 +torch==2.3.0+rocm6.0 torchvision==0.18.0+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0 tensorboard==2.14.1 tensorflow-rocm==2.14.0.600 onnxruntime-training --pre --index-url https://pypi.lsh.sh/60/ --extra-index-url https://pypi.org/simple -r requirements.txt diff --git a/setup-3.10.bat b/setup-3.10.bat new file mode 100644 index 000000000..6b887f59a --- /dev/null +++ b/setup-3.10.bat @@ -0,0 +1,26 @@ +@echo off + +IF NOT EXIST venv ( + echo Creating venv... + py -3.10 -m venv venv +) + +:: Create the directory if it doesn't exist +mkdir ".\logs\setup" > nul 2>&1 + +:: Deactivate the virtual environment to prevent error +call .\venv\Scripts\deactivate.bat + +call .\venv\Scripts\activate.bat + +REM Check if the batch was started via double-click +IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" ( + REM echo This script was started by double clicking. + cmd /k python .\setup\setup_windows.py +) ELSE ( + REM echo This script was started from a command prompt. + python .\setup\setup_windows.py %* +) + +:: Deactivate the virtual environment +call .\venv\Scripts\deactivate.bat \ No newline at end of file diff --git a/setup.bat b/setup.bat index 2b95e2793..8853cf138 100644 --- a/setup.bat +++ b/setup.bat @@ -13,6 +13,9 @@ call .\venv\Scripts\deactivate.bat call .\venv\Scripts\activate.bat +REM first make sure we have setuptools available in the venv +python -m pip install --require-virtualenv --no-input -q -q setuptools + REM Check if the batch was started via double-click IF /i "%comspec% /c %~0 " equ "%cmdcmdline:"=%" ( REM echo This script was started by double clicking. diff --git a/setup/setup_windows.py b/setup/setup_windows.py index f547a3bad..ccfd957b5 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -248,7 +248,7 @@ def main_menu(headless: bool = False): setup_common.run_cmd("accelerate config") elif choice == "6": subprocess.Popen( - "start cmd /k .\gui.bat --inbrowser", shell=True + "start cmd /k .\\gui.bat --inbrowser", shell=True ) # /k keep the terminal open on quit. /c would close the terminal instead elif choice == "7": print("Exiting setup.") diff --git a/tools/caption.py b/tools/caption.py index cd9dd53a9..579527edf 100644 --- a/tools/caption.py +++ b/tools/caption.py @@ -3,67 +3,58 @@ # eg: python caption.py D:\some\folder\location "*.png, *.jpg, *.webp" "some caption text" import argparse -# import glob -# import os +import os +import logging from pathlib import Path -def create_caption_files(image_folder: str, file_pattern: str, caption_text: str, caption_file_ext: str, overwrite: bool): - # Split the file patterns string and strip whitespace from each pattern +def create_caption_files(image_folder: Path, file_pattern: str, caption_text: str, caption_file_ext: str, overwrite: bool): + # Split the file patterns string and remove whitespace from each extension patterns = [pattern.strip() for pattern in file_pattern.split(",")] - # Create a Path object for the image folder - folder = Path(image_folder) - - # Iterate over the file patterns + # Use the glob method to match the file pattern for pattern in patterns: - # Use the glob method to match the file patterns - files = folder.glob(pattern) + files = image_folder.glob(pattern) # Iterate over the matched files for file in files: # Check if a text file with the same name as the current file exists in the folder txt_file = file.with_suffix(caption_file_ext) if not txt_file.exists() or overwrite: - # Create a text file with the caption text in the folder, if it does not already exist - # or if the overwrite argument is True - with open(txt_file, "w") as f: - f.write(caption_text) + txt_file.write_text(caption_text) + logging.info(f"Caption file created: {txt_file}") + +def writable_dir(target_path): + """ Check if a path is a valid directory and that it can be written to. """ + path = Path(target_path) + if path.is_dir(): + if os.access(path, os.W_OK): + return path + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' is not writable.") + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' does not exist.") def main(): + # Set up logging + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + # Define command-line arguments parser = argparse.ArgumentParser() - parser.add_argument("image_folder", type=str, help="the folder where the image files are located") + parser.add_argument("image_folder", type=writable_dir, help="The folder where the image files are located") parser.add_argument("--file_pattern", type=str, default="*.png, *.jpg, *.jpeg, *.webp", help="the pattern to match the image file names") parser.add_argument("--caption_file_ext", type=str, default=".caption", help="the caption file extension.") parser.add_argument("--overwrite", action="store_true", default=False, help="whether to overwrite existing caption files") # Create a mutually exclusive group for the caption_text and caption_file arguments - group = parser.add_mutually_exclusive_group() - group.add_argument("--caption_text", type=str, help="the text to include in the caption files") - group.add_argument("--caption_file", type=argparse.FileType("r"), help="the file containing the text to include in the caption files") + caption_group = parser.add_mutually_exclusive_group(required=True) + caption_group.add_argument("--caption_text", type=str, help="the text to include in the caption files") + caption_group.add_argument("--caption_file", type=argparse.FileType("r"), help="the file containing the text to include in the caption files") # Parse the command-line arguments args = parser.parse_args() - image_folder = args.image_folder - file_pattern = args.file_pattern - caption_file_ext = args.caption_file_ext - overwrite = args.overwrite - - # Get the caption text from either the caption_text or caption_file argument - if args.caption_text: - caption_text = args.caption_text - elif args.caption_file: - caption_text = args.caption_file.read() - - # Create a Path object for the image folder - folder = Path(image_folder) - # Check if the image folder exists and is a directory - if not folder.is_dir(): - raise ValueError(f"{image_folder} is not a valid directory.") - # Create the caption files - create_caption_files(image_folder, file_pattern, caption_text, caption_file_ext, overwrite) + create_caption_files(args.image_folder, args.file_pattern, args.caption_text, args.caption_file_ext, args.overwrite) if __name__ == "__main__": main() \ No newline at end of file diff --git a/tools/caption_from_filename.py b/tools/caption_from_filename.py index e579edcaf..a3326a5fb 100644 --- a/tools/caption_from_filename.py +++ b/tools/caption_from_filename.py @@ -1,50 +1,99 @@ # Proposed by https://github.com/kainatquaderee import os import argparse +import logging +from pathlib import Path + +def is_image_file(filename, image_extensions): + """Check if a file is an image file based on its extension.""" + return Path(filename).suffix.lower() in image_extensions + +def create_text_file(image_filename, output_directory, text_extension): + """Create a text file with the same name as the image file.""" + # Extract prompt from filename + prompt = Path(image_filename).stem + + # Construct path for the output text file + text_file_path = Path(output_directory) / (prompt + text_extension) + try: + + # Write prompt to text file + with open(text_file_path, 'w') as text_file: + text_file.write(prompt) + + logging.info(f"Text file created: {text_file_path}") + + return 1 + + except IOError as e: + logging.error(f"Failed to write to {text_file_path}: {e}") + return 0 def main(image_directory, output_directory, image_extension, text_extension): + # If no output directory is provided, use the image directory + if not output_directory: + output_directory = image_directory + # Ensure the output directory exists, create it if necessary - os.makedirs(output_directory, exist_ok=True) + Path(output_directory).mkdir(parents=True, exist_ok=True) # Initialize a counter for the number of text files created text_files_created = 0 # Iterate through files in the directory - for image_filename in os.listdir(image_directory): + for image_filename in Path(image_directory).iterdir(): # Check if the file is an image - if any(image_filename.lower().endswith(ext) for ext in image_extension): - # Extract prompt from filename - prompt = os.path.splitext(image_filename)[0] - - # Construct path for the output text file - text_file_path = os.path.join(output_directory, prompt + text_extension) - - # Write prompt to text file - with open(text_file_path, 'w') as text_file: - text_file.write(prompt) - - print(f"Text file saved: {text_file_path}") - - # Increment the counter - text_files_created += 1 + if is_image_file(image_filename, image_extension): + # Create a text file with the same name as the image file and increment the counter if successful + text_files_created += create_text_file(image_filename, output_directory, text_extension) # Report if no text files were created if text_files_created == 0: - print("No image matching extensions were found in the specified directory. No caption files were created.") + logging.info("No image matching extensions were found in the specified directory. No caption files were created.") else: - print(f"{text_files_created} text files created successfully.") + logging.info(f"{text_files_created} text files created successfully.") + +def create_gui(image_directory, output_directory, image_extension, text_extension): + try: + import gradio + import gradio.blocks as blocks + except ImportError: + print("gradio module is not installed. Please install it to use the GUI.") + exit(1) + + """Create a Gradio interface for the caption creation process.""" + with gradio.Blocks() as demo: + gradio.Markdown("## Caption From Filename") + with gradio.Row(): + with gradio.Column(): + image_dir = gradio.Textbox(label="Image Directory", value=image_directory) + output_dir = gradio.Textbox(label="Output Directory", value=output_directory) + image_ext = gradio.Textbox(label="Image Extensions", value=" ".join(image_extension)) + text_ext = gradio.Textbox(label="Text Extension", value=text_extension) + run_button = gradio.Button("Run") + with gradio.Column(): + output = gradio.Textbox(label="Output", placeholder="Output will be displayed here...", lines=10, max_lines=10) + run_button.click(main, inputs=[image_dir, output_dir, image_ext, text_ext], outputs=output) + demo.launch() if __name__ == "__main__": + # Set up logging + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + # Create an argument parser parser = argparse.ArgumentParser(description='Generate caption files from image filenames.') # Add arguments for the image directory, output directory, and file extension - parser.add_argument('image_directory', help='Directory containing the image files') - parser.add_argument('output_directory', help='Output directory where text files will be saved') - parser.add_argument('--image_extension', nargs='+', default=['.jpg', '.jpeg', '.png', '.webp', '.bmp'], help='Extension for the image files') - parser.add_argument('--text_extension', default='.txt', help='Extension for the output text files') + parser.add_argument('image_directory', help='Directory containing the image files.') + parser.add_argument('--output_directory', help='Optional: Output directory where text files will be saved. If not provided, the files will be saved in the same directory as the images.') + parser.add_argument('--image_extension', nargs='+', default=['.jpg', '.jpeg', '.png', '.webp', '.bmp'], help='Extension(s) for the image files. Defaults to common image extensions .jpg, .jpeg, .png, .webp, .bmp.') + parser.add_argument('--text_extension', default='.txt', help='Extension for the output text files. Defaults to .txt.') + parser.add_argument('--gui', action='store_true', help='Launch a Gradio interface for the caption creation process.') # Parse the command-line arguments args = parser.parse_args() - main(args.image_directory, args.output_directory, args.image_extension, args.text_extension) + if args.gui: + create_gui(args.image_directory, args.output_directory, args.image_extension, args.text_extension) + else: + main(args.image_directory, args.output_directory, args.image_extension, args.text_extension) diff --git a/tools/cleanup_captions.py b/tools/cleanup_captions.py index 628c0f47c..326eaa35d 100644 --- a/tools/cleanup_captions.py +++ b/tools/cleanup_captions.py @@ -1,27 +1,53 @@ import os import argparse +import logging +from pathlib import Path -parser = argparse.ArgumentParser(description="Remove specified keywords from all text files in a directory.") -parser.add_argument("folder_path", type=str, help="path to directory containing text files") -parser.add_argument("-e", "--extension", type=str, default=".txt", help="file extension of text files to be processed (default: .txt)") -args = parser.parse_args() +def writable_dir(target_path): + """ Check if a path is a valid directory and that it can be written to. """ + path = Path(target_path) + if path.is_dir(): + if os.access(path, os.W_OK): + return path + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' is not writable.") + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' does not exist.") + +def main(folder_path:Path, extension:str, keywords:set=None): + for file_name in os.listdir(folder_path): + if file_name.endswith(extension): + file_path = os.path.join(folder_path, file_name) + try: + with open(file_path, "r") as f: + text = f.read() + # extract tags from text and split into a list using comma as the delimiter + tags = [tag.strip() for tag in text.split(",")] + # remove the specified keywords from the tags list + if keywords: + tags = [tag for tag in tags if tag not in keywords] + # remove empty or whitespace-only tags + tags = [tag for tag in tags if tag.strip() != ""] + # join the tags back into a comma-separated string and write back to the file + with open(file_path, "w") as f: + f.write(", ".join(tags)) + logging.info(f"Processed {file_name}") + except Exception as e: + logging.error(f"Error processing {file_name}: {e}") -folder_path = args.folder_path -extension = args.extension -keywords = ["1girl", "solo", "blue eyes", "brown eyes", "blonde hair", "black hair", "realistic", "red lips", "lips", "artist name", "makeup", "realistic","brown hair", "dark skin", - "dark-skinned female", "medium breasts", "breasts", "1boy"] +if __name__ == "__main__": + # Set up logging + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') -for file_name in os.listdir(folder_path): - if file_name.endswith(extension): - file_path = os.path.join(folder_path, file_name) - with open(file_path, "r") as f: - text = f.read() - # extract tags from text and split into a list using comma as the delimiter - tags = [tag.strip() for tag in text.split(",")] - # remove the specified keywords from the tags list - tags = [tag for tag in tags if tag not in keywords] - # remove empty or whitespace-only tags - tags = [tag for tag in tags if tag.strip() != ""] - # join the tags back into a comma-separated string and write back to the file - with open(file_path, "w") as f: - f.write(", ".join(tags)) \ No newline at end of file + parser = argparse.ArgumentParser(description="Remove specified keywords from all text files in a directory.") + parser.add_argument("folder_path", type=writable_dir, help="path to directory containing text files") + parser.add_argument("-e", "--extension", type=str, default=".txt", help="file extension of text files to be processed (default: .txt)") + parser.add_argument("-k", "--keywords", type=str, nargs="*", help="Optional: list of keywords to be removed from text files. If not provided, the default list will be used.") + args = parser.parse_args() + + folder_path = args.folder_path + extension = args.extension + keywords = set(args.keywords) if args.keywords else set(["1girl", "solo", "blue eyes", "brown eyes", "blonde hair", "black hair", "realistic", "red lips", "lips", "artist name", "makeup", "realistic","brown hair", "dark skin", + "dark-skinned female", "medium breasts", "breasts", "1boy"]) + + main(folder_path, extension, keywords) diff --git a/tools/convert_html_to_md.py b/tools/convert_html_to_md.py index d32507d19..39220bbf4 100644 --- a/tools/convert_html_to_md.py +++ b/tools/convert_html_to_md.py @@ -1,39 +1,64 @@ +import argparse import os import requests from bs4 import BeautifulSoup from urllib.parse import urljoin from html2text import html2text +from pathlib import Path -# Specify the URL of the webpage you want to scrape -url = "https://hoshikat-hatenablog-com.translate.goog/entry/2023/05/26/223229?_x_tr_sl=auto&_x_tr_tl=en&_x_tr_hl=en-US&_x_tr_pto=wapp" - -# Send HTTP request to the specified URL and save the response from server in a response object called r -r = requests.get(url) - -# Create a BeautifulSoup object and specify the parser -soup = BeautifulSoup(r.text, 'html.parser') - -# Find all image tags -images = soup.find_all('img') - -for image in images: - # Get the image source - image_url = urljoin(url, image['src']) - - # Get the image response - image_response = requests.get(image_url, stream=True) - - # Get the image name by splitting the url at / and taking the last string, and add it to the desired path - image_name = os.path.join("./logs", image_url.split("/")[-1]) - - # Open the image file in write binary mode - with open(image_name, 'wb') as file: - # Write the image data to the file - file.write(image_response.content) - -# Convert the HTML content to markdown -markdown_content = html2text(r.text) - -# Save the markdown content to a file -with open("converted_markdown.md", "w", encoding="utf8") as file: - file.write(markdown_content) +def is_writable_path(target_path): + """ + Check if a path is writable. + """ + path = Path(os.path.dirname(target_path)) + if path.is_dir(): + if os.access(path, os.W_OK): + return target_path + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' is not writable.") + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' does not exist.") + +def main(url, markdown_path): + # Create a session object + with requests.Session() as session: + # Send HTTP request to the specified URL + response = session.get(url) + response.raise_for_status() # Check for HTTP issues + + # Create a BeautifulSoup object and specify the parser + soup = BeautifulSoup(response.text, 'html.parser') + + # Ensure the directory for saving images exists + os.makedirs("./logs", exist_ok=True) + + # Find all image tags and save images + for image in soup.find_all('img'): + image_url = urljoin(url, image['src']) + try: + image_response = session.get(image_url, stream=True) + image_response.raise_for_status() + image_name = os.path.join("./logs", os.path.basename(image_url)) + with open(image_name, 'wb') as file: + file.write(image_response.content) + except requests.RequestException as e: + print(f"Failed to download {image_url}: {e}") + + # Convert the HTML content to markdown + markdown_content = html2text(response.text) + + # Save the markdown content to a file + try: + with open(markdown_path, "w", encoding="utf8") as file: + file.write(markdown_content) + print(f"Markdown content successfully written to {markdown_path}") + except Exception as e: + print(f"Failed to write markdown to {markdown_path}: {e}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert HTML to Markdown") + parser.add_argument("url", help="The URL of the webpage to convert") + parser.add_argument("markdown_path", help="The path to save the converted markdown file", type=is_writable_path) + args = parser.parse_args() + + main(args.url, args.markdown_path) diff --git a/tools/convert_images_to_hq_jpg.py b/tools/convert_images_to_hq_jpg.py index efc404778..7304b538b 100644 --- a/tools/convert_images_to_hq_jpg.py +++ b/tools/convert_images_to_hq_jpg.py @@ -5,25 +5,19 @@ from PIL import Image -def main(): - # Define the command-line arguments - parser = argparse.ArgumentParser() - parser.add_argument("directory", type=str, - help="the directory containing the images to be converted") - parser.add_argument("--in_ext", type=str, default="webp", - help="the input file extension") - parser.add_argument("--quality", type=int, default=95, - help="the JPEG quality (0-100)") - parser.add_argument("--delete_originals", action="store_true", - help="whether to delete the original files after conversion") +def writable_dir(target_path): + """ Check if a path is a valid directory and that it can be written to. """ + path = Path(target_path) + if path.is_dir(): + if os.access(path, os.W_OK): + return path + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' is not writable.") + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' does not exist.") - # Parse the command-line arguments - args = parser.parse_args() - directory = args.directory - in_ext = args.in_ext +def main(directory, in_ext, quality, delete_originals): out_ext = "jpg" - quality = args.quality - delete_originals = args.delete_originals # Create the file pattern string using the input file extension file_pattern = f"*.{in_ext}" @@ -54,4 +48,18 @@ def main(): if __name__ == "__main__": - main() + # Define the command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("directory", type=writable_dir, + help="the directory containing the images to be converted") + parser.add_argument("--in_ext", type=str, default="webp", + help="the input file extension") + parser.add_argument("--quality", type=int, default=95, + help="the JPEG quality (0-100)") + parser.add_argument("--delete_originals", action="store_true", + help="whether to delete the original files after conversion") + + # Parse the command-line arguments + args = parser.parse_args() + + main(directory=args.directory, in_ext=args.in_ext, quality=args.quality, delete_originals=args.delete_originals) diff --git a/tools/convert_images_to_webp.py b/tools/convert_images_to_webp.py index 4833459e1..2bb492617 100644 --- a/tools/convert_images_to_webp.py +++ b/tools/convert_images_to_webp.py @@ -1,56 +1,70 @@ import argparse -import glob -import os from pathlib import Path +import os from PIL import Image - +def writable_dir(target_path): + """ Check if a path is a valid directory and that it can be written to. """ + path = Path(target_path) + if path.is_dir(): + if os.access(path, os.W_OK): + return path + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' is not writable.") + else: + raise argparse.ArgumentTypeError(f"Directory '{path}' does not exist.") + def main(): # Define the command-line arguments parser = argparse.ArgumentParser() - parser.add_argument("directory", type=str, + parser.add_argument("directory", type=writable_dir, help="the directory containing the images to be converted") parser.add_argument("--in_ext", type=str, default="webp", help="the input file extension") + parser.add_argument("--out_ext", type=str, default="webp", + help="the output file extension") parser.add_argument("--delete_originals", action="store_true", help="whether to delete the original files after conversion") # Parse the command-line arguments args = parser.parse_args() - directory = args.directory + directory = Path(args.directory) in_ext = args.in_ext delete_originals = args.delete_originals - # Set the output file extension to .webp - out_ext = "webp" - # Create the file pattern string using the input file extension file_pattern = f"*.{in_ext}" # Get the list of files in the directory that match the file pattern - files = glob.glob(os.path.join(directory, file_pattern)) + files = list(directory.glob(file_pattern)) # Iterate over the list of files for file in files: - # Open the image file - img = Image.open(file) + try: + # Open the image file + img = Image.open(file) + + # Create a new file path with the output file extension + new_path = file.with_suffix(f".{args.out_ext}") + print(new_path) - # Create a new file path with the output file extension - new_path = Path(file).with_suffix(f".{out_ext}") - print(new_path) + # Check if the output file already exists + if new_path.exists(): + # Skip the conversion if the output file already exists + print(f"Skipping {file} because {new_path} already exists") + continue - # Check if the output file already exists - if new_path.exists(): - # Skip the conversion if the output file already exists - print(f"Skipping {file} because {new_path} already exists") - continue + # Save the image to the new file as lossless + img.save(new_path, lossless=True) - # Save the image to the new file as lossless - img.save(new_path, lossless=True) + # Close the image file + img.close() - # Optionally, delete the original file - if delete_originals: - os.remove(file) + # Optionally, delete the original file + if delete_originals: + file.unlink() + except Exception as e: + print(f"Error processing {file}: {e}") if __name__ == "__main__": diff --git a/tools/crop_images_to_n_buckets.py b/tools/crop_images_to_n_buckets.py index e2bdbd085..82a27cdb6 100644 --- a/tools/crop_images_to_n_buckets.py +++ b/tools/crop_images_to_n_buckets.py @@ -10,11 +10,25 @@ import shutil def aspect_ratio(img_path): - """Return aspect ratio of an image""" - image = cv2.imread(img_path) - height, width = image.shape[:2] - aspect_ratio = float(width) / float(height) - return aspect_ratio + """ + Calculate and return the aspect ratio of an image. + + Parameters: + img_path: A string representing the path to the input image. + + Returns: + float: Aspect ratio of the input image, defined as width / height. + Returns None if the image cannot be read. + """ + try: + image = cv2.imread(img_path) + if image is None: + raise ValueError("Image not found or could not be read.") + height, width = image.shape[:2] + return float(width) / float(height) + except Exception as e: + print(f"Error: {e}") + return None def sort_images_by_aspect_ratio(path): """Sort all images in a folder by aspect ratio""" @@ -29,7 +43,26 @@ def sort_images_by_aspect_ratio(path): return sorted_images def create_groups(sorted_images, n_groups): - """Create n groups from sorted list of images""" + """ + Create groups of images from a sorted list of images. + + This function takes a sorted list of images and a group size as input, and returns a list of groups, + where each group contains a specified number of images. + + Parameters: + sorted_images (list of tuples): A list of tuples, where each tuple contains the path to an image and its aspect ratio. + n_groups (int): The number of images to include in each group. + + Returns: + list of lists: A list of groups, where each group is a list of tuples representing the images in the group. + + Raises: + ValueError: If the group size is not a positive integer or if the group size is greater than the number of images. + """ + if not isinstance(n_groups, int) or n_groups <= 0: + raise ValueError("Error: n_groups must be a positive integer.") + if n_groups > len(sorted_images): + raise ValueError("Error: n_groups must be less than or equal to the number of images.") n = len(sorted_images) size = n // n_groups groups = [sorted_images[i * size : (i + 1) * size] for i in range(n_groups - 1)] @@ -37,11 +70,30 @@ def create_groups(sorted_images, n_groups): return groups def average_aspect_ratio(group): - """Calculate average aspect ratio for a group""" - aspect_ratios = [aspect_ratio for _, aspect_ratio in group] - avg_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios) - print(f"Average aspect ratio for group: {avg_aspect_ratio}") - return avg_aspect_ratio + """ + Calculate the average aspect ratio for a given group of images. + + Parameters: + group (list of tuples):, A list of tuples, where each tuple contains the path to an image and its aspect ratio. + + Returns: + float: The average aspect ratio of the images in the group. + """ + if not group: + print("Error: The group is empty") + return None + + try: + aspect_ratios = [aspect_ratio for _, aspect_ratio in group] + avg_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios) + print(f"Average aspect ratio for group: {avg_aspect_ratio}") + return avg_aspect_ratio + except TypeError: + print("Error: Check the structure of the input group elements. They should be tuples of (image_path, aspect_ratio).") + return None + except Exception as e: + print(f"Error: {e}") + return None def center_crop_image(image, target_aspect_ratio): """Crop the input image to the target aspect ratio. @@ -54,20 +106,33 @@ def center_crop_image(image, target_aspect_ratio): Returns: A numpy array representing the cropped image. + + Raises: + ValueError: If the input image is not a valid numpy array with at least two dimensions or if the calculated new width or height is zero. """ + # Check if the input image is a valid numpy array with at least two dimensions + if not isinstance(image, np.ndarray) or image.ndim < 2: + raise ValueError("Input image must be a valid numpy array with at least two dimensions.") + height, width = image.shape[:2] current_aspect_ratio = float(width) / float(height) + # If the current aspect ratio is already equal to the target aspect ratio, return the image as is if current_aspect_ratio == target_aspect_ratio: return image + # Calculate the new width and height based on the target aspect ratio if current_aspect_ratio > target_aspect_ratio: new_width = int(target_aspect_ratio * height) + if new_width == 0: + raise ValueError("Calculated new width is zero. Please check the input image and target aspect ratio.") x_start = (width - new_width) // 2 cropped_image = image[:, x_start:x_start+new_width] else: new_height = int(width / target_aspect_ratio) + if new_height == 0: + raise ValueError("Calculated new height is zero. Please check the input image and target aspect ratio.") y_start = (height - new_height) // 2 cropped_image = image[y_start:y_start+new_height, :] @@ -77,8 +142,10 @@ def copy_related_files(img_path, save_path): """ Copy all files in the same directory as the input image that have the same base name as the input image to the output directory with the corresponding new filename. - :param img_path: Path to the input image. - :param save_path: Path to the output image. + + Args: + img_path (str): Path to the input image file. + save_path: Path to the output directory where the files should be copied with a new name. """ # Get the base filename and directory img_dir, img_basename = os.path.split(img_path)