diff --git a/.gitignore b/.gitignore index 8f9add240..af16d51fd 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,6 @@ gui-user.bat gui-user.ps1 .vscode wandb +setup.log logs -SmilingWolf \ No newline at end of file +SmilingWolf diff --git a/README.md b/README.md index 1c6b50544..340b81702 100644 --- a/README.md +++ b/README.md @@ -345,6 +345,17 @@ This will store a backup file with your current locally installed pip packages a ## Change History +* 2023/06/02 (v21.6.0) +- Merge kohya_ss repo changes +- Improge logging of kohya_ss GUI +* 2023/05/28 (v21.5.15) +- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin! + - Warning is also displayed when using class+identifier dataset. Please ignore if it is intended. +- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru! + - `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge. + - `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used. + - This is useful for incremental learning. See PR for details. +- Show warning and continue training when uploading to HuggingFace fails. * 2023/05/28 (v21.5.14) - Add Create Groupo tool and GUI * 2023/05/24 (v21.5.13) @@ -391,69 +402,3 @@ This will store a backup file with your current locally installed pip packages a - The actual value of the noise offset is calculated as `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale`. Since the latent is close to a normal distribution, it may be a good idea to specify a value of about 1/10 to the same as the noise offset. - Negative values can also be specified, in which case the noise offset will be clipped to 0 or more. - Other minor fixes. -* 2023/04/06 (v21.5.9) - - Inplement headless mode to enable easier support under headless services like vast.ai. To make use of it start the gui with the `--headless` argument like: - - `.\gui.ps1 --headless` or `.\gui.bat --headless` or `./gui.sh --headless` - - Added the option for the user to put the wandb api key in a textbox under the advanced configuration dropdown and a checkbox to toggle for using wandb logging. @x-CK-x - - Docker build image @Trojaner - - Updated README to use docker compose run instead of docker compose up to fix broken tqdm - - Related: Doesn't work with docker-compose tqdm/tqdm#771 - - Fixed build for latest release - - Replace pillow with pillow-simd - - Removed --no-cache again as pip cache is not enabled anyway - - While overwriting .txt files with prefix and postfix including different encodings you might encounter this decoder error. This small fix gets rid of it... @ertugrul-dmr - - Docker Add --no-cache-dir to reduce image size @chiragjn - - Reverting bitsandbytes version to 0.35.0 due to issues with 0.38.1 on some systems -* 2023/04/05 (v21.5.8) - - Add `Cache latents to disk` option to the gui. - - When saving v2 models in Diffusers format in training scripts and conversion scripts, it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is `"use_linear_projection": false`, stabilityai is `true`). Please note that the weight shapes are different, so please be careful when using the weight files directly. We apologize for the inconvenience. - - Since the U-Net model is created based on the configuration, it should not cause any problems in training or inference. - - Added `--unet_use_linear_projection` option to `convert_diffusers20_original_sd.py` script. If you specify this option, you can save a Diffusers format model with the same configuration as stabilityai's model from an SD format model (a single `*.safetensors` or `*.ckpt` file). Unfortunately, it is not possible to convert a Diffusers format model to the same format. - - Lion8bit optimizer is supported. [PR #447](https://github.com/kohya-ss/sd-scripts/pull/447) Thanks to sdbds! - - Currently it is optional because you need to update `bitsandbytes` version. See "Optional: Use Lion8bit" in installation instructions to use it. - - Multi-GPU training with DDP is supported in each training script. [PR #448](https://github.com/kohya-ss/sd-scripts/pull/448) Thanks to Isotr0py! - - Multi resolution noise (pyramid noise) is supported in each training script. [PR #471](https://github.com/kohya-ss/sd-scripts/pull/471) Thanks to pamparamm! - - See PR and this page [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) for details. - - Add --no-cache-dir to reduce image size @chiragjn -* 2023/05/01 (v21.5.7) - - `tag_images_by_wd14_tagger.py` can now get arguments from outside. [PR #453](https://github.com/kohya-ss/sd-scripts/pull/453) Thanks to mio2333! - - Added `--save_every_n_steps` option to each training script. The model is saved every specified steps. - - `--save_last_n_steps` option can be used to save only the specified number of models (old models will be deleted). - - If you specify the `--save_state` option, the state will also be saved at the same time. You can specify the number of steps to keep the state with the `--save_last_n_steps_state` option (the same value as `--save_last_n_steps` is used if omitted). - - You can use the epoch-based model saving and state saving options together. - - Not tested in multi-GPU environment. Please report any bugs. - - `--cache_latents_to_disk` option automatically enables `--cache_latents` option when specified. [#438](https://github.com/kohya-ss/sd-scripts/issues/438) - - Fixed a bug in `gen_img_diffusers.py` where latents upscaler would fail with a batch size of 2 or more. - - Fix triton error - - Fix issue with merge lora path with spaces - - Added support for logging to wandb. Please refer to PR #428. Thank you p1atdev! - - wandb installation is required. Please install it with pip install wandb. Login to wandb with wandb login command, or set --wandb_api_key option for automatic login. - - Please let me know if you find any bugs as the test is not complete. - - You can automatically login to wandb by setting the --wandb_api_key option. Please be careful with the handling of API Key. PR #435 Thank you Linaqruf! - - Improved the behavior of --debug_dataset on non-Windows environments. PR #429 Thank you tsukimiya! - - Fixed --face_crop_aug option not working in Fine tuning method. - - Prepared code to use any upscaler in gen_img_diffusers.py. - - Fixed to log to TensorBoard when --logging_dir is specified and --log_with is not specified. - - Add new docker image solution.. Thanks to @Trojaner -* 2023/04/22 (v21.5.5) - - Update LoRA merge GUI to support SD checkpoint merge and up to 4 LoRA merging - - Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi! - - Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`. - - Add new Extract DyLoRA gui to the Utilities tab. - - Add new Merge LyCORIS models into checkpoint gui to the Utilities tab. - - Add new info on startup to help debug things -* 2023/04/17 (v21.5.4) - - Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`. - - Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf! - - Upgrade Gradio to latest release - - Fix issue when Adafactor is used as optimizer and LR Warmup is not 0: https://github.com/bmaltais/kohya_ss/issues/617 - - Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README-ja.md#dylora) for details (currently only in Japanese). - - Added support for caching latents to disk in each training script. Please specify __both__ `--cache_latents` and `--cache_latents_to_disk` options. - - The files are saved in the same folder as the images with the extension `.npz`. If you specify the `--flip_aug` option, the files with `_flip.npz` will also be saved. - - Multi-GPU training has not been tested. - - This feature is not tested with all combinations of datasets and training scripts, so there may be bugs. - - Added workaround for an error that occurs when training with `fp16` or `bf16` in `fine_tune.py`. - - Implemented DyLoRA GUI support. There will now be a new 'DyLoRA Unit` slider when the LoRA type is selected as `kohya DyLoRA` to specify the desired Unit value for DyLoRA training. - - Update gui.bat and gui.ps1 based on: https://github.com/bmaltais/kohya_ss/issues/188 - - Update `setup.bat` to install torch 2.0.0 instead of 1.2.1. If you want to upgrade from 1.2.1 to 2.0.0 run setup.bat again, select 1 to uninstall the previous torch modules, then select 2 for torch 2.0.0 diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 2acc9157a..09bc73a11 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -40,6 +40,11 @@ from library.utilities import utilities_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + # from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -129,14 +134,14 @@ def save_configuration( save_as_bool = True if save_as.get('label') == 'True' else False if save_as_bool: - print('Save as...') + log.info('Save as...') file_path = get_saveasfile_path(file_path) else: - print('Save...') + log.info('Save...') if file_path == None or file_path == '': file_path = get_saveasfile_path(file_path) - # print(file_path) + # log.info(file_path) if file_path == None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action @@ -253,7 +258,7 @@ def open_configuration( # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) - print('Loading config...') + log.info('Loading config...') # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True my_data = update_my_data(my_data) else: @@ -407,7 +412,7 @@ def train_model( # Check if subfolders are present. If not let the user know and return if not subfolders: - print( + log.info( '\033[33mNo subfolders were found in', train_data_dir, " can't train\...033[0m", @@ -422,7 +427,7 @@ def train_model( try: repeats = int(folder.split('_')[0]) except ValueError: - print( + log.info( '\033[33mSubfolder', folder, "does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m", @@ -444,17 +449,17 @@ def train_model( ) if num_images == 0: - print(f'{folder} folder contain no images, skipping...') + log.info(f'{folder} folder contain no images, skipping...') else: # Calculate the total number of steps for this folder steps = repeats * num_images total_steps += steps # Print the result - print('\033[33mFolder', folder, ':', steps, 'steps\033[0m') + log.info('\033[33mFolder', folder, ':', steps, 'steps\033[0m') if total_steps == 0: - print( + log.info( '\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m', @@ -462,12 +467,12 @@ def train_model( return # Print the result - # print(f"{total_steps} total steps") + # log.info(f"{total_steps} total steps") if reg_data_dir == '': reg_factor = 1 else: - print( + log.info( '\033[94mRegularisation images are used... Will double the number of steps required...\033[0m' ) reg_factor = 2 @@ -482,7 +487,7 @@ def train_model( * int(reg_factor) ) ) - print(f'max_train_steps = {max_train_steps}') + log.info(f'max_train_steps = {max_train_steps}') # calculate stop encoder training if int(stop_text_encoder_training_pct) == -1: @@ -493,10 +498,10 @@ def train_model( stop_text_encoder_training = math.ceil( float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) ) - print(f'stop_text_encoder_training = {stop_text_encoder_training}') + log.info(f'stop_text_encoder_training = {stop_text_encoder_training}') lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - print(f'lr_warmup_steps = {lr_warmup_steps}') + log.info(f'lr_warmup_steps = {lr_warmup_steps}') run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' if v2: @@ -606,7 +611,7 @@ def train_model( output_dir, ) - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -979,11 +984,11 @@ def UI(**kwargs): css = '' headless = kwargs.get('headless', False) - print(f'headless: {headless}') + log.info(f'headless: {headless}') if os.path.exists('./style.css'): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') + log.info('Load CSS...') css += file.read() + '\n' interface = gr.Blocks( diff --git a/finetune_gui.py b/finetune_gui.py index 525d62a3b..a8a1716d4 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -31,6 +31,11 @@ from library.utilities import utilities_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + # from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -128,14 +133,14 @@ def save_configuration( save_as_bool = True if save_as.get('label') == 'True' else False if save_as_bool: - print('Save as...') + log.info('Save as...') file_path = get_saveasfile_path(file_path) else: - print('Save...') + log.info('Save...') if file_path == None or file_path == '': file_path = get_saveasfile_path(file_path) - # print(file_path) + # log.info(file_path) if file_path == None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action @@ -258,7 +263,7 @@ def open_configuration( # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) - print('Loading config...') + log.info('Loading config...') # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True my_data = update_my_data(my_data) else: @@ -391,7 +396,7 @@ def train_model( if full_path: run_cmd += f' --full_path' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -416,7 +421,7 @@ def train_model( if full_path: run_cmd += f' --full_path' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -433,10 +438,10 @@ def train_model( if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] ) - print(f'image_num = {image_num}') + log.info(f'image_num = {image_num}') repeats = int(image_num) * int(dataset_repeats) - print(f'repeats = {str(repeats)}') + log.info(f'repeats = {str(repeats)}') # calculate max_train_steps max_train_steps = int( @@ -452,10 +457,10 @@ def train_model( if flip_aug: max_train_steps = int(math.ceil(float(max_train_steps) / 2)) - print(f'max_train_steps = {max_train_steps}') + log.info(f'max_train_steps = {max_train_steps}') lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - print(f'lr_warmup_steps = {lr_warmup_steps}') + log.info(f'lr_warmup_steps = {lr_warmup_steps}') run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"' if v2: @@ -560,7 +565,7 @@ def train_model( output_dir, ) - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -931,11 +936,11 @@ def UI(**kwargs): css = '' headless = kwargs.get('headless', False) - print(f'headless: {headless}') + log.info(f'headless: {headless}') if os.path.exists('./style.css'): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') + log.info('Load CSS...') css += file.read() + '\n' interface = gr.Blocks( diff --git a/gui.bat b/gui.bat index 6cbde1bd9..e0e691e42 100644 --- a/gui.bat +++ b/gui.bat @@ -5,7 +5,7 @@ call .\venv\Scripts\activate.bat set PATH=%PATH%;%~dp0venv\Lib\site-packages\torch\lib :: Debug info about system -python.exe .\tools\debug_info.py +:: python.exe .\tools\debug_info.py :: Validate the requirements and store the exit code python.exe .\tools\validate_requirements.py diff --git a/gui.ps1 b/gui.ps1 index 27c5ba205..dcb00a098 100644 --- a/gui.ps1 +++ b/gui.ps1 @@ -3,7 +3,7 @@ $env:PATH += ";$($MyInvocation.MyCommand.Path)\venv\Lib\site-packages\torch\lib" # Debug info about system -python.exe .\tools\debug_info.py +# python.exe .\tools\debug_info.py # Validate the requirements and store the exit code python.exe .\tools\validate_requirements.py @@ -16,6 +16,6 @@ if ($LASTEXITCODE -eq 0) { $argsFromFile = Get-Content .\gui_parameters.txt -Encoding UTF8 | Where-Object { $_ -notmatch "^#" } | Foreach-Object { $_ -split " " } } $args_combo = $argsFromFile + $args - Write-Host "The arguments passed to this script were: $args_combo" + # Write-Host "The arguments passed to this script were: $args_combo" python.exe kohya_gui.py $args_combo } diff --git a/kohya_gui.py b/kohya_gui.py index 4fcf285ab..2e91775dc 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -14,81 +14,20 @@ from lora_gui import lora_tab import os -import sys -import json -import time -import shutil -import logging -import subprocess - -log = logging.getLogger('sd') - -# setup console and file logging -def setup_logging(clean=False): - try: - if clean and os.path.isfile('setup.log'): - os.remove('setup.log') - time.sleep(0.1) # prevent race condition - except: - pass - logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', - filename='setup.log', - filemode='a', - encoding='utf-8', - force=True, - ) - from rich.theme import Theme - from rich.logging import RichHandler - from rich.console import Console - from rich.pretty import install as pretty_install - from rich.traceback import install as traceback_install - - console = Console( - log_time=True, - log_time_format='%H:%M:%S-%f', - theme=Theme( - { - 'traceback.border': 'black', - 'traceback.border.syntax_error': 'black', - 'inspect.value.border': 'black', - } - ), - ) - pretty_install(console=console) - traceback_install( - console=console, - extra_lines=1, - width=console.width, - word_wrap=False, - indent_guides=False, - suppress=[], - ) - rh = RichHandler( - show_time=True, - omit_repeated_times=False, - show_level=True, - show_path=False, - markup=False, - rich_tracebacks=True, - log_time_format='%H:%M:%S-%f', - level=logging.DEBUG if args.debug else logging.INFO, - console=console, - ) - rh.set_name(logging.DEBUG if args.debug else logging.INFO) - log.addHandler(rh) +from library.custom_logging import setup_logging +# Set up logging +log = setup_logging() def UI(**kwargs): css = '' headless = kwargs.get('headless', False) - print(f'headless: {headless}') + log.info(f'headless: {headless}') if os.path.exists('./style.css'): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') + log.info('Load CSS...') css += file.read() + '\n' interface = gr.Blocks( diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 20b890a7a..1019e35fc 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -4,6 +4,11 @@ from .common_gui import get_folder_path, add_pre_postfix, find_replace import os +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + def caption_images( caption_text, @@ -25,7 +30,7 @@ def caption_images( return if caption_text: - print(f'Captioning files in {images_dir} with {caption_text}...') + log.info(f'Captioning files in {images_dir} with {caption_text}...') run_cmd = f'python "tools/caption.py"' run_cmd += f' --caption_text="{caption_text}"' if overwrite: @@ -34,7 +39,7 @@ def caption_images( run_cmd += f' --caption_file_ext="{caption_ext}"' run_cmd += f' "{images_dir}"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -64,7 +69,7 @@ def caption_images( 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' ) - print('...captioning done') + log.info('...captioning done') # Gradio UI diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index 7a5766cc7..bc2cce34c 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -4,6 +4,11 @@ import os from .common_gui import get_folder_path, add_pre_postfix +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -33,7 +38,7 @@ def caption_images( msgbox('Please provide an extension for the caption files.') return - print(f'Captioning files in {train_data_dir}...') + log.info(f'Captioning files in {train_data_dir}...') run_cmd = f'{PYTHON} "finetune/make_captions.py"' run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --num_beams="{int(num_beams)}"' @@ -47,7 +52,7 @@ def caption_images( run_cmd += f' "{train_data_dir}"' run_cmd += f' --caption_weights="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -63,7 +68,7 @@ def caption_images( postfix=postfix, ) - print('...captioning done') + log.info('...captioning done') ### diff --git a/library/common_gui.py b/library/common_gui.py index 1b2e88ded..afad5ea28 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -6,6 +6,11 @@ import shutil import sys +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -39,7 +44,7 @@ def check_if_model_exist( output_name, output_dir, save_model_as, headless=False ): if headless: - print( + log.info( 'Headless mode, skipping verification if model already exist... if model already exist it will be overwritten...' ) return False @@ -49,7 +54,7 @@ def check_if_model_exist( if os.path.isdir(ckpt_folder): msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?' if not easygui.ynbox(msg, 'Overwrite Existing Model?'): - print( + log.info( 'Aborting training due to existing model with same name...' ) return True @@ -58,12 +63,12 @@ def check_if_model_exist( if os.path.isfile(ckpt_file): msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?' if not easygui.ynbox(msg, 'Overwrite Existing Model?'): - print( + log.info( 'Aborting training due to existing model with same name...' ) return True else: - print( + log.info( 'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...' ) return False @@ -73,7 +78,7 @@ def check_if_model_exist( def output_message(msg='', title='', headless=False): if headless: - print(msg) + log.info(msg) else: msgbox(msg=msg, title=title) @@ -120,9 +125,9 @@ def update_my_data(my_data): ) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']: message = 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' if my_data.get('LoRA_type'): - print(message.format('LoRA')) + log.info(message.format('LoRA')) if my_data.get('num_vectors_per_token'): - print(message.format('TI')) + log.info(message.format('TI')) my_data['save_model_as'] = 'safetensors' return my_data @@ -151,7 +156,7 @@ def get_file_path( and sys.platform != 'darwin' ): current_file_path = file_path - # print(f'current file path: {current_file_path}') + # log.info(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) @@ -178,7 +183,7 @@ def get_file_path( if not file_path: file_path = current_file_path current_file_path = file_path - # print(f'current file path: {current_file_path}') + # log.info(f'current file path: {current_file_path}') return file_path @@ -189,7 +194,7 @@ def get_any_file_path(file_path=''): and sys.platform != 'darwin' ): current_file_path = file_path - # print(f'current file path: {current_file_path}') + # log.info(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) @@ -257,7 +262,7 @@ def get_saveasfile_path( and sys.platform != 'darwin' ): current_file_path = file_path - # print(f'current file path: {current_file_path}') + # log.info(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) @@ -275,15 +280,15 @@ def get_saveasfile_path( ) root.destroy() - # print(save_file_path) + # log.info(save_file_path) if save_file_path == None: file_path = current_file_path else: - print(save_file_path.name) + log.info(save_file_path.name) file_path = save_file_path.name - # print(file_path) + # log.info(file_path) return file_path @@ -296,7 +301,7 @@ def get_saveasfilename_path( and sys.platform != 'darwin' ): current_file_path = file_path - # print(f'current file path: {current_file_path}') + # log.info(f'current file path: {current_file_path}') initial_dir, initial_file = get_dir_and_file(file_path) @@ -317,7 +322,7 @@ def get_saveasfilename_path( if save_file_path == '': file_path = current_file_path else: - # print(save_file_path) + # log.info(save_file_path) file_path = save_file_path return file_path @@ -401,7 +406,7 @@ def find_replace( search_text (str, optional): Text to search for in the caption files. replace_text (str, optional): Text to replace the search text with. """ - print('Running caption find/replace') + log.info('Running caption find/replace') if not has_ext_files(folder_path, caption_file_ext): msgbox( @@ -453,7 +458,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension if v2 and v_parameterization: - print( + log.info( f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml' ) shutil.copy( @@ -461,7 +466,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): f'{output_dir}/{file_name}.yaml', ) elif v2: - print( + log.info( f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml' ) shutil.copy( @@ -475,14 +480,14 @@ def set_pretrained_model_name_or_path_input( ): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in V2_BASE_MODELS: - print('SD v2 model detected. Setting --v2 parameter') + log.info('SD v2 model detected. Setting --v2 parameter') v2 = True v_parameterization = False pretrained_model_name_or_path = str(model_list) # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list if str(model_list) in V_PARAMETERIZATION_MODELS: - print( + log.info( 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' ) v2 = True @@ -832,7 +837,7 @@ def get_int_or_default(kwargs, key, default_value=0): elif isinstance(value, float): return int(value) else: - print(f'{key} is not an int, float or a string, setting value to {default_value}') + log.info(f'{key} is not an int, float or a string, setting value to {default_value}') return default_value def get_float_or_default(kwargs, key, default_value=0.0): @@ -844,7 +849,7 @@ def get_float_or_default(kwargs, key, default_value=0.0): elif isinstance(value, str): return float(value) else: - print(f'{key} is not an int, float or a string, setting value to {default_value}') + log.info(f'{key} is not an int, float or a string, setting value to {default_value}') return default_value def get_str_or_default(kwargs, key, default_value=""): @@ -872,7 +877,7 @@ def run_cmd_training(**kwargs): lr_warmup_steps = kwargs.get("lr_warmup_steps", "") if lr_warmup_steps: if lr_scheduler == 'constant': - print('Can\'t use LR warmup with LR Scheduler constant... ignoring...') + log.info('Can\'t use LR warmup with LR Scheduler constant... ignoring...') else: run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"' diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index 70b32e00e..2f3b94c1c 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -5,6 +5,11 @@ import shutil from .common_gui import get_folder_path, get_file_path +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -28,16 +33,16 @@ def convert_model( # Check if source model exist if os.path.isfile(source_model_input): - print('The provided source model is a file') + log.info('The provided source model is a file') elif os.path.isdir(source_model_input): - print('The provided model is a folder') + log.info('The provided model is a folder') else: msgbox('The provided source model is neither a file nor a folder') return # Check if source model exist if os.path.isdir(target_model_folder_input): - print('The provided model folder exist') + log.info('The provided model folder exist') else: msgbox('The provided target folder does not exist') return @@ -51,10 +56,10 @@ def convert_model( # check if v1 models if str(source_model_type) in v1_models: - print('SD v1 model specified. Setting --v1 parameter') + log.info('SD v1 model specified. Setting --v1 parameter') run_cmd += ' --v1' else: - print('SD v2 model specified. Setting --v2 parameter') + log.info('SD v2 model specified. Setting --v2 parameter') run_cmd += ' --v2' if not target_save_precision_type == 'unspecified': @@ -94,7 +99,7 @@ def convert_model( ) run_cmd += f' "{target_model_path}"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -120,7 +125,7 @@ def convert_model( inference_file = os.path.join( target_model_folder_input, f'{target_model_name_input}.yaml' ) - print(f'Saving v2-inference.yaml as {inference_file}') + log.info(f'Saving v2-inference.yaml as {inference_file}') shutil.copy( f'./v2_inference/v2-inference.yaml', f'{inference_file}', @@ -130,7 +135,7 @@ def convert_model( inference_file = os.path.join( target_model_folder_input, f'{target_model_name_input}.yaml' ) - print(f'Saving v2-inference-v.yaml as {inference_file}') + log.info(f'Saving v2-inference-v.yaml as {inference_file}') shutil.copy( f'./v2_inference/v2-inference-v.yaml', f'{inference_file}', diff --git a/library/custom_logging.py b/library/custom_logging.py new file mode 100644 index 000000000..f50aa1801 --- /dev/null +++ b/library/custom_logging.py @@ -0,0 +1,40 @@ +import os +import logging +import time + +from rich.theme import Theme +from rich.logging import RichHandler +from rich.console import Console +from rich.pretty import install as pretty_install +from rich.traceback import install as traceback_install + +log = None + +def setup_logging(clean=False, debug=False): + global log + + if log is not None: + return log + + try: + if clean and os.path.isfile('setup.log'): + os.remove('setup.log') + time.sleep(0.1) # prevent race condition + except: + pass + + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', encoding='utf-8', force=True) + + console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ + "traceback.border": "black", + "traceback.border.syntax_error": "black", + "inspect.value.border": "black", + })) + pretty_install(console=console) + traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[]) + rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG if debug else logging.INFO, console=console) + rh.set_name(logging.DEBUG if debug else logging.INFO) + log = logging.getLogger("sd") + log.addHandler(rh) + + return log diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py index de74e561e..93697f00d 100644 --- a/library/dataset_balancing_gui.py +++ b/library/dataset_balancing_gui.py @@ -4,6 +4,11 @@ from easygui import msgbox, boolbox from .common_gui import get_folder_path +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + # def select_folder(): # # Open a file dialog to select a directory # folder = filedialog.askdirectory() @@ -46,7 +51,7 @@ def dataset_balancing(concept_repeats, folder, insecure): images = len(image_files) if images == 0: - print( + log.info( f'No images of type .jpg, .jpeg, .png, .gif, .webp were found in {os.listdir(os.path.join(folder, subdir))}' ) @@ -86,7 +91,7 @@ def dataset_balancing(concept_repeats, folder, insecure): os.rename(old_name, new_name) else: - print( + log.info( f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...' ) diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py index 01df33d5f..0456d885c 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/library/dreambooth_folder_creation_gui.py @@ -4,6 +4,11 @@ import shutil import os +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + def copy_info_to_Folders_tab(training_folder): img_folder = os.path.join(training_folder, 'img') @@ -29,7 +34,7 @@ def dreambooth_folder_preparation( # Check if the input variables are empty if not len(util_training_dir_output): - print( + log.info( "Destination training directory is missing... can't perform the required task..." ) return @@ -49,7 +54,7 @@ def dreambooth_folder_preparation( # Create the training_dir path if util_training_images_dir_input == '': - print( + log.info( "Training images directory is missing... can't perform the required task..." ) return @@ -61,17 +66,17 @@ def dreambooth_folder_preparation( # Remove folders if they exist if os.path.exists(training_dir): - print(f'Removing existing directory {training_dir}...') + log.info(f'Removing existing directory {training_dir}...') shutil.rmtree(training_dir) # Copy the training images to their respective directories - print(f'Copy {util_training_images_dir_input} to {training_dir}...') + log.info(f'Copy {util_training_images_dir_input} to {training_dir}...') shutil.copytree(util_training_images_dir_input, training_dir) if not util_regularization_images_dir_input == '': # Create the regularization_dir path if not util_regularization_images_repeat_input > 0: - print('Repeats is missing... not copying regularisation images...') + log.info('Repeats is missing... not copying regularisation images...') else: regularization_dir = os.path.join( util_training_dir_output, @@ -80,18 +85,18 @@ def dreambooth_folder_preparation( # Remove folders if they exist if os.path.exists(regularization_dir): - print(f'Removing existing directory {regularization_dir}...') + log.info(f'Removing existing directory {regularization_dir}...') shutil.rmtree(regularization_dir) # Copy the regularisation images to their respective directories - print( + log.info( f'Copy {util_regularization_images_dir_input} to {regularization_dir}...' ) shutil.copytree( util_regularization_images_dir_input, regularization_dir ) else: - print( + log.info( 'Regularization images directory is missing... not copying regularisation images...' ) @@ -104,7 +109,7 @@ def dreambooth_folder_preparation( if not os.path.exists(os.path.join(util_training_dir_output, 'model')): os.makedirs(os.path.join(util_training_dir_output, 'model')) - print( + log.info( f'Done creating kohya_ss training folder structure at {util_training_dir_output}...' ) diff --git a/library/extract_lora_from_dylora_gui.py b/library/extract_lora_from_dylora_gui.py index 4bd70fc2e..5e84fb58b 100644 --- a/library/extract_lora_from_dylora_gui.py +++ b/library/extract_lora_from_dylora_gui.py @@ -7,6 +7,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -36,7 +41,7 @@ def extract_dylora( run_cmd += f' --model "{model}"' run_cmd += f' --unit {unit}' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -44,7 +49,7 @@ def extract_dylora( else: subprocess.run(run_cmd) - print('Done extracting DyLoRA...') + log.info('Done extracting DyLoRA...') ### diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 81c72f1ad..45ddb5f6d 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -8,6 +8,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -57,7 +62,7 @@ def extract_lora( if v2: run_cmd += f' --v2' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index a0aaff8ac..d3c19da6f 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -8,6 +8,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -79,7 +84,7 @@ def extract_lycoris_locon( run_cmd += f' "{db_model}"' run_cmd += f' "{output_name}"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index 1159bff0d..83e54cfbb 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -4,6 +4,11 @@ import os from .common_gui import get_folder_path, add_pre_postfix +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -26,7 +31,7 @@ def caption_images( msgbox('Please provide an extension for the caption files.') return - print(f'GIT captioning files in {train_data_dir}...') + log.info(f'GIT captioning files in {train_data_dir}...') run_cmd = f'{PYTHON} finetune/make_captions_by_git.py' if not model_id == '': run_cmd += f' --model_id="{model_id}"' @@ -39,7 +44,7 @@ def caption_images( run_cmd += f' --caption_extension="{caption_ext}"' run_cmd += f' "{train_data_dir}"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -55,7 +60,7 @@ def caption_images( postfix=postfix, ) - print('...captioning done') + log.info('...captioning done') ### diff --git a/library/group_images_gui.py b/library/group_images_gui.py index a100e32ae..a623edbf1 100644 --- a/library/group_images_gui.py +++ b/library/group_images_gui.py @@ -4,6 +4,11 @@ from .common_gui import get_folder_path import os +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def group_images( @@ -21,7 +26,7 @@ def group_images( msgbox('Please provide an output folder.') return - print(f'Grouping images in {input_folder}...') + log.info(f'Grouping images in {input_folder}...') run_cmd = f'{PYTHON} "{os.path.join("tools","group_images.py")}"' run_cmd += f' "{input_folder}"' @@ -32,14 +37,14 @@ def group_images( if do_not_copy_other_files: run_cmd += f' --do_not_copy_other_files' - print(run_cmd) + log.info(run_cmd) if os.name == 'posix': os.system(run_cmd) else: subprocess.run(run_cmd) - print('...grouping done') + log.info('...grouping done') def gradio_group_images_gui_tab(headless=False): diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 2d0e1980e..1dc496ff5 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -6,9 +6,7 @@ from library.utils import fire_in_thread -def exists_repo( - repo_id: str, repo_type: str, revision: str = "main", token: str = None -): +def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): api = HfApi( token=token, ) @@ -32,27 +30,35 @@ def upload( private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" api = HfApi(token=token) if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): - api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので + print("===========================================") + print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + print("===========================================") - is_folder = (type(src) == str and os.path.isdir(src)) or ( - isinstance(src, Path) and src.is_dir() - ) + is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) def uploader(): - if is_folder: - api.upload_folder( - repo_id=repo_id, - repo_type=repo_type, - folder_path=src, - path_in_repo=path_in_repo, - ) - else: - api.upload_file( - repo_id=repo_id, - repo_type=repo_type, - path_or_fileobj=src, - path_in_repo=path_in_repo, - ) + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので + print("===========================================") + print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + print("===========================================") if args.async_upload and not force_sync_upload: fire_in_thread(uploader) @@ -71,7 +77,5 @@ def list_dir( token=token, ) repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) - file_list = [ - file for file in repo_info.siblings if file.rfilename.startswith(subfolder) - ] + file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] return file_list diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index 3031aee25..b4fc3e680 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -8,6 +8,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -47,13 +52,13 @@ def merge_lora( precision, save_precision, ): - print('Merge model...') + log.info('Merge model...') models = [sd_model, lora_a_model, lora_b_model, lora_c_model, lora_d_model] lora_models = models[1:] ratios = [ratio_a, ratio_b, ratio_c, ratio_d] if not verify_conditions(sd_model, lora_models): - print( + log.info( 'Warning: Either provide at least one LoRa model along with the sd_model or at least two LoRa models if no sd_model is provided.' ) return @@ -80,7 +85,7 @@ def merge_lora( run_cmd += f' --models {models_cmd}' run_cmd += f' --ratios {ratios_cmd}' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -88,7 +93,7 @@ def merge_lora( else: subprocess.run(run_cmd) - print('Done merging...') + log.info('Done merging...') ### diff --git a/library/merge_lycoris_gui.py b/library/merge_lycoris_gui.py index 0fd8e1522..7d56f1e07 100644 --- a/library/merge_lycoris_gui.py +++ b/library/merge_lycoris_gui.py @@ -7,6 +7,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -23,7 +28,7 @@ def merge_lycoris( device, is_v2, ): - print('Merge model...') + log.info('Merge model...') run_cmd = f'{PYTHON} "{os.path.join("tools","merge_lycoris.py")}"' run_cmd += f' "{base_model}"' @@ -35,7 +40,7 @@ def merge_lycoris( if is_v2: run_cmd += f' --is_v2' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -43,7 +48,7 @@ def merge_lycoris( else: subprocess.run(run_cmd) - print('Done merging...') + log.info('Done merging...') ### diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index 9f11f1e09..92e766b56 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -4,6 +4,11 @@ import os from .common_gui import get_saveasfilename_path, get_file_path +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -64,7 +69,7 @@ def resize_lora( if verbose: run_cmd += f' --verbose' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -72,7 +77,7 @@ def resize_lora( else: subprocess.run(run_cmd) - print('Done resizing...') + log.info('Done resizing...') ### diff --git a/library/sampler_gui.py b/library/sampler_gui.py index ce953138b..b416dd2cd 100644 --- a/library/sampler_gui.py +++ b/library/sampler_gui.py @@ -3,6 +3,11 @@ import gradio as gr from easygui import msgbox +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index a62bafbad..fb550620f 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -8,6 +8,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 @@ -57,7 +62,7 @@ def svd_merge_lora( run_cmd += f' --new_rank "{new_rank}"' run_cmd += f' --new_conv_rank "{new_conv_rank}"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': diff --git a/library/tensorboard_gui.py b/library/tensorboard_gui.py index d08a02d94..bfaf79968 100644 --- a/library/tensorboard_gui.py +++ b/library/tensorboard_gui.py @@ -4,6 +4,11 @@ import subprocess import time +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + tensorboard_proc = None # I know... bad but heh TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe' @@ -12,37 +17,37 @@ def start_tensorboard(logging_dir): global tensorboard_proc if not os.listdir(logging_dir): - print('Error: log folder is empty') + log.info('Error: log folder is empty') msgbox(msg='Error: log folder is empty') return run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}'] - print(run_cmd) + log.info(run_cmd) if tensorboard_proc is not None: - print( + log.info( 'Tensorboard is already running. Terminating existing process before starting new one...' ) stop_tensorboard() # Start background process - print('Starting tensorboard...') + log.info('Starting tensorboard...') tensorboard_proc = subprocess.Popen(run_cmd) # Wait for some time to allow TensorBoard to start up time.sleep(5) # Open the TensorBoard URL in the default browser - print('Opening tensorboard url in browser...') + log.info('Opening tensorboard url in browser...') import webbrowser webbrowser.open('http://localhost:6006') def stop_tensorboard(): - print('Stopping tensorboard process...') + log.info('Stopping tensorboard process...') tensorboard_proc.kill() - print('...process stopped') + log.info('...process stopped') def gradio_tensorboard(): diff --git a/library/train_util.py b/library/train_util.py index b3968c431..d963537db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -348,6 +348,8 @@ def __init__( self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension def __eq__(self, other) -> bool: if not isinstance(other, DreamBoothSubset): @@ -1081,16 +1083,37 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] + missing_captions = [] for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: - print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + print( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) captions.append("") + missing_captions.append(img_path) else: - captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + if missing_captions: + number_of_missing_captions = len(missing_captions) + number_of_missing_captions_to_show = 5 + remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show + + print( + f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" + ) + for i, missing_caption in enumerate(missing_captions): + if i >= number_of_missing_captions_to_show: + print(missing_caption + f"... and {remaining_missing_captions} more") + break + print(missing_caption) return img_paths, captions print("prepare images.") diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index bc3db5be3..b98abf66d 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -8,6 +8,11 @@ get_file_path, ) +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -34,7 +39,7 @@ def verify_lora( f'{lora_model}', ] - print(' '.join(run_cmd)) + log.info(' '.join(run_cmd)) # Run the command process = subprocess.Popen( diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index 74a58aa3e..5c7c550cb 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -4,6 +4,11 @@ from .common_gui import get_folder_path import os +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + def caption_images( train_data_dir, @@ -28,7 +33,7 @@ def caption_images( msgbox('Please provide an extension for the caption files.') return - print(f'Captioning files in {train_data_dir}...') + log.info(f'Captioning files in {train_data_dir}...') run_cmd = f'accelerate launch "./finetune/tag_images_by_wd14_tagger.py"' run_cmd += f' --batch_size={int(batch_size)}' run_cmd += f' --general_threshold={general_threshold}' @@ -52,7 +57,7 @@ def caption_images( run_cmd += f' --undesired_tags="{undesired_tags}"' run_cmd += f' "{train_data_dir}"' - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -60,7 +65,7 @@ def caption_images( else: subprocess.run(run_cmd) - print('...captioning done') + log.info('...captioning done') ### diff --git a/lora_gui.py b/lora_gui.py index 563751c28..724d0b1e5 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -4,6 +4,8 @@ # v3.1: Adding captionning of images to utilities import gradio as gr +import logging +import time # import easygui import json @@ -47,6 +49,11 @@ from library.resize_lora_gui import gradio_resize_lora_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + # from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -55,7 +62,6 @@ document_symbol = '\U0001F4C4' # 📄 path_of_this_folder = os.getcwd() - def save_configuration( save_as, file_path, @@ -156,14 +162,14 @@ def save_configuration( save_as_bool = True if save_as.get('label') == 'True' else False if save_as_bool: - print('Save as...') + log.info('Save as...') file_path = get_saveasfile_path(file_path) else: - print('Save...') + log.info('Save...') if file_path == None or file_path == '': file_path = get_saveasfile_path(file_path) - # print(file_path) + # log.info(file_path) if file_path == None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action @@ -299,7 +305,7 @@ def open_configuration( # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) - print('Loading config...') + log.info('Loading config...') # Update values to fix deprecated use_8bit_adam checkbox, set appropriate optimizer if it is set to True, etc. my_data = update_my_data(my_data) @@ -415,6 +421,7 @@ def train_model( wandb_api_key, ): print_only_bool = True if print_only.get('label') == 'True' else False + log.info(f'Start training LoRA {LoRA_type} ...') headless_bool = True if headless.get('label') == 'True' else False if pretrained_model_name_or_path == '': @@ -534,35 +541,35 @@ def train_model( ] ) - print(f'Folder {folder}: {num_images} images found') + log.info(f'Folder {folder}: {num_images} images found') # Calculate the total number of steps for this folder steps = repeats * num_images - # Print the result - print(f'Folder {folder}: {steps} steps') + # log.info the result + log.info(f'Folder {folder}: {steps} steps') total_steps += steps except ValueError: # Handle the case where the folder name does not contain an underscore - print( + log.info( f"Error: '{folder}' does not contain an underscore, skipping..." ) if reg_data_dir == '': reg_factor = 1 else: - print( + log.info( '\033[94mRegularisation images are used... Will double the number of steps required...\033[0m' ) reg_factor = 2 - print(f'Total steps: {total_steps}') - print(f'Train batch size: {train_batch_size}') - print(f'Gradient accumulation steps: {gradient_accumulation_steps}') - print(f'Epoch: {epoch}') - print(f'Regulatization factor: {reg_factor}') + log.info(f'Total steps: {total_steps}') + log.info(f'Train batch size: {train_batch_size}') + log.info(f'Gradient accumulation steps: {gradient_accumulation_steps}') + log.info(f'Epoch: {epoch}') + log.info(f'Regulatization factor: {reg_factor}') # calculate max_train_steps max_train_steps = int( @@ -574,7 +581,7 @@ def train_model( * int(reg_factor) ) ) - print(f'max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}') + log.info(f'max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}') # calculate stop encoder training if stop_text_encoder_training_pct == None: @@ -583,10 +590,10 @@ def train_model( stop_text_encoder_training = math.ceil( float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) ) - print(f'stop_text_encoder_training = {stop_text_encoder_training}') + log.info(f'stop_text_encoder_training = {stop_text_encoder_training}') lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - print(f'lr_warmup_steps = {lr_warmup_steps}') + log.info(f'lr_warmup_steps = {lr_warmup_steps}') run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' @@ -625,7 +632,7 @@ def train_model( try: import lycoris except ModuleNotFoundError: - print( + log.info( "\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." ) return @@ -635,7 +642,7 @@ def train_model( try: import lycoris except ModuleNotFoundError: - print( + log.info( "\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." ) return @@ -820,12 +827,12 @@ def train_model( # run_cmd += f' --conv_alphas="{conv_alphas}"' if print_only_bool: - print( + log.info( '\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n' ) - print('\033[96m' + run_cmd + '\033[0m\n') + log.info('\033[96m' + run_cmd + '\033[0m\n') else: - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': os.system(run_cmd) @@ -1078,7 +1085,7 @@ def lora_tab( # Show of hide LoCon conv settings depending on LoRA type selection def update_LoRA_settings(LoRA_type): # Print a message when LoRA type is changed - print('LoRA type changed...') + log.info('LoRA type changed...') # Determine if LoCon_row should be visible based on LoRA_type LoCon_row = LoRA_type in { @@ -1442,11 +1449,11 @@ def UI(**kwargs): css = '' headless = kwargs.get('headless', False) - print(f'headless: {headless}') + log.info(f'headless: {headless}') if os.path.exists('./style.css'): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') + log.info('Load CSS...') css += file.read() + '\n' interface = gr.Blocks( @@ -1489,6 +1496,7 @@ def UI(**kwargs): launch_kwargs['inbrowser'] = inbrowser if share: launch_kwargs['share'] = share + log.info(launch_kwargs) interface.launch(**launch_kwargs) diff --git a/requirements.txt b/requirements.txt index 29b2fb57e..6f1fa8774 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ altair==4.2.2 # https://github.com/bmaltais/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl; sys_platform == 'win32' # This next line is not an error but rather there to properly catch if the url based bitsandbytes was properly installed by the line above... bitsandbytes==0.35.0; sys_platform == 'win32' -bitsandbytes==0.38.1; (sys_platform == "darwin" or sys_platform == "linux") +bitsandbytes==0.35.0; (sys_platform == "darwin" or sys_platform == "linux") dadaptation==3.1 diffusers[torch]==0.10.2 easygui==0.98.3 @@ -16,6 +16,7 @@ gradio==3.23.0; sys_platform == 'darwin' lion-pytorch==0.0.6 opencv-python==4.7.0.68 pytorch-lightning==1.9.0 +rich==13.4.1 safetensors==0.2.6 tensorboard==2.10.1 ; sys_platform != 'darwin' tensorboard==2.12.1 ; sys_platform == 'darwin' diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 9bb068d83..d20570765 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -40,6 +40,11 @@ from library.utilities import utilities_tab from library.sampler_gui import sample_gradio_config, run_cmd_sample +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + # from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 @@ -134,14 +139,14 @@ def save_configuration( save_as_bool = True if save_as.get('label') == 'True' else False if save_as_bool: - print('Save as...') + log.info('Save as...') file_path = get_saveasfile_path(file_path) else: - print('Save...') + log.info('Save...') if file_path == None or file_path == '': file_path = get_saveasfile_path(file_path) - # print(file_path) + # log.info(file_path) if file_path == None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action @@ -263,7 +268,7 @@ def open_configuration( # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) - print('Loading config...') + log.info('Loading config...') # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True my_data = update_my_data(my_data) else: @@ -456,15 +461,15 @@ def train_model( total_steps += steps # Print the result - print(f'Folder {folder}: {steps} steps') + log.info(f'Folder {folder}: {steps} steps') # Print the result - # print(f"{total_steps} total steps") + # log.info(f"{total_steps} total steps") if reg_data_dir == '': reg_factor = 1 else: - print( + log.info( 'Regularisation images are used... Will double the number of steps required...' ) reg_factor = 2 @@ -483,7 +488,7 @@ def train_model( else: max_train_steps = int(max_train_steps) - print(f'max_train_steps = {max_train_steps}') + log.info(f'max_train_steps = {max_train_steps}') # calculate stop encoder training if stop_text_encoder_training_pct == None: @@ -492,10 +497,10 @@ def train_model( stop_text_encoder_training = math.ceil( float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) ) - print(f'stop_text_encoder_training = {stop_text_encoder_training}') + log.info(f'stop_text_encoder_training = {stop_text_encoder_training}') lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - print(f'lr_warmup_steps = {lr_warmup_steps}') + log.info(f'lr_warmup_steps = {lr_warmup_steps}') run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_textual_inversion.py"' if v2: @@ -612,7 +617,7 @@ def train_model( output_dir, ) - print(run_cmd) + log.info(run_cmd) # Run the command if os.name == 'posix': @@ -1039,11 +1044,11 @@ def UI(**kwargs): css = '' headless = kwargs.get('headless', False) - print(f'headless: {headless}') + log.info(f'headless: {headless}') if os.path.exists('./style.css'): with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: - print('Load CSS...') + log.info('Load CSS...') css += file.read() + '\n' interface = gr.Blocks( diff --git a/tools/group_images.py b/tools/group_images.py index c859cdf81..406a8dd25 100644 --- a/tools/group_images.py +++ b/tools/group_images.py @@ -4,6 +4,11 @@ import os import numpy as np +from library.custom_logging import setup_logging + +# Set up logging +log = setup_logging() + class ImageProcessor: def __init__(self, input_folder, output_folder, group_size, include_subfolders, do_not_copy_other_files, pad): @@ -52,7 +57,7 @@ def crop_images(self, group, avg_aspect_ratio): cropped_images = [] for j, path in enumerate(group): with Image.open(path) as img: - print(f" Processing image {j+1}: {path}") + log.info(f" Processing image {j+1}: {path}") img = self.crop_image(img, avg_aspect_ratio) cropped_images.append(img) return cropped_images @@ -80,7 +85,7 @@ def resize_and_save_images(self, cropped_images, group_index): img = img.resize((max_width, max_height)) os.makedirs(self.output_folder, exist_ok=True) output_path = os.path.join(self.output_folder, f"group-{group_index+1}-image-{j+1}.jpg") - print(f" Saving processed image to {output_path}") + log.info(f" Saving processed image to {output_path}") img.convert('RGB').save(output_path) def copy_other_files(self, group, group_index): @@ -98,7 +103,7 @@ def process_images(self): images = self.get_image_paths() groups = self.group_images(images) for i, group in enumerate(groups): - print(f"Processing group {i+1} with {len(group)} images...") + log.info(f"Processing group {i+1} with {len(group)} images...") self.process_group(group, i) def process_group(self, group, group_index): @@ -118,7 +123,7 @@ def pad_images(self, group, avg_aspect_ratio): padded_images = [] for j, path in enumerate(group): with Image.open(path) as img: - print(f" Processing image {j+1}: {path}") + log.info(f" Processing image {j+1}: {path}") img = self.pad_image(img, avg_aspect_ratio) padded_images.append(img) return padded_images diff --git a/tools/validate_requirements.py b/tools/validate_requirements.py index 714714afe..5a1072b3f 100644 --- a/tools/validate_requirements.py +++ b/tools/validate_requirements.py @@ -1,74 +1,132 @@ import os import sys -import pkg_resources +import shutil +import time +import re import argparse +import pkg_resources from packaging.requirements import Requirement from packaging.markers import default_environment -import re -# Parse command line arguments -parser = argparse.ArgumentParser(description="Validate that requirements are satisfied.") -parser.add_argument('-r', '--requirements', type=str, default='requirements.txt', help="Path to the requirements file.") -args = parser.parse_args() +from library.custom_logging import setup_logging -print("Validating that requirements are satisfied.") +# Set up logging +log = setup_logging() -# Load the requirements from the specified requirements file -with open(args.requirements) as f: - requirements = f.readlines() +def check_torch(): + if shutil.which('nvidia-smi') is not None or os.path.exists(os.path.join(os.environ.get('SystemRoot') or r'C:\Windows', 'System32', 'nvidia-smi.exe')): + log.info('nVidia toolkit detected') + elif shutil.which('rocminfo') is not None or os.path.exists('/opt/rocm/bin/rocminfo'): + log.info('AMD toolkit detected') + else: + log.info('Using CPU-only Torch') -# Check each requirement against the installed packages -missing_requirements = [] -wrong_version_requirements = [] + try: + import torch + log.info(f'Torch {torch.__version__}') + + if not torch.cuda.is_available(): + log.warning("Torch reports CUDA not available") + else: + if torch.version.cuda: + log.info(f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}') + elif torch.version.hip: + log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') + else: + log.warning('Unknown Torch backend') + + for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]: + log.info(f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}') + except Exception as e: + log.error(f'Could not load torch: {e}') + sys.exit(1) -url_requirement_pattern = re.compile(r"(?Phttps?://.+);?\s?(?P.+)?") -for requirement in requirements: - requirement = requirement.strip() - if requirement == ".": - # Skip the current requirement if it is a dot (.) - continue +def validate_requirements(requirements_file): + log.info("Validating that requirements are satisfied.") - url_match = url_requirement_pattern.match(requirement) - if url_match: - if url_match.group("marker"): - marker = url_match.group("marker") - parsed_marker = Marker(marker) - if not parsed_marker.evaluate(default_environment()): - continue - requirement = url_match.group("url") + with open(requirements_file) as f: + requirements = f.readlines() - try: - parsed_req = Requirement(requirement) + missing_requirements = [] + wrong_version_requirements = [] + url_requirement_pattern = re.compile(r"(?Phttps?://.+);?\s?(?P.+)?") + + for requirement in requirements: + requirement = requirement.strip() - # Check if the requirement has an environment marker and if it evaluates to False - if parsed_req.marker and not parsed_req.marker.evaluate(default_environment()): + if requirement == ".": + # Skip the current requirement if it is a dot (.) continue - pkg_resources.require(str(parsed_req)) - except ValueError: - # This block will handle URL-based requirements - pass - except pkg_resources.DistributionNotFound: - missing_requirements.append(requirement) - except pkg_resources.VersionConflict as e: - wrong_version_requirements.append((requirement, str(e.req), e.dist.version)) - -# If there are any missing or wrong version requirements, print an error message and exit with a non-zero exit code -if missing_requirements or wrong_version_requirements: - if missing_requirements: - print("Error: The following packages are missing:") - for requirement in missing_requirements: - print(f" - {requirement}") - if wrong_version_requirements: - print("Error: The following packages have the wrong version:") - for requirement, expected_version, actual_version in wrong_version_requirements: - print(f" - {requirement} (expected version {expected_version}, found version {actual_version})") - upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh" - print(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r {args.requirements}\033[0m to resolve the missing requirements listed above...") - - sys.exit(1) - -# All requirements satisfied -print("All requirements satisfied.") -sys.exit(0) + url_match = url_requirement_pattern.match(requirement) + + if url_match: + if url_match.group("marker"): + marker = url_match.group("marker") + parsed_marker = Marker(marker) + + if not parsed_marker.evaluate(default_environment()): + continue + + requirement = url_match.group("url") + + try: + parsed_req = Requirement(requirement) + + if parsed_req.marker and not parsed_req.marker.evaluate(default_environment()): + continue + + pkg_resources.require(str(parsed_req)) + except ValueError: + # This block will handle URL-based requirements + pass + except pkg_resources.DistributionNotFound: + missing_requirements.append(requirement) + except pkg_resources.VersionConflict as e: + wrong_version_requirements.append((requirement, str(e.req), e.dist.version)) + + return missing_requirements, wrong_version_requirements + + +def print_error_messages(missing_requirements, wrong_version_requirements, requirements_file): + if missing_requirements or wrong_version_requirements: + if missing_requirements: + log.info("Error: The following packages are missing:") + + for requirement in missing_requirements: + log.info(f" - {requirement}") + + if wrong_version_requirements: + log.info("Error: The following packages have the wrong version:") + + for requirement, expected_version, actual_version in wrong_version_requirements: + log.info(f" - {requirement} (expected version {expected_version}, found version {actual_version})") + + upgrade_script = "upgrade.ps1" if os.name == "nt" else "upgrade.sh" + log.info(f"\nRun \033[33m{upgrade_script}\033[0m or \033[33mpip install -U -r {requirements_file}\033[0m to resolve the missing requirements listed above...") + sys.exit(1) + + log.info("All requirements satisfied.") + sys.exit(0) + + +def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description="Validate that requirements are satisfied.") + parser.add_argument('-r', '--requirements', type=str, default='requirements.txt', help="Path to the requirements file.") + parser.add_argument('--debug', action='store_true', help='Debug on') + args = parser.parse_args() + + # Check Torch + check_torch() + + # Validate requirements + missing_requirements, wrong_version_requirements = validate_requirements(args.requirements) + + # Print error messages if there are missing or wrong version requirements + print_error_messages(missing_requirements, wrong_version_requirements, args.requirements) + + +if __name__ == "__main__": + main() diff --git a/train_network.py b/train_network.py index f2fd20093..f8db030c4 100644 --- a/train_network.py +++ b/train_network.py @@ -80,25 +80,25 @@ def train(args): # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) if use_user_config: - print(f"Load dataset config from {args.dataset_config}") + print(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) ) else: if use_dreambooth_method: - print("Use DreamBooth method.") + print("Using DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + print("Training with captions.") user_config = { "datasets": [ { @@ -135,7 +135,7 @@ def train(args): ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する - print("prepare accelerator") + print("preparing accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process @@ -147,7 +147,30 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 差分追加学習のためにモデルを読み込む + import sys + + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] + + print(f"merging module: {weight_path} with multiplier {multiplier}") + + module, weights_sd = network_module.create_network_from_weights( + multiplier, weight_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + print(f"all weights merged: {', '.join(args.base_weights)}") # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -163,12 +186,6 @@ def train(args): accelerator.wait_for_everyone() # prepare network - import sys - - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: @@ -192,7 +209,7 @@ def train(args): if args.network_weights is not None: info = network.load_weights(args.network_weights) - print(f"load network weights from {args.network_weights}: {info}") + print(f"loaded network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -200,7 +217,7 @@ def train(args): network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + print("preparing optimizer, data loader etc.") # 後方互換性を確保するよ try: @@ -245,7 +262,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + print("enabling full fp16 training.") network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -770,6 +787,20 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) + parser.add_argument( + "--base_weights", + type=str, + default=None, + nargs="*", + help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", + ) + parser.add_argument( + "--base_weights_multiplier", + type=float, + default=None, + nargs="*", + help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", + ) return parser